From 97bbc37a7375e835f3db7d2d99f8168f7338abae Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Fri, 28 Apr 2023 11:12:13 -0400 Subject: [PATCH 1/3] MAINT: Use black --- .github/workflows/tests.yml | 1 + .pre-commit-config.yaml | 10 +++++----- azure-pipelines.yml | 2 +- ignore_words.txt | 2 ++ mne/channels/tests/test_montage.py | 12 ++++++------ mne/chpi.py | 11 ++++++----- mne/conftest.py | 6 +++--- mne/coreg.py | 14 ++++++++++---- mne/io/kit/tests/test_kit.py | 19 ++++++++++++------- mne/io/tests/test_raw.py | 12 +++++++++--- mne/io/tests/test_reference.py | 3 ++- mne/source_space.py | 9 +++++++-- mne/tests/test_annotations.py | 4 ++-- mne/tests/test_docstring_parameters.py | 5 ++++- mne/tests/test_source_estimate.py | 2 +- mne/tests/test_source_space.py | 2 +- mne/transforms.py | 2 +- mne/utils/tests/test_check.py | 4 +++- mne/viz/backends/_notebook.py | 10 ++++++++-- mne/viz/backends/_qt.py | 12 +++++++++--- pyproject.toml | 4 +++- requirements_testing.txt | 1 + .../forward/50_background_freesurfer_mne.py | 2 +- 23 files changed, 98 insertions(+), 51 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 3e8a3195c7a..d535e037c0a 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -23,6 +23,7 @@ jobs: - uses: actions/setup-python@v4 with: python-version: '3.11' + - uses: psf/black@stable - uses: pre-commit/action@v3.0.0 pytest: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7ef76755eae..59a60c19015 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,9 +1,9 @@ repos: -# - repo: https://github.com/psf/black -# rev: 23.1.0 -# hooks: -# - id: black -# args: [--quiet] +- repo: https://github.com/psf/black + rev: 23.3.0 + hooks: + - id: black + args: [--quiet] # Ruff mne - repo: https://github.com/charliermarsh/ruff-pre-commit diff --git a/azure-pipelines.yml b/azure-pipelines.yml index b050cc191c1..f0665efc164 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -57,7 +57,7 @@ stages: displayName: Install dependencies - bash: | make pre-commit - displayName: make ruff + displayName: make pre-commit condition: always() - bash: | make nesting diff --git a/ignore_words.txt b/ignore_words.txt index 8dde5403c07..c09662e1a1a 100644 --- a/ignore_words.txt +++ b/ignore_words.txt @@ -14,6 +14,7 @@ nd cas thes ba +bu ist od fo @@ -33,3 +34,4 @@ recuse ro nam shs +pres diff --git a/mne/channels/tests/test_montage.py b/mne/channels/tests/test_montage.py index 5fe16a2294d..f78e6bb3f2d 100644 --- a/mne/channels/tests/test_montage.py +++ b/mne/channels/tests/test_montage.py @@ -223,12 +223,12 @@ def test_documented(): pytest.param( partial(read_custom_montage, head_size=None, coord_frame='mri'), - ('// MatLab Sphere coordinates [degrees] Cartesian coordinates\n' # noqa: E501 - '// Label Theta Phi Radius X Y Z off sphere surface\n' # noqa: E501 - 'E1 37.700 -14.000 1.000 0.7677 0.5934 -0.2419 -0.00000000000000011\n' # noqa: E501 - 'E3 51.700 11.000 1.000 0.6084 0.7704 0.1908 0.00000000000000000\n' # noqa: E501 - 'E31 90.000 -11.000 1.000 0.0000 0.9816 -0.1908 0.00000000000000000\n' # noqa: E501 - 'E61 158.000 -17.200 1.000 -0.8857 0.3579 -0.2957 -0.00000000000000022'), # noqa: E501 + "// MatLab Sphere coordinates [degrees] Cartesian coordinates\n" # noqa: E501 + "// Label Theta Phi Radius X Y Z off sphere surface\n" # noqa: E501 + "E1 37.700 -14.000 1.000 0.7677 0.5934 -0.2419 -0.00000000000000011\n" # noqa: E501 + "E3 51.700 11.000 1.000 0.6084 0.7704 0.1908 0.00000000000000000\n" # noqa: E501 + "E31 90.000 -11.000 1.000 0.0000 0.9816 -0.1908 0.00000000000000000\n" # noqa: E501 + "E61 158.000 -17.200 1.000 -0.8857 0.3579 -0.2957 -0.00000000000000022", # noqa: E501 make_dig_montage( ch_pos={ 'E1': [0.7677, 0.5934, -0.2419], diff --git a/mne/chpi.py b/mne/chpi.py index 648ad6ca78a..9d80fa6efde 100644 --- a/mne/chpi.py +++ b/mne/chpi.py @@ -847,11 +847,12 @@ def compute_head_pos(info, chpi_locs, dist_limit=0.005, gof_limit=0.98, # 1. Check number of good ones # if len(use_idx) < 3: - msg = (_time_prefix(fit_time) + '%s/%s good HPI fits, cannot ' - 'determine the transformation (%s GOF)!' - % (len(use_idx), n_coils, - ', '.join('%0.2f' % g for g in g_coils))) - warn(msg) + gofs = ', '.join(f"{g:0.2f}" for g in g_coils) + warn( + f"{_time_prefix(fit_time)}{len(use_idx)}/{n_coils} " + "good HPI fits, cannot determine the transformation " + f"({gofs} GOF)!" + ) continue # diff --git a/mne/conftest.py b/mne/conftest.py index 9a64066b852..72e95b6e788 100644 --- a/mne/conftest.py +++ b/mne/conftest.py @@ -101,8 +101,8 @@ def pytest_configure(config): first_kind = 'error' else: first_kind = 'always' - warning_lines = r""" - {0}:: + warning_lines = f" {first_kind}::" + warning_lines += r""" # matplotlib->traitlets (notebook) ignore:Passing unrecognized arguments to super.*:DeprecationWarning # notebook tests @@ -142,7 +142,7 @@ def pytest_configure(config): ignore:pkg_resources is deprecated as an API.*:DeprecationWarning # h5py ignore:`product` is deprecated as of NumPy.*:DeprecationWarning - """.format(first_kind) # noqa: E501 + """ # noqa: E501 for warning_line in warning_lines.split('\n'): warning_line = warning_line.strip() if warning_line and not warning_line.startswith('#'): diff --git a/mne/coreg.py b/mne/coreg.py index db0b3645633..3e21f3ff917 100644 --- a/mne/coreg.py +++ b/mne/coreg.py @@ -23,11 +23,17 @@ from .io._digitization import _get_data_as_dict_from_dig # keep get_mni_fiducials for backward compat (no burden to keep in this # namespace, too) -from ._freesurfer import (_read_mri_info, get_mni_fiducials, # noqa: F401 - estimate_head_mri_t) # noqa: F401 +from ._freesurfer import ( + _read_mri_info, + get_mni_fiducials, + estimate_head_mri_t, # noqa: F401 +) from .label import read_label, Label -from .source_space import (add_source_space_distances, read_source_spaces, # noqa: E501,F401 - write_source_spaces) +from .source_space import ( + add_source_space_distances, + read_source_spaces, # noqa: F401 + write_source_spaces, +) from .surface import (read_surface, write_surface, _normalize_vectors, complete_surface_info, decimate_surface, _DistanceQuery) diff --git a/mne/io/kit/tests/test_kit.py b/mne/io/kit/tests/test_kit.py index d3746012328..696d10a83da 100644 --- a/mne/io/kit/tests/test_kit.py +++ b/mne/io/kit/tests/test_kit.py @@ -69,8 +69,9 @@ def test_data(tmp_path): # check functionality raw_mrk = read_raw_kit(sqd_path, [mrk2_path, mrk3_path], elp_txt_path, hsp_txt_path) - assert raw_mrk.info['description'] == \ - 'NYU 160ch System since Jan24 2009 (34) V2R004 EQ1160C' + assert ( + raw_mrk.info['description'] == 'NYU 160ch System since Jan24 2009 (34) V2R004 EQ1160C' # noqa: E501 + ) raw_py = _test_raw_reader(read_raw_kit, input_fname=sqd_path, mrk=mrk_path, elp=elp_txt_path, hsp=hsp_txt_path, stim=list(range(167, 159, -1)), slope='+', @@ -123,8 +124,9 @@ def test_data(tmp_path): # KIT-UMD data _test_raw_reader(read_raw_kit, input_fname=sqd_umd_path, test_rank='less') raw = read_raw_kit(sqd_umd_path) - assert raw.info['description'] == \ - 'University of Maryland/Kanazawa Institute of Technology/160-channel MEG System (53) V2R004 PQ1160R' # noqa: E501 + assert ( + raw.info['description'] == 'University of Maryland/Kanazawa Institute of Technology/160-channel MEG System (53) V2R004 PQ1160R' # noqa: E501 + ) assert_equal(raw.info['kit_system_id'], KIT.SYSTEM_UMD_2014_12) # check number/kind of channels assert_equal(len(raw.info['chs']), 193) @@ -135,8 +137,9 @@ def test_data(tmp_path): # KIT Academia Sinica raw = read_raw_kit(sqd_as_path, slope='+') - assert raw.info['description'] == \ - 'Academia Sinica/Institute of Linguistics//Magnetoencephalograph System (261) V2R004 PQ1160R-N2' # noqa: E501 + assert ( + raw.info['description'] == 'Academia Sinica/Institute of Linguistics//Magnetoencephalograph System (261) V2R004 PQ1160R-N2' # noqa: E501 + ) assert_equal(raw.info['kit_system_id'], KIT.SYSTEM_AS_2008) assert_equal(raw.info['chs'][100]['ch_name'], 'MEG 101') assert_equal(raw.info['chs'][100]['kind'], FIFF.FIFFV_MEG_CH) @@ -374,7 +377,9 @@ def test_berlin(): """Test data from Berlin.""" # gh-8535 raw = read_raw_kit(berlin_path) - assert raw.info['description'] == 'Physikalisch Technische Bundesanstalt, Berlin/128-channel MEG System (124) V2R004 PQ1128R-N2' # noqa: E501 + assert ( + raw.info['description'] == 'Physikalisch Technische Bundesanstalt, Berlin/128-channel MEG System (124) V2R004 PQ1128R-N2' # noqa: E501 + ) assert raw.info['kit_system_id'] == 124 assert raw.info['highpass'] == 0. assert raw.info['lowpass'] == 200. diff --git a/mne/io/tests/test_raw.py b/mne/io/tests/test_raw.py index 694cd46c941..4c728df90ef 100644 --- a/mne/io/tests/test_raw.py +++ b/mne/io/tests/test_raw.py @@ -739,9 +739,15 @@ def test_describe_print(): assert re.match( r'', s[0]) is not None, s[0] - assert s[1] == " ch name type unit min Q1 median Q3 max" # noqa - assert s[2] == " 0 MEG 0113 GRAD fT/cm -221.80 -38.57 -9.64 19.29 414.67" # noqa - assert s[-1] == "375 EOG 061 EOG µV -231.41 271.28 277.16 285.66 334.69" # noqa + assert ( + s[1] == " ch name type unit min Q1 median Q3 max" # noqa: E501 + ) + assert ( + s[2] == " 0 MEG 0113 GRAD fT/cm -221.80 -38.57 -9.64 19.29 414.67" # noqa: E501 + ) + assert ( + s[-1] == "375 EOG 061 EOG µV -231.41 271.28 277.16 285.66 334.69" # noqa: E501 + ) @requires_pandas diff --git a/mne/io/tests/test_reference.py b/mne/io/tests/test_reference.py index 8ab37fb5879..0cfb2a5349e 100644 --- a/mne/io/tests/test_reference.py +++ b/mne/io/tests/test_reference.py @@ -329,7 +329,8 @@ def test_set_eeg_reference_rest(): # load('leadfield.mat', 'G'); # dat_ref = ft_preproc_rereference(dat, 'all', 'rest', true, G); # sprintf('%g ', dat_ref(:, 171)); - want = np.array('-3.3265e-05 -3.2419e-05 -3.18758e-05 -3.24079e-05 -3.39801e-05 -3.40573e-05 -3.24163e-05 -3.26896e-05 -3.33814e-05 -3.54734e-05 -3.51289e-05 -3.53229e-05 -3.51532e-05 -3.53149e-05 -3.4505e-05 -3.03462e-05 -2.81848e-05 -3.08895e-05 -3.27158e-05 -3.4605e-05 -3.47728e-05 -3.2459e-05 -3.06552e-05 -2.53255e-05 -2.69671e-05 -2.83425e-05 -3.12836e-05 -3.30965e-05 -3.34099e-05 -3.32766e-05 -3.32256e-05 -3.36385e-05 -3.20796e-05 -2.7108e-05 -2.47054e-05 -2.49589e-05 -2.7382e-05 -3.09774e-05 -3.12003e-05 -3.1246e-05 -3.07572e-05 -2.64942e-05 -2.25505e-05 -2.67194e-05 -2.86e-05 -2.94903e-05 -2.96249e-05 -2.92653e-05 -2.86472e-05 -2.81016e-05 -2.69737e-05 -2.48076e-05 -3.00473e-05 -2.73404e-05 -2.60153e-05 -2.41608e-05 -2.61937e-05 -2.5539e-05 -2.47104e-05 -2.35194e-05'.split(' '), float) # noqa: E501 + data_array = "-3.3265e-05 -3.2419e-05 -3.18758e-05 -3.24079e-05 -3.39801e-05 -3.40573e-05 -3.24163e-05 -3.26896e-05 -3.33814e-05 -3.54734e-05 -3.51289e-05 -3.53229e-05 -3.51532e-05 -3.53149e-05 -3.4505e-05 -3.03462e-05 -2.81848e-05 -3.08895e-05 -3.27158e-05 -3.4605e-05 -3.47728e-05 -3.2459e-05 -3.06552e-05 -2.53255e-05 -2.69671e-05 -2.83425e-05 -3.12836e-05 -3.30965e-05 -3.34099e-05 -3.32766e-05 -3.32256e-05 -3.36385e-05 -3.20796e-05 -2.7108e-05 -2.47054e-05 -2.49589e-05 -2.7382e-05 -3.09774e-05 -3.12003e-05 -3.1246e-05 -3.07572e-05 -2.64942e-05 -2.25505e-05 -2.67194e-05 -2.86e-05 -2.94903e-05 -2.96249e-05 -2.92653e-05 -2.86472e-05 -2.81016e-05 -2.69737e-05 -2.48076e-05 -3.00473e-05 -2.73404e-05 -2.60153e-05 -2.41608e-05 -2.61937e-05 -2.5539e-05 -2.47104e-05 -2.35194e-05" # noqa: E501 + want = np.array(data_array.split(" "), float) norm = np.linalg.norm(want) idx = np.argmin(np.abs(evoked.times - 0.083)) assert idx == 170 diff --git a/mne/source_space.py b/mne/source_space.py index 8c7e8899ea1..6eb6c000537 100644 --- a/mne/source_space.py +++ b/mne/source_space.py @@ -31,8 +31,13 @@ complete_surface_info, _compute_nearest, fast_cross_3d, _CheckInside) # keep get_mni_fiducials here just for easy backward compat -from ._freesurfer import (_get_mri_info_data, _get_atlas_values, # noqa: F401 - read_freesurfer_lut, get_mni_fiducials, _check_mri) +from ._freesurfer import ( + _get_mri_info_data, + _get_atlas_values, + read_freesurfer_lut, + get_mni_fiducials, # noqa: F401 + _check_mri, +) from .utils import (get_subjects_dir, check_fname, logger, verbose, fill_doc, _ensure_int, _get_call_line, warn, object_size, sizeof_fmt, _check_fname, _path_like, _check_sphere, _import_nibabel, diff --git a/mne/tests/test_annotations.py b/mne/tests/test_annotations.py index d1a311bc9ae..d1e35cebc0d 100644 --- a/mne/tests/test_annotations.py +++ b/mne/tests/test_annotations.py @@ -987,7 +987,7 @@ def test_io_annotation_txt(dummy_annotation_txt_file, tmp_path_factory, pytest.param(None, None, id='None'), pytest.param(42, 42.0, id='Scalar'), pytest.param(3.14, 3.14, id='Float'), - pytest.param((3, 140000), 3.14, id='Scalar touple'), + pytest.param((3, 140000), 3.14, id="Scalar tuple"), pytest.param('2002-12-03 19:01:11.720100', 1038942071.7201, id='valid iso8601 string'), pytest.param('2002-12-03T19:01:11.720100', None, @@ -1355,7 +1355,7 @@ def test_annotation_ch_names(): assert raw_2.annotations.ch_names[1] == tuple(raw.ch_names[4:5]) for ch_drop in raw_2.annotations.ch_names: assert all(name in raw_2.ch_names for name in ch_drop) - with pytest.raises(ValueError, match='channel name in annotations missin'): + with pytest.raises(ValueError, match='channel name in annotations miss'): raw_2.set_annotations(annot) with pytest.warns(RuntimeWarning, match='channel name in annotations mis'): raw_2.set_annotations(annot, on_missing='warn') diff --git a/mne/tests/test_docstring_parameters.py b/mne/tests/test_docstring_parameters.py index ddfec15686e..7a3f59783bc 100644 --- a/mne/tests/test_docstring_parameters.py +++ b/mne/tests/test_docstring_parameters.py @@ -130,7 +130,10 @@ def check_parameters_match(func, cls=None): msg = str(exc) # E ValueError: no signature found for builtin type # - if inspect.isclass(callable_) and 'no signature found for buil' in msg: + if ( + inspect.isclass(callable_) and + "no signature found for builtin type" in msg + ): pass else: raise diff --git a/mne/tests/test_source_estimate.py b/mne/tests/test_source_estimate.py index ec6da53cf56..02e174556c1 100644 --- a/mne/tests/test_source_estimate.py +++ b/mne/tests/test_source_estimate.py @@ -1898,5 +1898,5 @@ def test_label_extraction_subject(kind): with pytest.raises(ValueError, match=r'label\.sub.*not match.* stc\.'): extract_label_time_course(stc, labels_fs, src) stc.subject = None - with pytest.raises(ValueError, match=r'label\.sub.*not match.* sourc'): + with pytest.raises(ValueError, match=r"label\.sub.*not match.* sour"): extract_label_time_course(stc, labels_fs, src) diff --git a/mne/tests/test_source_space.py b/mne/tests/test_source_space.py index 364e250284a..83ad939d7ef 100644 --- a/mne/tests/test_source_space.py +++ b/mne/tests/test_source_space.py @@ -479,7 +479,7 @@ def test_setup_source_space(tmp_path): setup_source_space('sample', spacing='7emm', add_dist=False, subjects_dir=subjects_dir) with pytest.raises(ValueError, match='must be a string with values'): - setup_source_space('sample', spacing='alls', + setup_source_space("sample", spacing="ally", add_dist=False, subjects_dir=subjects_dir) # ico 5 (fsaverage) - write to temp file diff --git a/mne/transforms.py b/mne/transforms.py index 39ab647f479..1514b2ad2d3 100644 --- a/mne/transforms.py +++ b/mne/transforms.py @@ -1822,7 +1822,7 @@ def apply_volume_registration(moving, static, reg_affine, sdr_morph=None, moving.shape, moving_affine) reg_data = affine_map.transform(moving, interpolation=interpolation) if sdr_morph is not None: - logger.info('Appling SDR warp ...') + logger.info("Applying SDR warp ...") reg_data = sdr_morph.transform( reg_data, interpolation=interpolation, image_world2grid=np.linalg.inv(static_affine), diff --git a/mne/utils/tests/test_check.py b/mne/utils/tests/test_check.py index 44caa61ba10..5763649dd5d 100644 --- a/mne/utils/tests/test_check.py +++ b/mne/utils/tests/test_check.py @@ -204,7 +204,9 @@ def test_suggest(): sug = _suggest('Left-cerebellum', names) assert sug == " Did you mean 'Left-Cerebellum-Cortex'?" sug = _suggest('Cerebellum-Cortex', names) - assert sug == " Did you mean one of ['Left-Cerebellum-Cortex', 'Right-Cerebellum-Cortex', 'Left-Cerebral-Cortex']?" # noqa: E501 + assert ( + sug == " Did you mean one of ['Left-Cerebellum-Cortex', 'Right-Cerebellum-Cortex', 'Left-Cerebral-Cortex']?" # noqa: E501 + ) def test_on_missing(): diff --git a/mne/viz/backends/_notebook.py b/mne/viz/backends/_notebook.py index c239aa9e42c..187c02e23c9 100644 --- a/mne/viz/backends/_notebook.py +++ b/mne/viz/backends/_notebook.py @@ -34,8 +34,14 @@ _AbstractWidgetList, _AbstractAction, _AbstractDialog, _AbstractKeyPress) from ._pyvista import _PyVistaRenderer, Plotter -from ._pyvista import (_close_3d_figure, _check_3d_figure, _close_all, # noqa: F401,E501 analysis:ignore - _set_3d_view, _set_3d_title, _take_3d_screenshot) # noqa: F401,E501 analysis:ignore +from ._pyvista import ( + _close_3d_figure, # noqa: F401 + _check_3d_figure, # noqa: F401 + _close_all, # noqa: F401 + _set_3d_view, # noqa: F401 + _set_3d_title, # noqa: F401 + _take_3d_screenshot, # noqa: F401 +) from ._utils import _notebook_vtk_works diff --git a/mne/viz/backends/_qt.py b/mne/viz/backends/_qt.py index fa8b3b9b9be..d058a505c34 100644 --- a/mne/viz/backends/_qt.py +++ b/mne/viz/backends/_qt.py @@ -32,9 +32,15 @@ QSpinBox, QStyle, QStyleOptionSlider) from ._pyvista import _PyVistaRenderer -from ._pyvista import (_close_3d_figure, _check_3d_figure, _close_all, # noqa: F401,E501 analysis:ignore - _set_3d_view, _set_3d_title, _take_3d_screenshot, # noqa: F401,E501 analysis:ignore - _is_mesa) # noqa: F401,E501 analysis:ignore +from ._pyvista import ( + _close_3d_figure, # noqa: F401 + _check_3d_figure, # noqa: F401 + _close_all, # noqa: F401 + _set_3d_view, # noqa: F401 + _set_3d_title, # noqa: F401 + _take_3d_screenshot, # noqa: F401 + _is_mesa, # noqa: F401 +) from ._abstract import (_AbstractAppWindow, _AbstractHBoxLayout, _AbstractVBoxLayout, _AbstractGridLayout, _AbstractWidget, _AbstractCanvas, diff --git a/pyproject.toml b/pyproject.toml index b8b664ab193..8e7f703106d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,5 @@ [tool.codespell] ignore-words = "ignore_words.txt" -uri-ignore-words-list = "bu" builtin = "clear,rare,informal,names,usage" skip = "doc/references.bib" @@ -48,3 +47,6 @@ addopts = """--durations=20 --doctest-modules -ra --cov-report= --tb=short \ --ignore=mne/report/js_and_css \ --color=yes --capture=sys""" junit_family = "xunit2" + +[tool.black] +exclude = "(dist/)|(build/)|(.*\\.ipynb)" diff --git a/requirements_testing.txt b/requirements_testing.txt index fa6c7b86b3f..aad9e7ea206 100644 --- a/requirements_testing.txt +++ b/requirements_testing.txt @@ -11,3 +11,4 @@ tomli; python_version<'3.11' twine wheel pre-commit +black diff --git a/tutorials/forward/50_background_freesurfer_mne.py b/tutorials/forward/50_background_freesurfer_mne.py index a204272b57f..4d67e3e19b3 100644 --- a/tutorials/forward/50_background_freesurfer_mne.py +++ b/tutorials/forward/50_background_freesurfer_mne.py @@ -128,7 +128,7 @@ def imshow_mri(data, img, vox, xyz, suptitle): # Figure out the title based on the code of this axis ori_slice = dict(P='Coronal', A='Coronal', I='Axial', S='Axial', - L='Sagittal', R='Saggital') + L='Sagittal', R='Sagittal') ori_names = dict(P='posterior', A='anterior', I='inferior', S='superior', L='left', R='right') From e81ec528a42ac687f3d961ed5cf8e25f236925b0 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Tue, 2 May 2023 13:22:53 -0400 Subject: [PATCH 2/3] MAINT: Run black on codebase --- doc/conf.py | 1958 ++++++----- doc/sphinxext/flow_diagram.py | 184 +- doc/sphinxext/gen_commands.py | 52 +- doc/sphinxext/gen_names.py | 21 +- doc/sphinxext/gh_substitutions.py | 8 +- doc/sphinxext/mne_substitutions.py | 51 +- doc/sphinxext/newcontrib_substitutions.py | 13 +- doc/sphinxext/unit_role.py | 10 +- examples/datasets/brainstorm_data.py | 34 +- examples/datasets/hf_sef_data.py | 7 +- examples/datasets/limo_data.py | 95 +- examples/datasets/opm_data.py | 95 +- examples/datasets/spm_faces_dataset_sgskip.py | 67 +- examples/decoding/decoding_csp_eeg.py | 57 +- examples/decoding/decoding_csp_timefreq.py | 103 +- examples/decoding/decoding_rsa_sgskip.py | 76 +- .../decoding_spatio_temporal_source.py | 95 +- examples/decoding/decoding_spoc_CMC.py | 27 +- ...decoding_time_generalization_conditions.py | 72 +- .../decoding_unsupervised_spatial_filter.py | 53 +- examples/decoding/decoding_xdawn_eeg.py | 68 +- examples/decoding/ems_filtering.py | 57 +- examples/decoding/linear_model_patterns.py | 33 +- examples/decoding/receptive_field_mtrf.py | 130 +- examples/decoding/ssd_spatial_filters.py | 90 +- examples/forward/forward_sensitivity_maps.py | 70 +- .../forward/left_cerebellum_volume_source.py | 32 +- examples/forward/source_space_morphing.py | 42 +- .../compute_mne_inverse_epochs_in_label.py | 83 +- .../compute_mne_inverse_raw_in_label.py | 24 +- .../inverse/compute_mne_inverse_volume.py | 23 +- examples/inverse/custom_inverse_solver.py | 51 +- examples/inverse/dics_epochs.py | 67 +- examples/inverse/dics_source_power.py | 36 +- examples/inverse/evoked_ers_source_power.py | 116 +- examples/inverse/gamma_map_inverse.py | 75 +- examples/inverse/label_activation_from_stc.py | 42 +- examples/inverse/label_from_stc.py | 65 +- examples/inverse/label_source_activations.py | 76 +- examples/inverse/mixed_norm_inverse.py | 118 +- .../inverse/mixed_source_space_inverse.py | 143 +- examples/inverse/mne_cov_power.py | 95 +- examples/inverse/morph_surface_stc.py | 55 +- examples/inverse/morph_volume_stc.py | 35 +- examples/inverse/multi_dipole_model.py | 81 +- .../inverse/multidict_reweighted_tfmxne.py | 62 +- examples/inverse/psf_ctf_label_leakage.py | 85 +- examples/inverse/psf_ctf_vertices.py | 69 +- examples/inverse/psf_ctf_vertices_lcmv.py | 147 +- examples/inverse/psf_volume.py | 62 +- examples/inverse/rap_music.py | 30 +- examples/inverse/read_inverse.py | 37 +- examples/inverse/read_stc.py | 13 +- examples/inverse/resolution_metrics.py | 127 +- examples/inverse/resolution_metrics_eegmeg.py | 137 +- examples/inverse/snr_estimate.py | 6 +- examples/inverse/source_space_snr.py | 38 +- .../time_frequency_mixed_norm_inverse.py | 120 +- examples/inverse/vector_mne_solution.py | 60 +- examples/io/elekta_epochs.py | 34 +- examples/io/read_neo_format.py | 10 +- examples/io/read_noise_covariance_matrix.py | 6 +- examples/io/read_xdf.py | 6 +- .../contralateral_referencing.py | 46 +- examples/preprocessing/css.py | 67 +- .../preprocessing/define_target_events.py | 49 +- examples/preprocessing/eeg_bridging.py | 191 +- examples/preprocessing/eeg_csd.py | 39 +- .../preprocessing/eog_artifact_histogram.py | 15 +- examples/preprocessing/eog_regression.py | 31 +- examples/preprocessing/find_ref_artifacts.py | 24 +- .../preprocessing/fnirs_artifact_removal.py | 21 +- examples/preprocessing/ica_comparison.py | 27 +- .../preprocessing/interpolate_bad_channels.py | 16 +- .../preprocessing/movement_compensation.py | 26 +- examples/preprocessing/movement_detection.py | 48 +- examples/preprocessing/muscle_detection.py | 14 +- examples/preprocessing/muscle_ica.py | 31 +- examples/preprocessing/otp.py | 43 +- examples/preprocessing/shift_evoked.py | 38 +- examples/preprocessing/virtual_evoked.py | 26 +- examples/preprocessing/xdawn_denoising.py | 34 +- examples/simulation/plot_stc_metrics.py | 165 +- examples/simulation/simulate_evoked_data.py | 58 +- examples/simulation/simulate_raw_data.py | 44 +- ...imulated_raw_data_using_subject_anatomy.py | 104 +- examples/simulation/source_simulator.py | 25 +- examples/stats/cluster_stats_evoked.py | 52 +- examples/stats/fdr_stats_evoked.py | 55 +- examples/stats/linear_regression_raw.py | 38 +- examples/stats/sensor_permutation_test.py | 45 +- examples/stats/sensor_regression.py | 14 +- examples/time_frequency/compute_csd.py | 32 +- .../compute_source_psd_epochs.py | 76 +- .../source_label_time_frequency.py | 91 +- .../time_frequency/source_power_spectrum.py | 43 +- .../source_power_spectrum_opm.py | 162 +- .../source_space_time_frequency.py | 47 +- examples/time_frequency/temporal_whitening.py | 26 +- .../time_frequency/time_frequency_erds.py | 107 +- .../time_frequency_global_field_power.py | 73 +- .../time_frequency_simulated.py | 170 +- examples/visualization/3d_to_2d.py | 27 +- examples/visualization/brain.py | 47 +- .../visualization/channel_epochs_image.py | 53 +- examples/visualization/eeg_on_scalp.py | 21 +- examples/visualization/evoked_arrowmap.py | 29 +- examples/visualization/evoked_topomap.py | 67 +- examples/visualization/evoked_whitening.py | 47 +- examples/visualization/meg_sensors.py | 57 +- examples/visualization/mne_helmet.py | 49 +- examples/visualization/montage_sgskip.py | 29 +- examples/visualization/parcellation.py | 55 +- examples/visualization/publication_figure.py | 130 +- examples/visualization/roi_erpimage_by_rt.py | 76 +- examples/visualization/sensor_noise_level.py | 7 +- .../ssp_projs_sensitivity_map.py | 20 +- .../visualization/topo_compare_conditions.py | 22 +- examples/visualization/topo_customized.py | 31 +- examples/visualization/xhemi.py | 29 +- logo/generate_mne_logos.py | 173 +- mne/__init__.py | 261 +- mne/__main__.py | 2 +- mne/_freesurfer.py | 375 +- mne/_ola.py | 286 +- mne/annotations.py | 623 ++-- mne/baseline.py | 89 +- mne/beamformer/__init__.py | 18 +- mne/beamformer/_compute_beamformer.py | 352 +- mne/beamformer/_dics.py | 268 +- mne/beamformer/_lcmv.py | 208 +- mne/beamformer/_rap_music.py | 59 +- mne/beamformer/resolution_matrix.py | 15 +- mne/beamformer/tests/test_dics.py | 765 ++-- mne/beamformer/tests/test_external.py | 72 +- mne/beamformer/tests/test_lcmv.py | 977 ++++-- mne/beamformer/tests/test_rap_music.py | 153 +- .../tests/test_resolution_matrix.py | 43 +- mne/bem.py | 1586 +++++---- mne/channels/__init__.py | 100 +- mne/channels/_dig_montage_utils.py | 66 +- mne/channels/_standard_montage_utils.py | 227 +- mne/channels/channels.py | 1373 +++++--- mne/channels/interpolation.py | 105 +- mne/channels/layout.py | 438 ++- mne/channels/montage.py | 764 ++-- mne/channels/tests/test_channels.py | 441 +-- mne/channels/tests/test_interpolation.py | 211 +- mne/channels/tests/test_layout.py | 233 +- mne/channels/tests/test_montage.py | 1793 ++++++---- mne/channels/tests/test_standard_montage.py | 266 +- mne/chpi.py | 973 +++--- mne/commands/mne_anonymize.py | 67 +- mne/commands/mne_browse_raw.py | 189 +- mne/commands/mne_bti2fiff.py | 97 +- mne/commands/mne_clean_eog_ecg.py | 183 +- mne/commands/mne_compare_fiff.py | 3 +- mne/commands/mne_compute_proj_ecg.py | 331 +- mne/commands/mne_compute_proj_eog.py | 306 +- mne/commands/mne_coreg.py | 142 +- mne/commands/mne_flash_bem.py | 161 +- mne/commands/mne_freeview_bem_surfaces.py | 70 +- mne/commands/mne_kit2fiff.py | 83 +- mne/commands/mne_make_scalp_surfaces.py | 72 +- mne/commands/mne_maxfilter.py | 254 +- mne/commands/mne_prepare_bem_model.py | 36 +- mne/commands/mne_report.py | 114 +- mne/commands/mne_setup_forward_model.py | 135 +- mne/commands/mne_setup_source_space.py | 175 +- mne/commands/mne_show_fiff.py | 12 +- mne/commands/mne_show_info.py | 7 +- mne/commands/mne_surf2bem.py | 22 +- mne/commands/mne_sys_info.py | 36 +- mne/commands/mne_watershed_bem.py | 117 +- mne/commands/mne_what.py | 3 +- mne/commands/tests/test_commands.py | 400 ++- mne/commands/utils.py | 35 +- mne/conftest.py | 581 +-- mne/coreg.py | 1201 ++++--- mne/cov.py | 1337 ++++--- mne/cuda.py | 150 +- mne/datasets/__init__.py | 48 +- mne/datasets/_fake/_fake.py | 25 +- mne/datasets/_fetch.py | 40 +- mne/datasets/_fsaverage/base.py | 22 +- mne/datasets/_infant/base.py | 25 +- mne/datasets/_phantom/base.py | 18 +- mne/datasets/brainstorm/__init__.py | 3 +- mne/datasets/brainstorm/bst_auditory.py | 42 +- mne/datasets/brainstorm/bst_phantom_ctf.py | 42 +- mne/datasets/brainstorm/bst_phantom_elekta.py | 43 +- mne/datasets/brainstorm/bst_raw.py | 54 +- mne/datasets/brainstorm/bst_resting.py | 42 +- mne/datasets/config.py | 355 +- mne/datasets/eegbci/eegbci.py | 62 +- mne/datasets/eegbci/tests/test_eegbci.py | 3 +- mne/datasets/epilepsy_ecog/_data.py | 25 +- mne/datasets/erp_core/erp_core.py | 27 +- mne/datasets/eyelink/eyelink.py | 27 +- mne/datasets/fieldtrip_cmc/fieldtrip_cmc.py | 25 +- mne/datasets/fnirs_motor/fnirs_motor.py | 27 +- mne/datasets/hf_sef/hf_sef.py | 48 +- mne/datasets/kiloword/kiloword.py | 21 +- mne/datasets/limo/limo.py | 287 +- mne/datasets/misc/_misc.py | 22 +- mne/datasets/mtrf/mtrf.py | 23 +- mne/datasets/multimodal/multimodal.py | 27 +- mne/datasets/opm/opm.py | 25 +- mne/datasets/phantom_4dbti/phantom_4dbti.py | 25 +- mne/datasets/refmeg_noise/refmeg_noise.py | 25 +- mne/datasets/sample/sample.py | 27 +- mne/datasets/sleep_physionet/_utils.py | 168 +- mne/datasets/sleep_physionet/age.py | 71 +- mne/datasets/sleep_physionet/temazepam.py | 51 +- .../sleep_physionet/tests/test_physionet.py | 196 +- mne/datasets/somato/somato.py | 27 +- mne/datasets/spm_face/spm_data.py | 40 +- mne/datasets/ssvep/ssvep.py | 25 +- mne/datasets/testing/__init__.py | 9 +- mne/datasets/testing/_testing.py | 52 +- mne/datasets/tests/test_datasets.py | 255 +- .../ucl_opm_auditory/ucl_opm_auditory.py | 24 +- mne/datasets/utils.py | 623 ++-- .../visual_92_categories.py | 25 +- mne/decoding/__init__.py | 11 +- mne/decoding/base.py | 169 +- mne/decoding/csp.py | 349 +- mne/decoding/ems.py | 49 +- mne/decoding/mixin.py | 21 +- mne/decoding/receptive_field.py | 168 +- mne/decoding/search_light.py | 131 +- mne/decoding/ssd.py | 188 +- mne/decoding/tests/test_base.py | 175 +- mne/decoding/tests/test_csp.py | 209 +- mne/decoding/tests/test_ems.py | 39 +- mne/decoding/tests/test_receptive_field.py | 399 ++- mne/decoding/tests/test_search_light.py | 107 +- mne/decoding/tests/test_ssd.py | 335 +- mne/decoding/tests/test_time_frequency.py | 5 +- mne/decoding/tests/test_transformer.py | 184 +- mne/decoding/time_delaying_ridge.py | 136 +- mne/decoding/time_frequency.py | 37 +- mne/decoding/transformer.py | 261 +- mne/defaults.py | 394 ++- mne/dipole.py | 1022 ++++-- mne/epochs.py | 2158 ++++++++---- mne/event.py | 632 ++-- mne/evoked.py | 1103 ++++-- mne/export/_brainvision.py | 1 + mne/export/_edf.py | 192 +- mne/export/_eeglab.py | 55 +- mne/export/_egimff.py | 99 +- mne/export/_export.py | 89 +- mne/export/tests/test_export.py | 376 +- mne/filter.py | 1715 ++++++--- mne/fixes.py | 282 +- mne/forward/__init__.py | 64 +- mne/forward/_compute_forward.py | 293 +- mne/forward/_field_interpolation.py | 347 +- mne/forward/_lead_dots.py | 247 +- mne/forward/_make_forward.py | 661 ++-- mne/forward/forward.py | 1396 ++++---- mne/forward/tests/test_field_interpolation.py | 253 +- mne/forward/tests/test_forward.py | 393 ++- mne/forward/tests/test_make_forward.py | 711 ++-- mne/gui/__init__.py | 199 +- mne/gui/_core.py | 398 ++- mne/gui/_coreg.py | 744 ++-- mne/gui/_ieeg_locate.py | 501 ++- mne/gui/tests/test_core.py | 37 +- mne/gui/tests/test_coreg.py | 198 +- mne/gui/tests/test_gui_api.py | 254 +- mne/gui/tests/test_ieeg_locate.py | 186 +- mne/html_templates/_templates.py | 19 +- mne/inverse_sparse/__init__.py | 3 +- mne/inverse_sparse/_gamma_map.py | 133 +- mne/inverse_sparse/mxne_debiasing.py | 18 +- mne/inverse_sparse/mxne_inverse.py | 592 +++- mne/inverse_sparse/mxne_optim.py | 772 ++-- mne/inverse_sparse/tests/test_gamma_map.py | 171 +- mne/inverse_sparse/tests/test_mxne_inverse.py | 516 ++- mne/inverse_sparse/tests/test_mxne_optim.py | 366 +- mne/io/__init__.py | 21 +- mne/io/_digitization.py | 372 +- mne/io/_read_raw.py | 51 +- mne/io/array/array.py | 67 +- mne/io/array/tests/test_array.py | 107 +- mne/io/artemis123/artemis123.py | 455 +-- mne/io/artemis123/tests/test_artemis123.py | 88 +- mne/io/artemis123/utils.py | 70 +- mne/io/base.py | 1522 +++++--- mne/io/besa/besa.py | 190 +- mne/io/besa/tests/test_besa.py | 60 +- mne/io/boxy/boxy.py | 182 +- mne/io/boxy/tests/test_boxy.py | 138 +- mne/io/brainvision/brainvision.py | 619 ++-- mne/io/brainvision/tests/test_brainvision.py | 744 ++-- mne/io/bti/bti.py | 1466 ++++---- mne/io/bti/constants.py | 134 +- mne/io/bti/read.py | 39 +- mne/io/bti/tests/test_bti.py | 411 ++- mne/io/cnt/_utils.py | 70 +- mne/io/cnt/cnt.py | 357 +- mne/io/cnt/tests/test_cnt.py | 29 +- mne/io/compensator.py | 69 +- mne/io/constants.py | 1628 +++++---- mne/io/ctf/constants.py | 2 +- mne/io/ctf/ctf.py | 208 +- mne/io/ctf/eeg.py | 73 +- mne/io/ctf/hc.py | 58 +- mne/io/ctf/info.py | 518 +-- mne/io/ctf/markers.py | 67 +- mne/io/ctf/res4.py | 231 +- mne/io/ctf/tests/test_ctf.py | 667 ++-- mne/io/ctf/trans.py | 108 +- mne/io/ctf_comp.py | 92 +- mne/io/curry/curry.py | 387 +- mne/io/curry/tests/test_curry.py | 421 ++- mne/io/diff.py | 6 +- mne/io/edf/edf.py | 1139 +++--- mne/io/edf/tests/test_edf.py | 655 ++-- mne/io/edf/tests/test_gdf.py | 91 +- mne/io/eeglab/_eeglab.py | 8 +- mne/io/eeglab/eeglab.py | 398 ++- mne/io/eeglab/tests/test_eeglab.py | 600 ++-- mne/io/egi/egi.py | 279 +- mne/io/egi/egimff.py | 697 ++-- mne/io/egi/events.py | 57 +- mne/io/egi/general.py | 126 +- mne/io/egi/tests/test_egi.py | 429 ++- mne/io/eximia/eximia.py | 57 +- mne/io/eximia/tests/test_eximia.py | 38 +- mne/io/eyelink/eyelink.py | 669 ++-- mne/io/eyelink/tests/test_eyelink.py | 116 +- mne/io/fieldtrip/__init__.py | 3 +- mne/io/fieldtrip/fieldtrip.py | 53 +- mne/io/fieldtrip/tests/helpers.py | 166 +- mne/io/fieldtrip/tests/test_fieldtrip.py | 185 +- mne/io/fieldtrip/utils.py | 299 +- mne/io/fiff/raw.py | 274 +- mne/io/fiff/tests/test_raw_fiff.py | 1166 ++++--- mne/io/fil/__init__.py | 2 +- mne/io/fil/fil.py | 226 +- mne/io/fil/sensors.py | 5 +- mne/io/fil/tests/test_fil.py | 37 +- mne/io/hitachi/hitachi.py | 221 +- mne/io/hitachi/tests/test_hitachi.py | 288 +- mne/io/kit/constants.py | 98 +- mne/io/kit/coreg.py | 128 +- mne/io/kit/kit.py | 624 ++-- mne/io/kit/tests/test_coreg.py | 4 +- mne/io/kit/tests/test_kit.py | 357 +- mne/io/matrix.py | 95 +- mne/io/meas_info.py | 2101 ++++++----- mne/io/nedf/nedf.py | 108 +- mne/io/nedf/tests/test_nedf.py | 92 +- mne/io/nicolet/nicolet.py | 134 +- mne/io/nicolet/tests/test_nicolet.py | 18 +- mne/io/nihon/nihon.py | 367 +- mne/io/nihon/tests/test_nihon.py | 55 +- mne/io/nirx/_localized_abbr.py | 108 +- mne/io/nirx/nirx.py | 406 ++- mne/io/nirx/tests/test_nirx.py | 641 ++-- mne/io/open.py | 159 +- mne/io/persyst/persyst.py | 207 +- mne/io/persyst/tests/test_persyst.py | 97 +- mne/io/pick.py | 1102 +++--- mne/io/proc_history.py | 246 +- mne/io/proj.py | 608 ++-- mne/io/reference.py | 377 +- mne/io/snirf/_snirf.py | 470 +-- mne/io/snirf/tests/test_snirf.py | 369 +- mne/io/tag.py | 223 +- mne/io/tests/__init__.py | 2 +- mne/io/tests/test_apply_function.py | 19 +- mne/io/tests/test_compensator.py | 36 +- mne/io/tests/test_constants.py | 384 +- mne/io/tests/test_meas_info.py | 926 ++--- mne/io/tests/test_pick.py | 602 ++-- mne/io/tests/test_proc_history.py | 36 +- mne/io/tests/test_raw.py | 583 ++-- mne/io/tests/test_read_raw.py | 61 +- mne/io/tests/test_reference.py | 603 ++-- mne/io/tests/test_show_fiff.py | 21 +- mne/io/tests/test_utils.py | 13 +- mne/io/tests/test_what.py | 37 +- mne/io/tests/test_write.py | 8 +- mne/io/tree.py | 82 +- mne/io/utils.py | 102 +- mne/io/what.py | 43 +- mne/io/write.py | 318 +- mne/label.py | 1169 ++++--- mne/minimum_norm/__init__.py | 36 +- mne/minimum_norm/_eloreta.py | 106 +- mne/minimum_norm/inverse.py | 1456 +++++--- mne/minimum_norm/resolution_matrix.py | 213 +- mne/minimum_norm/spatial_resolution.py | 119 +- mne/minimum_norm/tests/test_inverse.py | 1405 +++++--- .../tests/test_resolution_matrix.py | 207 +- .../tests/test_resolution_metrics.py | 135 +- mne/minimum_norm/tests/test_snr.py | 24 +- mne/minimum_norm/tests/test_time_frequency.py | 277 +- mne/minimum_norm/time_frequency.py | 558 ++- mne/misc.py | 33 +- mne/morph.py | 845 +++-- mne/morph_map.py | 96 +- mne/parallel.py | 64 +- mne/preprocessing/__init__.py | 34 +- mne/preprocessing/_csd.py | 151 +- mne/preprocessing/_css.py | 18 +- mne/preprocessing/_fine_cal.py | 296 +- mne/preprocessing/_peak_finder.py | 21 +- mne/preprocessing/_regress.py | 192 +- mne/preprocessing/annotate_amplitude.py | 104 +- mne/preprocessing/annotate_nan.py | 7 +- mne/preprocessing/artifact_detection.py | 274 +- mne/preprocessing/bads.py | 1 + mne/preprocessing/ctps_.py | 18 +- mne/preprocessing/ecg.py | 295 +- mne/preprocessing/eog.py | 214 +- mne/preprocessing/eyetracking/eyetracking.py | 115 +- mne/preprocessing/hfc.py | 30 +- mne/preprocessing/ica.py | 1807 ++++++---- mne/preprocessing/ieeg/_projection.py | 127 +- mne/preprocessing/ieeg/_volume.py | 161 +- .../ieeg/tests/test_projection.py | 147 +- mne/preprocessing/ieeg/tests/test_volume.py | 104 +- mne/preprocessing/infomax_.py | 92 +- mne/preprocessing/interpolate.py | 70 +- mne/preprocessing/maxfilter.py | 124 +- mne/preprocessing/maxwell.py | 1781 ++++++---- mne/preprocessing/nirs/__init__.py | 17 +- mne/preprocessing/nirs/_beer_lambert_law.py | 64 +- mne/preprocessing/nirs/_optical_density.py | 6 +- .../nirs/_scalp_coupling_index.py | 26 +- mne/preprocessing/nirs/_tddr.py | 13 +- mne/preprocessing/nirs/nirs.py | 182 +- .../nirs/tests/test_beer_lambert_law.py | 75 +- mne/preprocessing/nirs/tests/test_nirs.py | 341 +- .../nirs/tests/test_optical_density.py | 25 +- .../nirs/tests/test_scalp_coupling_index.py | 33 +- ...temporal_derivative_distribution_repair.py | 18 +- mne/preprocessing/otp.py | 34 +- mne/preprocessing/realign.py | 49 +- mne/preprocessing/ssp.py | 431 ++- mne/preprocessing/stim.py | 56 +- .../tests/test_annotate_amplitude.py | 288 +- mne/preprocessing/tests/test_annotate_nan.py | 14 +- .../tests/test_artifact_detection.py | 138 +- mne/preprocessing/tests/test_csd.py | 131 +- mne/preprocessing/tests/test_css.py | 31 +- mne/preprocessing/tests/test_ctps.py | 54 +- mne/preprocessing/tests/test_ecg.py | 78 +- .../tests/test_eeglab_infomax.py | 67 +- mne/preprocessing/tests/test_eog.py | 6 +- mne/preprocessing/tests/test_fine_cal.py | 71 +- mne/preprocessing/tests/test_hfc.py | 78 +- mne/preprocessing/tests/test_ica.py | 1084 +++--- mne/preprocessing/tests/test_infomax.py | 17 +- mne/preprocessing/tests/test_interpolate.py | 103 +- mne/preprocessing/tests/test_maxwell.py | 1551 +++++---- mne/preprocessing/tests/test_otp.py | 59 +- mne/preprocessing/tests/test_peak_finder.py | 12 +- mne/preprocessing/tests/test_realign.py | 77 +- mne/preprocessing/tests/test_regress.py | 93 +- mne/preprocessing/tests/test_ssp.py | 250 +- mne/preprocessing/tests/test_stim.py | 82 +- mne/preprocessing/tests/test_xdawn.py | 189 +- mne/preprocessing/xdawn.py | 189 +- mne/proj.py | 264 +- mne/rank.py | 294 +- .../bootstrap-icons/gen_css_for_mne.py | 40 +- mne/report/report.py | 2308 +++++++----- mne/report/tests/test_report.py | 765 ++-- mne/simulation/_metrics.py | 4 +- mne/simulation/evoked.py | 64 +- mne/simulation/metrics/__init__.py | 22 +- mne/simulation/metrics/metrics.py | 97 +- mne/simulation/metrics/tests/test_metrics.py | 185 +- mne/simulation/raw.py | 498 +-- mne/simulation/source.py | 206 +- mne/simulation/tests/test_evoked.py | 141 +- mne/simulation/tests/test_metrics.py | 24 +- mne/simulation/tests/test_raw.py | 373 +- mne/simulation/tests/test_source.py | 319 +- mne/source_estimate.py | 1943 +++++++---- mne/source_space.py | 2008 ++++++----- mne/stats/__init__.py | 24 +- mne/stats/_adjacency.py | 38 +- mne/stats/cluster_level.py | 661 ++-- mne/stats/multi_comp.py | 10 +- mne/stats/parametric.py | 90 +- mne/stats/permutations.py | 42 +- mne/stats/regression.py | 183 +- mne/stats/tests/test_adjacency.py | 28 +- mne/stats/tests/test_cluster_level.py | 645 ++-- mne/stats/tests/test_multi_comp.py | 9 +- mne/stats/tests/test_parametric.py | 144 +- mne/stats/tests/test_permutations.py | 40 +- mne/stats/tests/test_regression.py | 82 +- mne/surface.py | 1111 +++--- mne/tests/test_annotations.py | 1131 +++--- mne/tests/test_bem.py | 566 +-- mne/tests/test_chpi.py | 573 +-- mne/tests/test_coreg.py | 378 +- mne/tests/test_cov.py | 749 ++-- mne/tests/test_defaults.py | 33 +- mne/tests/test_dipole.py | 361 +- mne/tests/test_docstring_parameters.py | 267 +- mne/tests/test_epochs.py | 3101 ++++++++++------- mne/tests/test_event.py | 537 +-- mne/tests/test_evoked.py | 421 +-- mne/tests/test_filter.py | 894 +++-- mne/tests/test_freesurfer.py | 186 +- mne/tests/test_import_nesting.py | 5 +- mne/tests/test_label.py | 884 +++-- mne/tests/test_line_endings.py | 77 +- mne/tests/test_morph.py | 855 +++-- mne/tests/test_morph_map.py | 27 +- mne/tests/test_ola.py | 63 +- mne/tests/test_parallel.py | 20 +- mne/tests/test_proj.py | 331 +- mne/tests/test_rank.py | 237 +- mne/tests/test_read_vectorview_selection.py | 43 +- mne/tests/test_source_estimate.py | 1303 +++---- mne/tests/test_source_space.py | 913 ++--- mne/tests/test_surface.py | 436 ++- mne/tests/test_transforms.py | 370 +- mne/time_frequency/__init__.py | 33 +- mne/time_frequency/_stft.py | 73 +- mne/time_frequency/_stockwell.py | 107 +- mne/time_frequency/ar.py | 8 +- mne/time_frequency/csd.py | 584 +++- mne/time_frequency/multitaper.py | 157 +- mne/time_frequency/psd.py | 89 +- mne/time_frequency/spectrum.py | 791 +++-- mne/time_frequency/tests/test_ar.py | 21 +- mne/time_frequency/tests/test_csd.py | 406 ++- mne/time_frequency/tests/test_multitaper.py | 20 +- mne/time_frequency/tests/test_psd.py | 103 +- mne/time_frequency/tests/test_spectrum.py | 239 +- mne/time_frequency/tests/test_stft.py | 29 +- mne/time_frequency/tests/test_stockwell.py | 82 +- mne/time_frequency/tests/test_tfr.py | 1083 +++--- mne/time_frequency/tfr.py | 1514 +++++--- mne/transforms.py | 937 +++-- mne/utils/__init__.py | 283 +- mne/utils/_bunch.py | 9 +- mne/utils/_logging.py | 142 +- mne/utils/_testing.py | 198 +- mne/utils/check.py | 660 ++-- mne/utils/config.py | 485 +-- mne/utils/dataframe.py | 64 +- mne/utils/docs.py | 2059 +++++++---- mne/utils/fetching.py | 5 +- mne/utils/linalg.py | 62 +- mne/utils/misc.py | 124 +- mne/utils/mixin.py | 223 +- mne/utils/numerics.py | 326 +- mne/utils/progressbar.py | 64 +- mne/utils/spectrum.py | 36 +- mne/utils/tests/test_bunch.py | 4 +- mne/utils/tests/test_check.py | 254 +- mne/utils/tests/test_config.py | 91 +- mne/utils/tests/test_docs.py | 117 +- mne/utils/tests/test_linalg.py | 40 +- mne/utils/tests/test_logging.py | 134 +- mne/utils/tests/test_misc.py | 88 +- mne/utils/tests/test_numerics.py | 394 ++- mne/utils/tests/test_progressbar.py | 71 +- mne/utils/tests/test_testing.py | 29 +- mne/viz/_3d.py | 3016 ++++++++++------ mne/viz/_3d_overlay.py | 24 +- mne/viz/__init__.py | 111 +- mne/viz/_brain/__init__.py | 2 +- mne/viz/_brain/_brain.py | 2057 ++++++----- mne/viz/_brain/_linkviewer.py | 18 +- mne/viz/_brain/_scraper.py | 64 +- mne/viz/_brain/callback.py | 16 +- mne/viz/_brain/colormap.py | 94 +- mne/viz/_brain/surface.py | 69 +- mne/viz/_brain/tests/test_brain.py | 860 +++-- mne/viz/_brain/tests/test_notebook.py | 75 +- mne/viz/_brain/view.py | 57 +- mne/viz/_dipole.py | 170 +- mne/viz/_figure.py | 350 +- mne/viz/_mpl_figure.py | 1348 ++++--- mne/viz/_proj.py | 152 +- mne/viz/_scraper.py | 31 +- mne/viz/backends/_abstract.py | 377 +- mne/viz/backends/_notebook.py | 684 ++-- mne/viz/backends/_pyvista.py | 699 ++-- mne/viz/backends/_qt.py | 573 +-- mne/viz/backends/_utils.py | 166 +- mne/viz/backends/renderer.py | 88 +- mne/viz/backends/tests/_utils.py | 4 +- mne/viz/backends/tests/test_abstract.py | 68 +- mne/viz/backends/tests/test_renderer.py | 145 +- mne/viz/backends/tests/test_utils.py | 59 +- mne/viz/circle.py | 195 +- mne/viz/conftest.py | 20 +- mne/viz/epochs.py | 684 ++-- mne/viz/evoked.py | 1975 +++++++---- mne/viz/ica.py | 885 +++-- mne/viz/misc.py | 866 +++-- mne/viz/montage.py | 35 +- mne/viz/raw.py | 364 +- mne/viz/tests/test_3d.py | 1058 +++--- mne/viz/tests/test_3d_mpl.py | 131 +- mne/viz/tests/test_circle.py | 25 +- mne/viz/tests/test_epochs.py | 308 +- mne/viz/tests/test_evoked.py | 507 +-- mne/viz/tests/test_figure.py | 4 +- mne/viz/tests/test_ica.py | 295 +- mne/viz/tests/test_misc.py | 295 +- mne/viz/tests/test_montage.py | 63 +- mne/viz/tests/test_proj.py | 44 +- mne/viz/tests/test_raw.py | 630 ++-- mne/viz/tests/test_scraper.py | 17 +- mne/viz/tests/test_topo.py | 311 +- mne/viz/tests/test_topomap.py | 601 ++-- mne/viz/tests/test_utils.py | 93 +- mne/viz/topo.py | 843 +++-- mne/viz/topomap.py | 2563 +++++++++----- mne/viz/utils.py | 1448 +++++--- setup.py | 200 +- tools/check_mne_location.py | 5 +- tools/generate_codemeta.py | 136 +- tutorials/clinical/20_seeg.py | 126 +- tutorials/clinical/30_ecog.py | 114 +- tutorials/clinical/60_sleep.py | 132 +- tutorials/epochs/10_epochs_overview.py | 86 +- tutorials/epochs/15_baseline_regression.py | 112 +- tutorials/epochs/20_visualize_epochs.py | 80 +- tutorials/epochs/30_epochs_metadata.py | 35 +- tutorials/epochs/40_autogenerate_metadata.py | 232 +- tutorials/epochs/50_epochs_to_data_frame.py | 91 +- .../epochs/60_make_fixed_length_epochs.py | 15 +- tutorials/evoked/10_evoked_overview.py | 55 +- tutorials/evoked/20_visualize_evoked.py | 97 +- tutorials/evoked/30_eeg_erp.py | 150 +- tutorials/evoked/40_whitened.py | 31 +- tutorials/forward/10_background_freesurfer.py | 9 +- tutorials/forward/20_source_alignment.py | 173 +- tutorials/forward/25_automated_coreg.py | 29 +- tutorials/forward/30_forward.py | 101 +- tutorials/forward/35_eeg_no_mri.py | 65 +- .../forward/50_background_freesurfer_mne.py | 268 +- tutorials/forward/80_fix_bem_in_blender.py | 38 +- tutorials/forward/90_compute_covariance.py | 50 +- tutorials/intro/10_overview.py | 120 +- tutorials/intro/15_inplace.py | 16 +- tutorials/intro/20_events_from_raw.py | 40 +- tutorials/intro/30_info.py | 22 +- tutorials/intro/40_sensor_locations.py | 52 +- tutorials/intro/50_configure_mne.py | 50 +- tutorials/intro/70_report.py | 289 +- tutorials/inverse/10_stc_class.py | 45 +- tutorials/inverse/20_dipole_fit.py | 93 +- tutorials/inverse/30_mne_dspm_loreta.py | 86 +- tutorials/inverse/35_dipole_orientations.py | 148 +- tutorials/inverse/40_mne_fixed_free.py | 69 +- tutorials/inverse/50_beamformer_lcmv.py | 118 +- tutorials/inverse/60_visualize_stc.py | 122 +- tutorials/inverse/70_eeg_mri_coords.py | 49 +- .../inverse/80_brainstorm_phantom_elekta.py | 92 +- .../inverse/85_brainstorm_phantom_ctf.py | 46 +- tutorials/inverse/90_phantom_4DBTi.py | 39 +- tutorials/io/30_reading_fnirs_data.py | 59 +- tutorials/io/60_ctf_bst_auditory.py | 195 +- tutorials/io/70_reading_eyetracking_data.py | 10 +- tutorials/machine-learning/30_strf.py | 135 +- tutorials/machine-learning/50_decoding.py | 147 +- .../10_preprocessing_overview.py | 28 +- .../preprocessing/15_handling_bad_channels.py | 39 +- .../preprocessing/20_rejecting_bad_data.py | 92 +- .../preprocessing/25_background_filtering.py | 338 +- .../preprocessing/30_filtering_resampling.py | 69 +- .../35_artifact_correction_regression.py | 50 +- .../40_artifact_correction_ica.py | 90 +- .../preprocessing/45_projectors_background.py | 96 +- .../50_artifact_correction_ssp.py | 116 +- .../preprocessing/55_setting_eeg_reference.py | 38 +- tutorials/preprocessing/59_head_positions.py | 12 +- .../preprocessing/60_maxwell_filtering_sss.py | 105 +- .../preprocessing/70_fnirs_processing.py | 191 +- tutorials/preprocessing/80_opm_processing.py | 74 +- .../preprocessing/90_eyetracking_data.py | 27 +- tutorials/raw/10_raw_overview.py | 64 +- tutorials/raw/20_event_arrays.py | 42 +- tutorials/raw/30_annotate_raw.py | 51 +- tutorials/raw/40_visualize_raw.py | 11 +- tutorials/simulation/10_array_objs.py | 89 +- tutorials/simulation/70_point_spread.py | 84 +- tutorials/simulation/80_dics.py | 146 +- .../stats-sensor-space/10_background_stats.py | 156 +- tutorials/stats-sensor-space/20_erp_stats.py | 55 +- .../40_cluster_1samp_time_freq.py | 124 +- .../50_cluster_between_time_freq.py | 129 +- .../70_cluster_rmANOVA_time_freq.py | 146 +- .../75_cluster_ftest_spatiotemporal.py | 167 +- .../20_cluster_1samp_spatiotemporal.py | 101 +- .../30_cluster_ftest_spatiotemporal.py | 66 +- .../60_cluster_rmANOVA_spatiotemporal.py | 137 +- tutorials/time-freq/10_spectrum_class.py | 33 +- .../time-freq/20_sensors_time_frequency.py | 122 +- tutorials/time-freq/50_ssvep.py | 383 +- 707 files changed, 107021 insertions(+), 70350 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index e0748d37053..8f904e45d85 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -22,17 +22,21 @@ import mne from mne.fixes import _compare_version from mne.tests.test_docstring_parameters import error_ignores -from mne.utils import (linkcode_resolve, # noqa, analysis:ignore - _assert_no_instances, sizeof_fmt, run_subprocess) +from mne.utils import ( + linkcode_resolve, # noqa, analysis:ignore + _assert_no_instances, + sizeof_fmt, + run_subprocess, +) from mne.viz import Brain # noqa -matplotlib.use('agg') +matplotlib.use("agg") faulthandler.enable() -os.environ['_MNE_BROWSER_NO_BLOCK'] = 'true' -os.environ['MNE_BROWSER_OVERVIEW_MODE'] = 'hidden' -os.environ['MNE_BROWSER_THEME'] = 'light' -os.environ['MNE_3D_OPTION_THEME'] = 'light' -sphinx_logger = sphinx.util.logging.getLogger('mne') +os.environ["_MNE_BROWSER_NO_BLOCK"] = "true" +os.environ["MNE_BROWSER_OVERVIEW_MODE"] = "hidden" +os.environ["MNE_BROWSER_THEME"] = "light" +os.environ["MNE_3D_OPTION_THEME"] = "light" +sphinx_logger = sphinx.util.logging.getLogger("mne") # -- Path setup -------------------------------------------------------------- @@ -40,22 +44,23 @@ # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. curdir = os.path.dirname(__file__) -sys.path.append(os.path.abspath(os.path.join(curdir, '..', 'mne'))) -sys.path.append(os.path.abspath(os.path.join(curdir, 'sphinxext'))) +sys.path.append(os.path.abspath(os.path.join(curdir, "..", "mne"))) +sys.path.append(os.path.abspath(os.path.join(curdir, "sphinxext"))) # -- Project information ----------------------------------------------------- -project = 'MNE' +project = "MNE" td = datetime.now(tz=timezone.utc) # We need to triage which date type we use so that incremental builds work # (Sphinx looks at variable changes and rewrites all files if some change) copyright = ( f'2012–{td.year}, MNE Developers. Last updated \n' # noqa: E501 - '') # noqa: E501 -if os.getenv('MNE_FULL_DATE', 'false').lower() != 'true': - copyright = f'2012–{td.year}, MNE Developers. Last updated locally.' + '' +) # noqa: E501 +if os.getenv("MNE_FULL_DATE", "false").lower() != "true": + copyright = f"2012–{td.year}, MNE Developers. Last updated locally." # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the @@ -63,71 +68,70 @@ # # The full version, including alpha/beta/rc tags. release = mne.__version__ -sphinx_logger.info( - f'Building documentation for MNE {release} ({mne.__file__})') +sphinx_logger.info(f"Building documentation for MNE {release} ({mne.__file__})") # The short X.Y version. -version = '.'.join(release.split('.')[:2]) +version = ".".join(release.split(".")[:2]) # -- General configuration --------------------------------------------------- # If your documentation needs a minimal Sphinx version, state it here. -needs_sphinx = '2.0' +needs_sphinx = "2.0" # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.autosummary', - 'sphinx.ext.coverage', - 'sphinx.ext.doctest', - 'sphinx.ext.intersphinx', - 'sphinx.ext.linkcode', - 'sphinx.ext.mathjax', - 'sphinx.ext.todo', - 'sphinx.ext.graphviz', - 'numpydoc', - 'sphinx_gallery.gen_gallery', - 'gen_commands', - 'gh_substitutions', - 'mne_substitutions', - 'newcontrib_substitutions', - 'gen_names', - 'matplotlib.sphinxext.plot_directive', - 'sphinxcontrib.bibtex', - 'sphinx_copybutton', - 'sphinx_design', - 'sphinxcontrib.youtube', - 'unit_role', + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.coverage", + "sphinx.ext.doctest", + "sphinx.ext.intersphinx", + "sphinx.ext.linkcode", + "sphinx.ext.mathjax", + "sphinx.ext.todo", + "sphinx.ext.graphviz", + "numpydoc", + "sphinx_gallery.gen_gallery", + "gen_commands", + "gh_substitutions", + "mne_substitutions", + "newcontrib_substitutions", + "gen_names", + "matplotlib.sphinxext.plot_directive", + "sphinxcontrib.bibtex", + "sphinx_copybutton", + "sphinx_design", + "sphinxcontrib.youtube", + "unit_role", ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ['_includes'] +exclude_patterns = ["_includes"] # The suffix of source filenames. -source_suffix = '.rst' +source_suffix = ".rst" # The main toctree document. -master_doc = 'index' +master_doc = "index" # List of documents that shouldn't be included in the build. unused_docs = [] # List of directories, relative to source directory, that shouldn't be searched # for source files. -exclude_trees = ['_build'] +exclude_trees = ["_build"] # The reST default role (used for this markup: `text`) to use for all # documents. default_role = "py:obj" # A list of ignored prefixes for module index sorting. -modindex_common_prefix = ['mne.'] +modindex_common_prefix = ["mne."] # -- Sphinx-Copybutton configuration ----------------------------------------- copybutton_prompt_text = r">>> |\.\.\. |\$ " @@ -136,36 +140,38 @@ # -- Intersphinx configuration ----------------------------------------------- intersphinx_mapping = { - 'python': ('https://docs.python.org/3', None), - 'numpy': ('https://numpy.org/doc/stable', None), - 'scipy': ('https://docs.scipy.org/doc/scipy', None), - 'matplotlib': ('https://matplotlib.org/stable', None), - 'sklearn': ('https://scikit-learn.org/stable', None), - 'numba': ('https://numba.readthedocs.io/en/latest', None), - 'joblib': ('https://joblib.readthedocs.io/en/latest', None), - 'nibabel': ('https://nipy.org/nibabel', None), - 'nilearn': ('http://nilearn.github.io/stable', None), - 'nitime': ('https://nipy.org/nitime/', None), - 'surfer': ('https://pysurfer.github.io/', None), - 'mne_bids': ('https://mne.tools/mne-bids/stable', None), - 'mne-connectivity': ('https://mne.tools/mne-connectivity/stable', None), - 'mne-gui-addons': ('https://mne.tools/mne-gui-addons', None), - 'pandas': ('https://pandas.pydata.org/pandas-docs/stable', None), - 'seaborn': ('https://seaborn.pydata.org/', None), - 'statsmodels': ('https://www.statsmodels.org/dev', None), - 'patsy': ('https://patsy.readthedocs.io/en/latest', None), - 'pyvista': ('https://docs.pyvista.org', None), - 'imageio': ('https://imageio.readthedocs.io/en/latest', None), - 'mne_realtime': ('https://mne.tools/mne-realtime', None), - 'picard': ('https://pierreablin.github.io/picard/', None), - 'qdarkstyle': ('https://qdarkstylesheet.readthedocs.io/en/latest', None), - 'eeglabio': ('https://eeglabio.readthedocs.io/en/latest', None), - 'dipy': ('https://dipy.org/documentation/1.7.0/', - 'https://dipy.org/documentation/1.7.0/objects.inv/'), - 'pooch': ('https://www.fatiando.org/pooch/latest/', None), - 'pybv': ('https://pybv.readthedocs.io/en/latest/', None), - 'pyqtgraph': ('https://pyqtgraph.readthedocs.io/en/latest/', None), - 'openmeeg': ('https://openmeeg.github.io', None), + "python": ("https://docs.python.org/3", None), + "numpy": ("https://numpy.org/doc/stable", None), + "scipy": ("https://docs.scipy.org/doc/scipy", None), + "matplotlib": ("https://matplotlib.org/stable", None), + "sklearn": ("https://scikit-learn.org/stable", None), + "numba": ("https://numba.readthedocs.io/en/latest", None), + "joblib": ("https://joblib.readthedocs.io/en/latest", None), + "nibabel": ("https://nipy.org/nibabel", None), + "nilearn": ("http://nilearn.github.io/stable", None), + "nitime": ("https://nipy.org/nitime/", None), + "surfer": ("https://pysurfer.github.io/", None), + "mne_bids": ("https://mne.tools/mne-bids/stable", None), + "mne-connectivity": ("https://mne.tools/mne-connectivity/stable", None), + "mne-gui-addons": ("https://mne.tools/mne-gui-addons", None), + "pandas": ("https://pandas.pydata.org/pandas-docs/stable", None), + "seaborn": ("https://seaborn.pydata.org/", None), + "statsmodels": ("https://www.statsmodels.org/dev", None), + "patsy": ("https://patsy.readthedocs.io/en/latest", None), + "pyvista": ("https://docs.pyvista.org", None), + "imageio": ("https://imageio.readthedocs.io/en/latest", None), + "mne_realtime": ("https://mne.tools/mne-realtime", None), + "picard": ("https://pierreablin.github.io/picard/", None), + "qdarkstyle": ("https://qdarkstylesheet.readthedocs.io/en/latest", None), + "eeglabio": ("https://eeglabio.readthedocs.io/en/latest", None), + "dipy": ( + "https://dipy.org/documentation/1.7.0/", + "https://dipy.org/documentation/1.7.0/objects.inv/", + ), + "pooch": ("https://www.fatiando.org/pooch/latest/", None), + "pybv": ("https://pybv.readthedocs.io/en/latest/", None), + "pyqtgraph": ("https://pyqtgraph.readthedocs.io/en/latest/", None), + "openmeeg": ("https://openmeeg.github.io", None), } @@ -175,127 +181,251 @@ docscrape.ClassDoc.extra_public_methods = mne.utils._doc_special_members numpydoc_class_members_toctree = False numpydoc_show_inherited_class_members = { - 'mne.SourceSpaces': False, - 'mne.Forward': False, + "mne.SourceSpaces": False, + "mne.Forward": False, } numpydoc_attributes_as_param_list = True numpydoc_xref_param_type = True numpydoc_xref_aliases = { # Python - 'file-like': ':term:`file-like `', - 'iterator': ':term:`iterator `', - 'path-like': ':term:`path-like`', - 'array-like': ':term:`array_like `', - 'Path': ':class:`python:pathlib.Path`', - 'bool': ':class:`python:bool`', + "file-like": ":term:`file-like `", + "iterator": ":term:`iterator `", + "path-like": ":term:`path-like`", + "array-like": ":term:`array_like `", + "Path": ":class:`python:pathlib.Path`", + "bool": ":class:`python:bool`", # Matplotlib - 'colormap': ':doc:`colormap `', - 'color': ':doc:`color `', - 'Axes': 'matplotlib.axes.Axes', - 'Figure': 'matplotlib.figure.Figure', - 'Axes3D': 'mpl_toolkits.mplot3d.axes3d.Axes3D', - 'ColorbarBase': 'matplotlib.colorbar.ColorbarBase', + "colormap": ":doc:`colormap `", + "color": ":doc:`color `", + "Axes": "matplotlib.axes.Axes", + "Figure": "matplotlib.figure.Figure", + "Axes3D": "mpl_toolkits.mplot3d.axes3d.Axes3D", + "ColorbarBase": "matplotlib.colorbar.ColorbarBase", # sklearn - 'LeaveOneOut': 'sklearn.model_selection.LeaveOneOut', + "LeaveOneOut": "sklearn.model_selection.LeaveOneOut", # joblib - 'joblib.Parallel': 'joblib.Parallel', + "joblib.Parallel": "joblib.Parallel", # nibabel - 'Nifti1Image': 'nibabel.nifti1.Nifti1Image', - 'Nifti2Image': 'nibabel.nifti2.Nifti2Image', - 'SpatialImage': 'nibabel.spatialimages.SpatialImage', + "Nifti1Image": "nibabel.nifti1.Nifti1Image", + "Nifti2Image": "nibabel.nifti2.Nifti2Image", + "SpatialImage": "nibabel.spatialimages.SpatialImage", # MNE - 'Label': 'mne.Label', 'Forward': 'mne.Forward', 'Evoked': 'mne.Evoked', - 'Info': 'mne.Info', 'SourceSpaces': 'mne.SourceSpaces', - 'Epochs': 'mne.Epochs', 'Layout': 'mne.channels.Layout', - 'EvokedArray': 'mne.EvokedArray', 'BiHemiLabel': 'mne.BiHemiLabel', - 'AverageTFR': 'mne.time_frequency.AverageTFR', - 'EpochsTFR': 'mne.time_frequency.EpochsTFR', - 'Raw': 'mne.io.Raw', 'ICA': 'mne.preprocessing.ICA', - 'Covariance': 'mne.Covariance', 'Annotations': 'mne.Annotations', - 'DigMontage': 'mne.channels.DigMontage', - 'VectorSourceEstimate': 'mne.VectorSourceEstimate', - 'VolSourceEstimate': 'mne.VolSourceEstimate', - 'VolVectorSourceEstimate': 'mne.VolVectorSourceEstimate', - 'MixedSourceEstimate': 'mne.MixedSourceEstimate', - 'MixedVectorSourceEstimate': 'mne.MixedVectorSourceEstimate', - 'SourceEstimate': 'mne.SourceEstimate', 'Projection': 'mne.Projection', - 'ConductorModel': 'mne.bem.ConductorModel', - 'Dipole': 'mne.Dipole', 'DipoleFixed': 'mne.DipoleFixed', - 'InverseOperator': 'mne.minimum_norm.InverseOperator', - 'CrossSpectralDensity': 'mne.time_frequency.CrossSpectralDensity', - 'SourceMorph': 'mne.SourceMorph', - 'Xdawn': 'mne.preprocessing.Xdawn', - 'Report': 'mne.Report', - 'TimeDelayingRidge': 'mne.decoding.TimeDelayingRidge', - 'Vectorizer': 'mne.decoding.Vectorizer', - 'UnsupervisedSpatialFilter': 'mne.decoding.UnsupervisedSpatialFilter', - 'TemporalFilter': 'mne.decoding.TemporalFilter', - 'SSD': 'mne.decoding.SSD', - 'Scaler': 'mne.decoding.Scaler', 'SPoC': 'mne.decoding.SPoC', - 'PSDEstimator': 'mne.decoding.PSDEstimator', - 'LinearModel': 'mne.decoding.LinearModel', - 'FilterEstimator': 'mne.decoding.FilterEstimator', - 'EMS': 'mne.decoding.EMS', 'CSP': 'mne.decoding.CSP', - 'Beamformer': 'mne.beamformer.Beamformer', - 'Transform': 'mne.transforms.Transform', - 'Coregistration': 'mne.coreg.Coregistration', - 'Figure3D': 'mne.viz.Figure3D', - 'EOGRegression': 'mne.preprocessing.EOGRegression', - 'Spectrum': 'mne.time_frequency.Spectrum', - 'EpochsSpectrum': 'mne.time_frequency.EpochsSpectrum', + "Label": "mne.Label", + "Forward": "mne.Forward", + "Evoked": "mne.Evoked", + "Info": "mne.Info", + "SourceSpaces": "mne.SourceSpaces", + "Epochs": "mne.Epochs", + "Layout": "mne.channels.Layout", + "EvokedArray": "mne.EvokedArray", + "BiHemiLabel": "mne.BiHemiLabel", + "AverageTFR": "mne.time_frequency.AverageTFR", + "EpochsTFR": "mne.time_frequency.EpochsTFR", + "Raw": "mne.io.Raw", + "ICA": "mne.preprocessing.ICA", + "Covariance": "mne.Covariance", + "Annotations": "mne.Annotations", + "DigMontage": "mne.channels.DigMontage", + "VectorSourceEstimate": "mne.VectorSourceEstimate", + "VolSourceEstimate": "mne.VolSourceEstimate", + "VolVectorSourceEstimate": "mne.VolVectorSourceEstimate", + "MixedSourceEstimate": "mne.MixedSourceEstimate", + "MixedVectorSourceEstimate": "mne.MixedVectorSourceEstimate", + "SourceEstimate": "mne.SourceEstimate", + "Projection": "mne.Projection", + "ConductorModel": "mne.bem.ConductorModel", + "Dipole": "mne.Dipole", + "DipoleFixed": "mne.DipoleFixed", + "InverseOperator": "mne.minimum_norm.InverseOperator", + "CrossSpectralDensity": "mne.time_frequency.CrossSpectralDensity", + "SourceMorph": "mne.SourceMorph", + "Xdawn": "mne.preprocessing.Xdawn", + "Report": "mne.Report", + "TimeDelayingRidge": "mne.decoding.TimeDelayingRidge", + "Vectorizer": "mne.decoding.Vectorizer", + "UnsupervisedSpatialFilter": "mne.decoding.UnsupervisedSpatialFilter", + "TemporalFilter": "mne.decoding.TemporalFilter", + "SSD": "mne.decoding.SSD", + "Scaler": "mne.decoding.Scaler", + "SPoC": "mne.decoding.SPoC", + "PSDEstimator": "mne.decoding.PSDEstimator", + "LinearModel": "mne.decoding.LinearModel", + "FilterEstimator": "mne.decoding.FilterEstimator", + "EMS": "mne.decoding.EMS", + "CSP": "mne.decoding.CSP", + "Beamformer": "mne.beamformer.Beamformer", + "Transform": "mne.transforms.Transform", + "Coregistration": "mne.coreg.Coregistration", + "Figure3D": "mne.viz.Figure3D", + "EOGRegression": "mne.preprocessing.EOGRegression", + "Spectrum": "mne.time_frequency.Spectrum", + "EpochsSpectrum": "mne.time_frequency.EpochsSpectrum", # dipy - 'dipy.align.AffineMap': 'dipy.align.imaffine.AffineMap', - 'dipy.align.DiffeomorphicMap': 'dipy.align.imwarp.DiffeomorphicMap', + "dipy.align.AffineMap": "dipy.align.imaffine.AffineMap", + "dipy.align.DiffeomorphicMap": "dipy.align.imwarp.DiffeomorphicMap", } numpydoc_xref_ignore = { # words - 'instance', 'instances', 'of', 'default', 'shape', 'or', - 'with', 'length', 'pair', 'matplotlib', 'optional', 'kwargs', 'in', - 'dtype', 'object', + "instance", + "instances", + "of", + "default", + "shape", + "or", + "with", + "length", + "pair", + "matplotlib", + "optional", + "kwargs", + "in", + "dtype", + "object", # shapes - 'n_vertices', 'n_faces', 'n_channels', 'm', 'n', 'n_events', 'n_colors', - 'n_times', 'obj', 'n_chan', 'n_epochs', 'n_picks', 'n_ch_groups', - 'n_dipoles', 'n_ica_components', 'n_pos', 'n_node_names', 'n_tapers', - 'n_signals', 'n_step', 'n_freqs', 'wsize', 'Tx', 'M', 'N', 'p', 'q', 'r', - 'n_observations', 'n_regressors', 'n_cols', 'n_frequencies', 'n_tests', - 'n_samples', 'n_permutations', 'nchan', 'n_points', 'n_features', - 'n_parts', 'n_features_new', 'n_components', 'n_labels', 'n_events_in', - 'n_splits', 'n_scores', 'n_outputs', 'n_trials', 'n_estimators', 'n_tasks', - 'nd_features', 'n_classes', 'n_targets', 'n_slices', 'n_hpi', 'n_fids', - 'n_elp', 'n_pts', 'n_tris', 'n_nodes', 'n_nonzero', 'n_events_out', - 'n_segments', 'n_orient_inv', 'n_orient_fwd', 'n_orient', 'n_dipoles_lcmv', - 'n_dipoles_fwd', 'n_picks_ref', 'n_coords', 'n_meg', 'n_good_meg', - 'n_moments', 'n_patterns', 'n_new_events', + "n_vertices", + "n_faces", + "n_channels", + "m", + "n", + "n_events", + "n_colors", + "n_times", + "obj", + "n_chan", + "n_epochs", + "n_picks", + "n_ch_groups", + "n_dipoles", + "n_ica_components", + "n_pos", + "n_node_names", + "n_tapers", + "n_signals", + "n_step", + "n_freqs", + "wsize", + "Tx", + "M", + "N", + "p", + "q", + "r", + "n_observations", + "n_regressors", + "n_cols", + "n_frequencies", + "n_tests", + "n_samples", + "n_permutations", + "nchan", + "n_points", + "n_features", + "n_parts", + "n_features_new", + "n_components", + "n_labels", + "n_events_in", + "n_splits", + "n_scores", + "n_outputs", + "n_trials", + "n_estimators", + "n_tasks", + "nd_features", + "n_classes", + "n_targets", + "n_slices", + "n_hpi", + "n_fids", + "n_elp", + "n_pts", + "n_tris", + "n_nodes", + "n_nonzero", + "n_events_out", + "n_segments", + "n_orient_inv", + "n_orient_fwd", + "n_orient", + "n_dipoles_lcmv", + "n_dipoles_fwd", + "n_picks_ref", + "n_coords", + "n_meg", + "n_good_meg", + "n_moments", + "n_patterns", + "n_new_events", # Undocumented (on purpose) - 'RawKIT', 'RawEximia', 'RawEGI', 'RawEEGLAB', 'RawEDF', 'RawCTF', 'RawBTi', - 'RawBrainVision', 'RawCurry', 'RawNIRX', 'RawGDF', 'RawSNIRF', 'RawBOXY', - 'RawPersyst', 'RawNihon', 'RawNedf', 'RawHitachi', 'RawFIL', 'RawEyelink', + "RawKIT", + "RawEximia", + "RawEGI", + "RawEEGLAB", + "RawEDF", + "RawCTF", + "RawBTi", + "RawBrainVision", + "RawCurry", + "RawNIRX", + "RawGDF", + "RawSNIRF", + "RawBOXY", + "RawPersyst", + "RawNihon", + "RawNedf", + "RawHitachi", + "RawFIL", + "RawEyelink", # sklearn subclasses - 'mapping', 'to', 'any', + "mapping", + "to", + "any", # unlinkable - 'CoregistrationUI', - 'IntracranialElectrodeLocator', - 'mne_qt_browser.figure.MNEQtBrowser', + "CoregistrationUI", + "IntracranialElectrodeLocator", + "mne_qt_browser.figure.MNEQtBrowser", } numpydoc_validate = True -numpydoc_validation_checks = {'all'} | set(error_ignores) +numpydoc_validation_checks = {"all"} | set(error_ignores) numpydoc_validation_exclude = { # set of regex # dict subclasses - r'\.clear', r'\.get$', r'\.copy$', r'\.fromkeys', r'\.items', r'\.keys', - r'\.pop', r'\.popitem', r'\.setdefault', r'\.update', r'\.values', + r"\.clear", + r"\.get$", + r"\.copy$", + r"\.fromkeys", + r"\.items", + r"\.keys", + r"\.pop", + r"\.popitem", + r"\.setdefault", + r"\.update", + r"\.values", # list subclasses - r'\.append', r'\.count', r'\.extend', r'\.index', r'\.insert', r'\.remove', - r'\.sort', + r"\.append", + r"\.count", + r"\.extend", + r"\.index", + r"\.insert", + r"\.remove", + r"\.sort", # we currently don't document these properly (probably okay) - r'\.__getitem__', r'\.__contains__', r'\.__hash__', r'\.__mul__', - r'\.__sub__', r'\.__add__', r'\.__iter__', r'\.__div__', r'\.__neg__', + r"\.__getitem__", + r"\.__contains__", + r"\.__hash__", + r"\.__mul__", + r"\.__sub__", + r"\.__add__", + r"\.__iter__", + r"\.__div__", + r"\.__neg__", # copied from sklearn - r'mne\.utils\.deprecated', + r"mne\.utils\.deprecated", } # -- Sphinx-gallery configuration -------------------------------------------- + class Resetter(object): """Simple class to make the str(obj) static for Sphinx build env hash.""" @@ -303,10 +433,11 @@ def __init__(self): self.t0 = time.time() def __repr__(self): - return f'<{self.__class__.__name__}>' + return f"<{self.__class__.__name__}>" def __call__(self, gallery_conf, fname, when): import matplotlib.pyplot as plt + try: from pyvista import Plotter # noqa except ImportError: @@ -324,45 +455,46 @@ def __call__(self, gallery_conf, fname, when): except ImportError: MNEQtBrowser = None from mne.viz.backends.renderer import backend + _Renderer = backend._Renderer if backend is not None else None reset_warnings(gallery_conf, fname) # in case users have interactive mode turned on in matplotlibrc, # turn it off here (otherwise the build can be very slow) plt.ioff() - plt.rcParams['animation.embed_limit'] = 30. - plt.rcParams['figure.raise_window'] = False + plt.rcParams["animation.embed_limit"] = 30.0 + plt.rcParams["figure.raise_window"] = False # neo holds on to an exception, which in turn holds a stack frame, # which will keep alive the global vars during SG execution try: import neo + neo.io.stimfitio.STFIO_ERR = None except Exception: pass gc.collect() - when = f'mne/conf.py:Resetter.__call__:{when}:{fname}' + when = f"mne/conf.py:Resetter.__call__:{when}:{fname}" # Support stuff like # MNE_SKIP_INSTANCE_ASSERTIONS="Brain,Plotter,BackgroundPlotter,vtkPolyData,_Renderer" make html-memory # noqa: E501 # to just test MNEQtBrowser - skips = os.getenv('MNE_SKIP_INSTANCE_ASSERTIONS', '').lower() - prefix = '' - if skips not in ('true', '1', 'all'): - prefix = 'Clean ' - skips = skips.split(',') - if 'brain' not in skips: + skips = os.getenv("MNE_SKIP_INSTANCE_ASSERTIONS", "").lower() + prefix = "" + if skips not in ("true", "1", "all"): + prefix = "Clean " + skips = skips.split(",") + if "brain" not in skips: _assert_no_instances(Brain, when) # calls gc.collect() - if Plotter is not None and 'plotter' not in skips: + if Plotter is not None and "plotter" not in skips: _assert_no_instances(Plotter, when) - if BackgroundPlotter is not None and \ - 'backgroundplotter' not in skips: + if BackgroundPlotter is not None and "backgroundplotter" not in skips: _assert_no_instances(BackgroundPlotter, when) - if vtkPolyData is not None and 'vtkpolydata' not in skips: + if vtkPolyData is not None and "vtkpolydata" not in skips: _assert_no_instances(vtkPolyData, when) - if '_renderer' not in skips: + if "_renderer" not in skips: _assert_no_instances(_Renderer, when) - if MNEQtBrowser is not None and \ - 'mneqtbrowser' not in skips: + if MNEQtBrowser is not None and "mneqtbrowser" not in skips: # Ensure any manual fig.close() events get properly handled from mne_qt_browser._pg_figure import QApplication + inst = QApplication.instance() if inst is not None: for _ in range(2): @@ -370,18 +502,19 @@ def __call__(self, gallery_conf, fname, when): _assert_no_instances(MNEQtBrowser, when) # This will overwrite some Sphinx printing but it's useful # for memory timestamps - if os.getenv('SG_STAMP_STARTS', '').lower() == 'true': + if os.getenv("SG_STAMP_STARTS", "").lower() == "true": import psutil + process = psutil.Process(os.getpid()) mem = sizeof_fmt(process.memory_info().rss) - print(f'{prefix}{time.time() - self.t0:6.1f} s : {mem}'.ljust(22)) + print(f"{prefix}{time.time() - self.t0:6.1f} s : {mem}".ljust(22)) -examples_dirs = ['../tutorials', '../examples'] -gallery_dirs = ['auto_tutorials', 'auto_examples'] -os.environ['_MNE_BUILDING_DOC'] = 'true' -scrapers = ('matplotlib',) -mne.viz.set_3d_backend('pyvistaqt') +examples_dirs = ["../tutorials", "../examples"] +gallery_dirs = ["auto_tutorials", "auto_examples"] +os.environ["_MNE_BUILDING_DOC"] = "true" +scrapers = ("matplotlib",) +mne.viz.set_3d_backend("pyvistaqt") with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=DeprecationWarning) import pyvista @@ -390,111 +523,118 @@ def __call__(self, gallery_conf, fname, when): report_scraper = mne.report._ReportScraper() scrapers = ( - 'matplotlib', + "matplotlib", mne.gui._GUIScraper(), mne.viz._brain._BrainScraper(), - 'pyvista', + "pyvista", report_scraper, mne.viz._scraper._MNEQtBrowserScraper(), ) -compress_images = ('images', 'thumbnails') +compress_images = ("images", "thumbnails") # let's make things easier on Windows users # (on Linux and macOS it's easy enough to require this) -if sys.platform.startswith('win'): +if sys.platform.startswith("win"): try: - subprocess.check_call(['optipng', '--version']) + subprocess.check_call(["optipng", "--version"]) except Exception: compress_images = () sphinx_gallery_conf = { - 'doc_module': ('mne',), - 'reference_url': dict(mne=None), - 'examples_dirs': examples_dirs, - 'subsection_order': ExplicitOrder(['../examples/io/', - '../examples/simulation/', - '../examples/preprocessing/', - '../examples/visualization/', - '../examples/time_frequency/', - '../examples/stats/', - '../examples/decoding/', - '../examples/connectivity/', - '../examples/forward/', - '../examples/inverse/', - '../examples/realtime/', - '../examples/datasets/', - '../tutorials/intro/', - '../tutorials/io/', - '../tutorials/raw/', - '../tutorials/preprocessing/', - '../tutorials/epochs/', - '../tutorials/evoked/', - '../tutorials/time-freq/', - '../tutorials/forward/', - '../tutorials/inverse/', - '../tutorials/stats-sensor-space/', - '../tutorials/stats-source-space/', - '../tutorials/machine-learning/', - '../tutorials/clinical/', - '../tutorials/simulation/', - '../tutorials/sample-datasets/', - '../tutorials/misc/']), - 'gallery_dirs': gallery_dirs, - 'default_thumb_file': os.path.join('_static', 'mne_helmet.png'), - 'backreferences_dir': 'generated', - 'plot_gallery': 'True', # Avoid annoying Unicode/bool default warning - 'thumbnail_size': (160, 112), - 'remove_config_comments': True, - 'min_reported_time': 1., - 'abort_on_example_error': False, - 'reset_modules': ('matplotlib', Resetter()), # called w/each script - 'reset_modules_order': 'both', - 'image_scrapers': scrapers, - 'show_memory': not sys.platform.startswith(('win', 'darwin')), - 'line_numbers': False, # messes with style - 'within_subsection_order': FileNameSortKey, - 'capture_repr': ('_repr_html_',), - 'junit': os.path.join('..', 'test-results', 'sphinx-gallery', 'junit.xml'), - 'matplotlib_animations': True, - 'compress_images': compress_images, - 'filename_pattern': '^((?!sgskip).)*$', - 'exclude_implicit_doc': { - r'mne\.io\.read_raw_fif', r'mne\.io\.Raw', r'mne\.Epochs', - r'mne.datasets.*', + "doc_module": ("mne",), + "reference_url": dict(mne=None), + "examples_dirs": examples_dirs, + "subsection_order": ExplicitOrder( + [ + "../examples/io/", + "../examples/simulation/", + "../examples/preprocessing/", + "../examples/visualization/", + "../examples/time_frequency/", + "../examples/stats/", + "../examples/decoding/", + "../examples/connectivity/", + "../examples/forward/", + "../examples/inverse/", + "../examples/realtime/", + "../examples/datasets/", + "../tutorials/intro/", + "../tutorials/io/", + "../tutorials/raw/", + "../tutorials/preprocessing/", + "../tutorials/epochs/", + "../tutorials/evoked/", + "../tutorials/time-freq/", + "../tutorials/forward/", + "../tutorials/inverse/", + "../tutorials/stats-sensor-space/", + "../tutorials/stats-source-space/", + "../tutorials/machine-learning/", + "../tutorials/clinical/", + "../tutorials/simulation/", + "../tutorials/sample-datasets/", + "../tutorials/misc/", + ] + ), + "gallery_dirs": gallery_dirs, + "default_thumb_file": os.path.join("_static", "mne_helmet.png"), + "backreferences_dir": "generated", + "plot_gallery": "True", # Avoid annoying Unicode/bool default warning + "thumbnail_size": (160, 112), + "remove_config_comments": True, + "min_reported_time": 1.0, + "abort_on_example_error": False, + "reset_modules": ("matplotlib", Resetter()), # called w/each script + "reset_modules_order": "both", + "image_scrapers": scrapers, + "show_memory": not sys.platform.startswith(("win", "darwin")), + "line_numbers": False, # messes with style + "within_subsection_order": FileNameSortKey, + "capture_repr": ("_repr_html_",), + "junit": os.path.join("..", "test-results", "sphinx-gallery", "junit.xml"), + "matplotlib_animations": True, + "compress_images": compress_images, + "filename_pattern": "^((?!sgskip).)*$", + "exclude_implicit_doc": { + r"mne\.io\.read_raw_fif", + r"mne\.io\.Raw", + r"mne\.Epochs", + r"mne.datasets.*", }, - 'show_api_usage': False, # disable for now until graph warning fixed - 'api_usage_ignore': ( - '(' - '.*__.*__|' # built-ins - '.*Base.*|.*Array.*|mne.Vector.*|mne.Mixed.*|mne.Vol.*|' # inherited - 'mne.coreg.Coregistration.*|' # GUI + "show_api_usage": False, # disable for now until graph warning fixed + "api_usage_ignore": ( + "(" + ".*__.*__|" # built-ins + ".*Base.*|.*Array.*|mne.Vector.*|mne.Mixed.*|mne.Vol.*|" # inherited + "mne.coreg.Coregistration.*|" # GUI # common - '.*utils.*|.*verbose()|.*copy()|.*update()|.*save()|' - '.*get_data()|' + ".*utils.*|.*verbose()|.*copy()|.*update()|.*save()|" + ".*get_data()|" # mixins - '.*add_channels()|.*add_reference_channels()|' - '.*anonymize()|.*apply_baseline()|.*apply_function()|' - '.*apply_hilbert()|.*as_type()|.*decimate()|' - '.*drop()|.*drop_channels()|.*drop_log_stats()|' - '.*export()|.*get_channel_types()|' - '.*get_montage()|.*interpolate_bads()|.*next()|' - '.*pick()|.*pick_channels()|.*pick_types()|' - '.*plot_sensors()|.*rename_channels()|' - '.*reorder_channels()|.*savgol_filter()|' - '.*set_eeg_reference()|.*set_channel_types()|' - '.*set_meas_date()|.*set_montage()|.*shift_time()|' - '.*time_as_index()|.*to_data_frame()|' + ".*add_channels()|.*add_reference_channels()|" + ".*anonymize()|.*apply_baseline()|.*apply_function()|" + ".*apply_hilbert()|.*as_type()|.*decimate()|" + ".*drop()|.*drop_channels()|.*drop_log_stats()|" + ".*export()|.*get_channel_types()|" + ".*get_montage()|.*interpolate_bads()|.*next()|" + ".*pick()|.*pick_channels()|.*pick_types()|" + ".*plot_sensors()|.*rename_channels()|" + ".*reorder_channels()|.*savgol_filter()|" + ".*set_eeg_reference()|.*set_channel_types()|" + ".*set_meas_date()|.*set_montage()|.*shift_time()|" + ".*time_as_index()|.*to_data_frame()|" # dictionary inherited - '.*clear()|.*fromkeys()|.*get()|.*items()|' - '.*keys()|.*pop()|.*popitem()|.*setdefault()|' - '.*values()|' + ".*clear()|.*fromkeys()|.*get()|.*items()|" + ".*keys()|.*pop()|.*popitem()|.*setdefault()|" + ".*values()|" # sklearn inherited - '.*apply()|.*decision_function()|.*fit()|' - '.*fit_transform()|.*get_params()|.*predict()|' - '.*predict_proba()|.*set_params()|.*transform()|' + ".*apply()|.*decision_function()|.*fit()|" + ".*fit_transform()|.*get_params()|.*predict()|" + ".*predict_proba()|.*set_params()|.*transform()|" # I/O, also related to mixins - '.*.remove.*|.*.write.*)'), - 'copyfile_regex': r'.*index\.rst', # allow custom index.rst files + ".*.remove.*|.*.write.*)" + ), + "copyfile_regex": r".*index\.rst", # allow custom index.rst files } # Files were renamed from plot_* with: # find . -type f -name 'plot_*.py' -exec sh -c 'x="{}"; xn=`basename "${x}"`; git mv "$x" `dirname "${x}"`/${xn:5}' \; # noqa @@ -506,9 +646,12 @@ def append_attr_meth_examples(app, what, name, obj, options, lines): # does not respect the autodoc templates that would otherwise insert # the .. include:: lines, so we need to do it. # Eventually this could perhaps live in SG. - if what in ('attribute', 'method'): - size = os.path.getsize(os.path.join( - os.path.dirname(__file__), 'generated', '%s.examples' % (name,))) + if what in ("attribute", "method"): + size = os.path.getsize( + os.path.join( + os.path.dirname(__file__), "generated", "%s.examples" % (name,) + ) + ) if size > 0: lines += """ .. _sphx_glr_backreferences_{1}: @@ -517,12 +660,16 @@ def append_attr_meth_examples(app, what, name, obj, options, lines): .. minigallery:: {1} -""".format(name.split('.')[-1], name).split('\n') +""".format( + name.split(".")[-1], name + ).split( + "\n" + ) # -- Other extension configuration ------------------------------------------- -user_agent = 'Mozilla/5.0 (Linux; Android 6.0; Nexus 5 Build/MRA58N) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/113.0.0.0 Mobile Safari/537.36' # noqa: E501 +user_agent = "Mozilla/5.0 (Linux; Android 6.0; Nexus 5 Build/MRA58N) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/113.0.0.0 Mobile Safari/537.36" # noqa: E501 # Can eventually add linkcheck_request_headers if needed linkcheck_ignore = [ # will be compiled to regex # 403 Client Error: Forbidden @@ -560,12 +707,12 @@ def append_attr_meth_examples(app, what, name, obj, options, lines): # autodoc / autosummary autosummary_generate = True -autodoc_default_options = {'inherited-members': None} +autodoc_default_options = {"inherited-members": None} # sphinxcontrib-bibtex -bibtex_bibfiles = ['./references.bib'] -bibtex_style = 'unsrt' -bibtex_footbibliography_header = '' +bibtex_bibfiles = ["./references.bib"] +bibtex_style = "unsrt" +bibtex_footbibliography_header = "" # -- Nitpicky ---------------------------------------------------------------- @@ -575,7 +722,10 @@ def append_attr_meth_examples(app, what, name, obj, options, lines): ("py:class", "None. Remove all items from D."), ("py:class", "a set-like object providing a view on D's items"), ("py:class", "a set-like object providing a view on D's keys"), - ("py:class", "v, remove specified key and return the corresponding value."), # noqa: E501 + ( + "py:class", + "v, remove specified key and return the corresponding value.", + ), # noqa: E501 ("py:class", "None. Update D from dict/iterable E and F."), ("py:class", "an object providing a view on D's values"), ("py:class", "a shallow copy of D"), @@ -584,11 +734,14 @@ def append_attr_meth_examples(app, what, name, obj, options, lines): ("py:class", "mne.utils._logging._FuncT"), ] nitpick_ignore_regex = [ - ('py:.*', r"mne\.io\.BaseRaw.*"), - ('py:.*', r"mne\.BaseEpochs.*"), - ('py:obj', "(filename|metadata|proj|times|tmax|tmin|annotations|ch_names|compensation_grade|filenames|first_samp|first_time|last_samp|n_times|proj|times|tmax|tmin)"), # noqa: E501 + ("py:.*", r"mne\.io\.BaseRaw.*"), + ("py:.*", r"mne\.BaseEpochs.*"), + ( + "py:obj", + "(filename|metadata|proj|times|tmax|tmin|annotations|ch_names|compensation_grade|filenames|first_samp|first_time|last_samp|n_times|proj|times|tmax|tmin)", + ), # noqa: E501 ] -suppress_warnings = ['image.nonlocal_uri'] # we intentionally link outside +suppress_warnings = ["image.nonlocal_uri"] # we intentionally link outside # -- Options for HTML output ------------------------------------------------- @@ -596,46 +749,56 @@ def append_attr_meth_examples(app, what, name, obj, options, lines): # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'pydata_sphinx_theme' +html_theme = "pydata_sphinx_theme" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. -switcher_version_match = 'dev' if release.endswith('dev0') else version +switcher_version_match = "dev" if release.endswith("dev0") else version html_theme_options = { - 'icon_links': [ - dict(name='GitHub', - url='https://github.com/mne-tools/mne-python', - icon='fa-brands fa-square-github'), - dict(name='Mastodon', - url='https://fosstodon.org/@mne', - icon='fa-brands fa-mastodon', - attributes=dict(rel='me')), - dict(name='Twitter', - url='https://twitter.com/mne_python', - icon='fa-brands fa-square-twitter'), - dict(name='Forum', - url='https://mne.discourse.group/', - icon='fa-brands fa-discourse'), - dict(name='Discord', - url='https://discord.gg/rKfvxTuATa', - icon='fa-brands fa-discord') + "icon_links": [ + dict( + name="GitHub", + url="https://github.com/mne-tools/mne-python", + icon="fa-brands fa-square-github", + ), + dict( + name="Mastodon", + url="https://fosstodon.org/@mne", + icon="fa-brands fa-mastodon", + attributes=dict(rel="me"), + ), + dict( + name="Twitter", + url="https://twitter.com/mne_python", + icon="fa-brands fa-square-twitter", + ), + dict( + name="Forum", + url="https://mne.discourse.group/", + icon="fa-brands fa-discourse", + ), + dict( + name="Discord", + url="https://discord.gg/rKfvxTuATa", + icon="fa-brands fa-discord", + ), ], - 'icon_links_label': 'External Links', # for screen reader - 'use_edit_page_button': False, - 'navigation_with_keys': False, - 'show_toc_level': 1, - 'navbar_end': ['theme-switcher', 'version-switcher', 'navbar-icon-links'], - 'footer_start': ['copyright'], - 'footer_end': [], - 'secondary_sidebar_items': ['page-toc'], - 'analytics': dict(google_analytics_id='G-5TBCPCRB6X'), - 'switcher': { - 'json_url': 'https://mne.tools/dev/_static/versions.json', - 'version_match': switcher_version_match, + "icon_links_label": "External Links", # for screen reader + "use_edit_page_button": False, + "navigation_with_keys": False, + "show_toc_level": 1, + "navbar_end": ["theme-switcher", "version-switcher", "navbar-icon-links"], + "footer_start": ["copyright"], + "footer_end": [], + "secondary_sidebar_items": ["page-toc"], + "analytics": dict(google_analytics_id="G-5TBCPCRB6X"), + "switcher": { + "json_url": "https://mne.tools/dev/_static/versions.json", + "version_match": switcher_version_match, }, - 'pygment_light_style': 'default', - 'pygment_dark_style': 'github-dark', + "pygment_light_style": "default", + "pygment_dark_style": "github-dark", } # The name of an image file (relative to this directory) to place at the top @@ -651,24 +814,24 @@ def append_attr_meth_examples(app, what, name, obj, options, lines): # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] html_css_files = [ - 'style.css', + "style.css", ] # Add any extra paths that contain custom files (such as robots.txt or # .htaccess) here, relative to this directory. These files are copied # directly to the root of the documentation. html_extra_path = [ - 'contributing.html', - 'documentation.html', - 'getting_started.html', - 'install_mne_python.html', + "contributing.html", + "documentation.html", + "getting_started.html", + "install_mne_python.html", ] # Custom sidebar templates, maps document names to template names. html_sidebars = { - 'index': ['sidebar-quicklinks.html'], + "index": ["sidebar-quicklinks.html"], } # If true, links to the reST sources are added to the pages. @@ -679,262 +842,346 @@ def append_attr_meth_examples(app, what, name, obj, options, lines): html_show_sphinx = False # accommodate different logo shapes (width values in rem) -xs = '2' -sm = '2.5' -md = '3' -lg = '4.5' -xl = '5' -xxl = '6' +xs = "2" +sm = "2.5" +md = "3" +lg = "4.5" +xl = "5" +xxl = "6" # variables to pass to HTML templating engine html_context = { - 'default_mode': 'auto', - 'pygment_light_style': 'tango', - 'pygment_dark_style': 'native', - 'funders': [ - dict(img='nih.svg', size='3', title='National Institutes of Health'), - dict(img='nsf.png', size='3.5', - title='US National Science Foundation'), - dict(img='erc.svg', size='3.5', title='European Research Council', - klass='only-light'), - dict(img='erc-dark.svg', size='3.5', title='European Research Council', - klass='only-dark'), - dict(img='doe.svg', size='3', title='US Department of Energy'), - dict(img='anr.svg', size='3.5', - title='Agence Nationale de la Recherche'), - dict(img='cds.png', size='2.25', - title='Paris-Saclay Center for Data Science'), - dict(img='google.svg', size='2.25', title='Google'), - dict(img='amazon.svg', size='2.5', title='Amazon'), - dict(img='czi.svg', size='2.5', title='Chan Zuckerberg Initiative'), + "default_mode": "auto", + "pygment_light_style": "tango", + "pygment_dark_style": "native", + "funders": [ + dict(img="nih.svg", size="3", title="National Institutes of Health"), + dict(img="nsf.png", size="3.5", title="US National Science Foundation"), + dict( + img="erc.svg", + size="3.5", + title="European Research Council", + klass="only-light", + ), + dict( + img="erc-dark.svg", + size="3.5", + title="European Research Council", + klass="only-dark", + ), + dict(img="doe.svg", size="3", title="US Department of Energy"), + dict(img="anr.svg", size="3.5", title="Agence Nationale de la Recherche"), + dict(img="cds.png", size="2.25", title="Paris-Saclay Center for Data Science"), + dict(img="google.svg", size="2.25", title="Google"), + dict(img="amazon.svg", size="2.5", title="Amazon"), + dict(img="czi.svg", size="2.5", title="Chan Zuckerberg Initiative"), ], - 'institutions': [ - dict(name='Massachusetts General Hospital', - img='MGH.svg', - url='https://www.massgeneral.org/', - size=sm), - dict(name='Athinoula A. Martinos Center for Biomedical Imaging', - img='Martinos.png', - url='https://martinos.org/', - size=md), - dict(name='Harvard Medical School', - img='Harvard.png', - url='https://hms.harvard.edu/', - size=sm), - dict(name='Massachusetts Institute of Technology', - img='MIT.svg', - url='https://web.mit.edu/', - size=md), - dict(name='New York University', - img='NYU.svg', - url='https://www.nyu.edu/', - size=xs, - klass='only-light'), - dict(name='New York University', - img='NYU-dark.svg', - url='https://www.nyu.edu/', - size=xs, - klass='only-dark'), - dict(name='Commissariat à l´énergie atomique et aux énergies alternatives', # noqa E501 - img='CEA.png', - url='http://www.cea.fr/', - size=md), - dict(name='Aalto-yliopiston perustieteiden korkeakoulu', - img='Aalto.svg', - url='https://sci.aalto.fi/', - size=md, - klass='only-light'), - dict(name='Aalto-yliopiston perustieteiden korkeakoulu', - img='Aalto-dark.svg', - url='https://sci.aalto.fi/', - size=md, - klass='only-dark'), - dict(name='Télécom ParisTech', - img='Telecom_Paris_Tech.svg', - url='https://www.telecom-paris.fr/', - size=md), - dict(name='University of Washington', - img='Washington.svg', - url='https://www.washington.edu/', - size=md, - klass='only-light'), - dict(name='University of Washington', - img='Washington-dark.svg', - url='https://www.washington.edu/', - size=md, - klass='only-dark'), - dict(name='Institut du Cerveau et de la Moelle épinière', - img='ICM.jpg', - url='https://icm-institute.org/', - size=md), - dict(name='Boston University', - img='BU.svg', - url='https://www.bu.edu/', - size=lg), - dict(name='Institut national de la santé et de la recherche médicale', - img='Inserm.svg', - url='https://www.inserm.fr/', - size=xl, - klass='only-light'), - dict(name='Institut national de la santé et de la recherche médicale', - img='Inserm-dark.svg', - url='https://www.inserm.fr/', - size=xl, - klass='only-dark'), - dict(name='Forschungszentrum Jülich', - img='Julich.svg', - url='https://www.fz-juelich.de/', - size=xl, - klass='only-light'), - dict(name='Forschungszentrum Jülich', - img='Julich-dark.svg', - url='https://www.fz-juelich.de/', - size=xl, - klass='only-dark'), - dict(name='Technische Universität Ilmenau', - img='Ilmenau.svg', - url='https://www.tu-ilmenau.de/', - size=xxl, - klass='only-light'), - dict(name='Technische Universität Ilmenau', - img='Ilmenau-dark.svg', - url='https://www.tu-ilmenau.de/', - size=xxl, - klass='only-dark'), - dict(name='Berkeley Institute for Data Science', - img='BIDS.svg', - url='https://bids.berkeley.edu/', - size=lg, - klass='only-light'), - dict(name='Berkeley Institute for Data Science', - img='BIDS-dark.svg', - url='https://bids.berkeley.edu/', - size=lg, - klass='only-dark'), - dict(name='Institut national de recherche en informatique et en automatique', # noqa E501 - img='inria.png', - url='https://www.inria.fr/', - size=xl), - dict(name='Aarhus Universitet', - img='Aarhus.svg', - url='https://www.au.dk/', - size=xl, - klass='only-light'), - dict(name='Aarhus Universitet', - img='Aarhus-dark.svg', - url='https://www.au.dk/', - size=xl, - klass='only-dark'), - dict(name='Karl-Franzens-Universität Graz', - img='Graz.svg', - url='https://www.uni-graz.at/', - size=md), - dict(name='SWPS Uniwersytet Humanistycznospołeczny', - img='SWPS.svg', - url='https://www.swps.pl/', - size=xl, - klass='only-light'), - dict(name='SWPS Uniwersytet Humanistycznospołeczny', - img='SWPS-dark.svg', - url='https://www.swps.pl/', - size=xl, - klass='only-dark'), - dict(name='Max-Planck-Institut für Bildungsforschung', - img='MPIB.svg', - url='https://www.mpib-berlin.mpg.de/', - size=xxl, - klass='only-light'), - dict(name='Max-Planck-Institut für Bildungsforschung', - img='MPIB-dark.svg', - url='https://www.mpib-berlin.mpg.de/', - size=xxl, - klass='only-dark'), - dict(name='Macquarie University', - img='Macquarie.svg', - url='https://www.mq.edu.au/', - size=lg, - klass='only-light'), - dict(name='Macquarie University', - img='Macquarie-dark.svg', - url='https://www.mq.edu.au/', - size=lg, - klass='only-dark'), - dict(name='Children’s Hospital of Philadelphia Research Institute', - img='CHOP.svg', - url='https://www.research.chop.edu/imaging', - size=xxl, - klass='only-light'), - dict(name='Children’s Hospital of Philadelphia Research Institute', - img='CHOP-dark.svg', - url='https://www.research.chop.edu/imaging', - size=xxl, - klass='only-dark'), - dict(name='Donders Institute for Brain, Cognition and Behaviour at Radboud University', # noqa E501 - img='Donders.png', - url='https://www.ru.nl/donders/', - size=xl), + "institutions": [ + dict( + name="Massachusetts General Hospital", + img="MGH.svg", + url="https://www.massgeneral.org/", + size=sm, + ), + dict( + name="Athinoula A. Martinos Center for Biomedical Imaging", + img="Martinos.png", + url="https://martinos.org/", + size=md, + ), + dict( + name="Harvard Medical School", + img="Harvard.png", + url="https://hms.harvard.edu/", + size=sm, + ), + dict( + name="Massachusetts Institute of Technology", + img="MIT.svg", + url="https://web.mit.edu/", + size=md, + ), + dict( + name="New York University", + img="NYU.svg", + url="https://www.nyu.edu/", + size=xs, + klass="only-light", + ), + dict( + name="New York University", + img="NYU-dark.svg", + url="https://www.nyu.edu/", + size=xs, + klass="only-dark", + ), + dict( + name="Commissariat à l´énergie atomique et aux énergies alternatives", # noqa E501 + img="CEA.png", + url="http://www.cea.fr/", + size=md, + ), + dict( + name="Aalto-yliopiston perustieteiden korkeakoulu", + img="Aalto.svg", + url="https://sci.aalto.fi/", + size=md, + klass="only-light", + ), + dict( + name="Aalto-yliopiston perustieteiden korkeakoulu", + img="Aalto-dark.svg", + url="https://sci.aalto.fi/", + size=md, + klass="only-dark", + ), + dict( + name="Télécom ParisTech", + img="Telecom_Paris_Tech.svg", + url="https://www.telecom-paris.fr/", + size=md, + ), + dict( + name="University of Washington", + img="Washington.svg", + url="https://www.washington.edu/", + size=md, + klass="only-light", + ), + dict( + name="University of Washington", + img="Washington-dark.svg", + url="https://www.washington.edu/", + size=md, + klass="only-dark", + ), + dict( + name="Institut du Cerveau et de la Moelle épinière", + img="ICM.jpg", + url="https://icm-institute.org/", + size=md, + ), + dict( + name="Boston University", img="BU.svg", url="https://www.bu.edu/", size=lg + ), + dict( + name="Institut national de la santé et de la recherche médicale", + img="Inserm.svg", + url="https://www.inserm.fr/", + size=xl, + klass="only-light", + ), + dict( + name="Institut national de la santé et de la recherche médicale", + img="Inserm-dark.svg", + url="https://www.inserm.fr/", + size=xl, + klass="only-dark", + ), + dict( + name="Forschungszentrum Jülich", + img="Julich.svg", + url="https://www.fz-juelich.de/", + size=xl, + klass="only-light", + ), + dict( + name="Forschungszentrum Jülich", + img="Julich-dark.svg", + url="https://www.fz-juelich.de/", + size=xl, + klass="only-dark", + ), + dict( + name="Technische Universität Ilmenau", + img="Ilmenau.svg", + url="https://www.tu-ilmenau.de/", + size=xxl, + klass="only-light", + ), + dict( + name="Technische Universität Ilmenau", + img="Ilmenau-dark.svg", + url="https://www.tu-ilmenau.de/", + size=xxl, + klass="only-dark", + ), + dict( + name="Berkeley Institute for Data Science", + img="BIDS.svg", + url="https://bids.berkeley.edu/", + size=lg, + klass="only-light", + ), + dict( + name="Berkeley Institute for Data Science", + img="BIDS-dark.svg", + url="https://bids.berkeley.edu/", + size=lg, + klass="only-dark", + ), + dict( + name="Institut national de recherche en informatique et en automatique", # noqa E501 + img="inria.png", + url="https://www.inria.fr/", + size=xl, + ), + dict( + name="Aarhus Universitet", + img="Aarhus.svg", + url="https://www.au.dk/", + size=xl, + klass="only-light", + ), + dict( + name="Aarhus Universitet", + img="Aarhus-dark.svg", + url="https://www.au.dk/", + size=xl, + klass="only-dark", + ), + dict( + name="Karl-Franzens-Universität Graz", + img="Graz.svg", + url="https://www.uni-graz.at/", + size=md, + ), + dict( + name="SWPS Uniwersytet Humanistycznospołeczny", + img="SWPS.svg", + url="https://www.swps.pl/", + size=xl, + klass="only-light", + ), + dict( + name="SWPS Uniwersytet Humanistycznospołeczny", + img="SWPS-dark.svg", + url="https://www.swps.pl/", + size=xl, + klass="only-dark", + ), + dict( + name="Max-Planck-Institut für Bildungsforschung", + img="MPIB.svg", + url="https://www.mpib-berlin.mpg.de/", + size=xxl, + klass="only-light", + ), + dict( + name="Max-Planck-Institut für Bildungsforschung", + img="MPIB-dark.svg", + url="https://www.mpib-berlin.mpg.de/", + size=xxl, + klass="only-dark", + ), + dict( + name="Macquarie University", + img="Macquarie.svg", + url="https://www.mq.edu.au/", + size=lg, + klass="only-light", + ), + dict( + name="Macquarie University", + img="Macquarie-dark.svg", + url="https://www.mq.edu.au/", + size=lg, + klass="only-dark", + ), + dict( + name="Children’s Hospital of Philadelphia Research Institute", + img="CHOP.svg", + url="https://www.research.chop.edu/imaging", + size=xxl, + klass="only-light", + ), + dict( + name="Children’s Hospital of Philadelphia Research Institute", + img="CHOP-dark.svg", + url="https://www.research.chop.edu/imaging", + size=xxl, + klass="only-dark", + ), + dict( + name="Donders Institute for Brain, Cognition and Behaviour at Radboud University", # noqa E501 + img="Donders.png", + url="https://www.ru.nl/donders/", + size=xl, + ), ], # \u00AD is an optional hyphen (not rendered unless needed) # If these are changed, the Makefile should be updated, too - 'carousel': [ - dict(title='Source Estimation', - text='Distributed, sparse, mixed-norm, beam\u00ADformers, dipole fitting, and more.', # noqa E501 - url='auto_tutorials/inverse/index.html', - img='sphx_glr_30_mne_dspm_loreta_008.gif', - alt='dSPM'), - dict(title='Machine Learning', - text='Advanced decoding models including time general\u00ADiza\u00ADtion.', # noqa E501 - url='auto_tutorials/machine-learning/50_decoding.html', - img='sphx_glr_50_decoding_006.png', - alt='Decoding'), - dict(title='Encoding Models', - text='Receptive field estima\u00ADtion with optional smooth\u00ADness priors.', # noqa E501 - url='auto_tutorials/machine-learning/30_strf.html', - img='sphx_glr_30_strf_001.png', - alt='STRF'), - dict(title='Statistics', - text='Parametric and non-parametric, permutation tests and clustering.', # noqa E501 - url='auto_tutorials/stats-source-space/index.html', - img='sphx_glr_20_cluster_1samp_spatiotemporal_001.png', - alt='Clusters'), - dict(title='Connectivity', - text='All-to-all spectral and effective connec\u00ADtivity measures.', # noqa E501 - url='https://mne.tools/mne-connectivity/stable/auto_examples/mne_inverse_label_connectivity.html', # noqa E501 - img='https://mne.tools/mne-connectivity/stable/_images/sphx_glr_mne_inverse_label_connectivity_001.png', # noqa E501 - alt='Connectivity'), - dict(title='Data Visualization', - text='Explore your data from multiple perspectives.', - url='auto_tutorials/evoked/20_visualize_evoked.html', - img='sphx_glr_20_visualize_evoked_010.png', - alt='Visualization'), - ] + "carousel": [ + dict( + title="Source Estimation", + text="Distributed, sparse, mixed-norm, beam\u00ADformers, dipole fitting, and more.", # noqa E501 + url="auto_tutorials/inverse/index.html", + img="sphx_glr_30_mne_dspm_loreta_008.gif", + alt="dSPM", + ), + dict( + title="Machine Learning", + text="Advanced decoding models including time general\u00ADiza\u00ADtion.", # noqa E501 + url="auto_tutorials/machine-learning/50_decoding.html", + img="sphx_glr_50_decoding_006.png", + alt="Decoding", + ), + dict( + title="Encoding Models", + text="Receptive field estima\u00ADtion with optional smooth\u00ADness priors.", # noqa E501 + url="auto_tutorials/machine-learning/30_strf.html", + img="sphx_glr_30_strf_001.png", + alt="STRF", + ), + dict( + title="Statistics", + text="Parametric and non-parametric, permutation tests and clustering.", # noqa E501 + url="auto_tutorials/stats-source-space/index.html", + img="sphx_glr_20_cluster_1samp_spatiotemporal_001.png", + alt="Clusters", + ), + dict( + title="Connectivity", + text="All-to-all spectral and effective connec\u00ADtivity measures.", # noqa E501 + url="https://mne.tools/mne-connectivity/stable/auto_examples/mne_inverse_label_connectivity.html", # noqa E501 + img="https://mne.tools/mne-connectivity/stable/_images/sphx_glr_mne_inverse_label_connectivity_001.png", # noqa E501 + alt="Connectivity", + ), + dict( + title="Data Visualization", + text="Explore your data from multiple perspectives.", + url="auto_tutorials/evoked/20_visualize_evoked.html", + img="sphx_glr_20_visualize_evoked_010.png", + alt="Visualization", + ), + ], } # Output file base name for HTML help builder. -htmlhelp_basename = 'mne-doc' +htmlhelp_basename = "mne-doc" # -- Options for plot_directive ---------------------------------------------- # Adapted from SciPy plot_include_source = True -plot_formats = [('png', 96)] +plot_formats = [("png", 96)] plot_html_show_formats = False plot_html_show_source_link = False font_size = 13 * 72 / 96.0 # 13 px plot_rcparams = { - 'font.size': font_size, - 'axes.titlesize': font_size, - 'axes.labelsize': font_size, - 'xtick.labelsize': font_size, - 'ytick.labelsize': font_size, - 'legend.fontsize': font_size, - 'figure.figsize': (6, 5), - 'figure.subplot.bottom': 0.2, - 'figure.subplot.left': 0.2, - 'figure.subplot.right': 0.9, - 'figure.subplot.top': 0.85, - 'figure.subplot.wspace': 0.4, - 'text.usetex': False, + "font.size": font_size, + "axes.titlesize": font_size, + "axes.labelsize": font_size, + "xtick.labelsize": font_size, + "ytick.labelsize": font_size, + "legend.fontsize": font_size, + "figure.figsize": (6, 5), + "figure.subplot.bottom": 0.2, + "figure.subplot.left": 0.2, + "figure.subplot.right": 0.9, + "figure.subplot.top": 0.85, + "figure.subplot.wspace": 0.4, + "text.usetex": False, } @@ -951,13 +1198,14 @@ def append_attr_meth_examples(app, what, name, obj, options, lines): # For "manual" documents, if this is true, then toplevel headings are parts, # not chapters. -latex_toplevel_sectioning = 'part' +latex_toplevel_sectioning = "part" _np_print_defaults = np.get_printoptions() # -- Warnings management ----------------------------------------------------- + def reset_warnings(gallery_conf, fname): """Ensure we are future compatible and ignore silly warnings.""" # In principle, our examples should produce no warnings. @@ -968,78 +1216,84 @@ def reset_warnings(gallery_conf, fname): # remove tweaks from other module imports or example runs warnings.resetwarnings() # restrict - warnings.filterwarnings('error') + warnings.filterwarnings("error") # allow these, but show them - warnings.filterwarnings('always', '.*non-standard config type: "foo".*') - warnings.filterwarnings('always', '.*config type: "MNEE_USE_CUUDAA".*') - warnings.filterwarnings('always', '.*cannot make axes width small.*') - warnings.filterwarnings('always', '.*Axes that are not compatible.*') - warnings.filterwarnings('always', '.*FastICA did not converge.*') + warnings.filterwarnings("always", '.*non-standard config type: "foo".*') + warnings.filterwarnings("always", '.*config type: "MNEE_USE_CUUDAA".*') + warnings.filterwarnings("always", ".*cannot make axes width small.*") + warnings.filterwarnings("always", ".*Axes that are not compatible.*") + warnings.filterwarnings("always", ".*FastICA did not converge.*") # ECoG BIDS spec violations: - warnings.filterwarnings('always', '.*Fiducial point nasion not found.*') - warnings.filterwarnings('always', '.*DigMontage is only a subset of.*') + warnings.filterwarnings("always", ".*Fiducial point nasion not found.*") + warnings.filterwarnings("always", ".*DigMontage is only a subset of.*") warnings.filterwarnings( # xhemi morph (should probably update sample) - 'always', '.*does not exist, creating it and saving it.*') + "always", ".*does not exist, creating it and saving it.*" + ) # internal warnings - warnings.filterwarnings('default', module='sphinx') + warnings.filterwarnings("default", module="sphinx") # allow these warnings, but don't show them for key in ( - 'The module matplotlib.tight_layout is deprecated', # nilearn - 'invalid version and will not be supported', # pyxdf - 'distutils Version classes are deprecated', # seaborn and neo - '`np.object` is a deprecated alias for the builtin `object`', # pyxdf + "The module matplotlib.tight_layout is deprecated", # nilearn + "invalid version and will not be supported", # pyxdf + "distutils Version classes are deprecated", # seaborn and neo + "`np.object` is a deprecated alias for the builtin `object`", # pyxdf # nilearn, should be fixed in > 0.9.1 - 'In future, it will be an error for \'np.bool_\' scalars to', + "In future, it will be an error for 'np.bool_' scalars to", # sklearn hasn't updated to SciPy's sym_pos dep - 'The \'sym_pos\' keyword is deprecated', + "The 'sym_pos' keyword is deprecated", # numba - '`np.MachAr` is deprecated', + "`np.MachAr` is deprecated", # joblib hasn't updated to avoid distutils - 'distutils package is deprecated', + "distutils package is deprecated", # jupyter - 'Jupyter is migrating its paths to use standard', - r'Widget\..* is deprecated\.', + "Jupyter is migrating its paths to use standard", + r"Widget\..* is deprecated\.", # PyQt6 - 'Enum value .* is marked as deprecated', + "Enum value .* is marked as deprecated", # matplotlib PDF output - 'The py23 module has been deprecated', + "The py23 module has been deprecated", # pkg_resources - 'Implementing implicit namespace packages', - 'Deprecated call to `pkg_resources', + "Implementing implicit namespace packages", + "Deprecated call to `pkg_resources", # nilearn - 'pkg_resources is deprecated as an API', - r'The .* was deprecated in Matplotlib 3\.7', + "pkg_resources is deprecated as an API", + r"The .* was deprecated in Matplotlib 3\.7", ): warnings.filterwarnings( # deal with other modules having bad imports - 'ignore', message=".*%s.*" % key, category=DeprecationWarning) - warnings.filterwarnings( - 'ignore', message=( - 'Matplotlib is currently using agg, which is a non-GUI backend.*' + "ignore", message=".*%s.*" % key, category=DeprecationWarning ) + warnings.filterwarnings( + "ignore", + message=("Matplotlib is currently using agg, which is a non-GUI backend.*"), ) # matplotlib 3.6 in nilearn and pyvista - warnings.filterwarnings( - 'ignore', message='.*cmap function will be deprecated.*') + warnings.filterwarnings("ignore", message=".*cmap function will be deprecated.*") # xarray/netcdf4 warnings.filterwarnings( - 'ignore', message=r'numpy\.ndarray size changed, may indicate.*', - category=RuntimeWarning) + "ignore", + message=r"numpy\.ndarray size changed, may indicate.*", + category=RuntimeWarning, + ) # qdarkstyle warnings.filterwarnings( - 'ignore', message=r'.*Setting theme=.*6 in qdarkstyle.*', - category=RuntimeWarning) + "ignore", + message=r".*Setting theme=.*6 in qdarkstyle.*", + category=RuntimeWarning, + ) # pandas, via seaborn (examples/time_frequency/time_frequency_erds.py) warnings.filterwarnings( - 'ignore', message=r'iteritems is deprecated.*Use \.items instead\.', - category=FutureWarning) + "ignore", + message=r"iteritems is deprecated.*Use \.items instead\.", + category=FutureWarning, + ) # pandas in 50_epochs_to_data_frame.py warnings.filterwarnings( - 'ignore', message=r'invalid value encountered in cast', - category=RuntimeWarning) + "ignore", message=r"invalid value encountered in cast", category=RuntimeWarning + ) # xarray _SixMetaPathImporter (?) warnings.filterwarnings( - 'ignore', message=r'falling back to find_module', - category=ImportWarning) + "ignore", message=r"falling back to find_module", category=ImportWarning + ) # In case we use np.set_printoptions in any tutorials, we only # want it to affect those: @@ -1051,49 +1305,70 @@ def reset_warnings(gallery_conf, fname): # -- Fontawesome support ----------------------------------------------------- -brand_icons = ('apple', 'linux', 'windows', 'discourse', 'python') +brand_icons = ("apple", "linux", "windows", "discourse", "python") fixed_width_icons = ( # homepage: - 'book', 'code-branch', 'newspaper', 'circle-question', 'quote-left', + "book", + "code-branch", + "newspaper", + "circle-question", + "quote-left", # contrib guide: - 'bug-slash', 'comment', 'computer-mouse', 'hand-sparkles', 'pencil', - 'text-slash', 'universal-access', 'wand-magic-sparkles', - 'discourse', 'python', + "bug-slash", + "comment", + "computer-mouse", + "hand-sparkles", + "pencil", + "text-slash", + "universal-access", + "wand-magic-sparkles", + "discourse", + "python", ) other_icons = ( - 'hand-paper', 'question', 'rocket', 'server', 'code', 'desktop', - 'terminal', 'cloud-arrow-down', 'wrench', 'hourglass-half' + "hand-paper", + "question", + "rocket", + "server", + "code", + "desktop", + "terminal", + "cloud-arrow-down", + "wrench", + "hourglass-half", ) icon_class = dict() for icon in brand_icons + fixed_width_icons + other_icons: - icon_class[icon] = ('fa-brands',) if icon in brand_icons else ('fa-solid',) - icon_class[icon] += ('fa-fw',) if icon in fixed_width_icons else () + icon_class[icon] = ("fa-brands",) if icon in brand_icons else ("fa-solid",) + icon_class[icon] += ("fa-fw",) if icon in fixed_width_icons else () -rst_prolog = '' +rst_prolog = "" for icon, classes in icon_class.items(): - rst_prolog += f''' + rst_prolog += f""" .. |{icon}| raw:: html -''' +""" -rst_prolog += ''' +rst_prolog += """ .. |ensp| unicode:: U+2002 .. EN SPACE -''' +""" # -- Dependency info ---------------------------------------------------------- try: from importlib.metadata import metadata # new in Python 3.8 - min_py = metadata('mne')['Requires-Python'] + + min_py = metadata("mne")["Requires-Python"] except ModuleNotFoundError: from pkg_resources import get_distribution - info = get_distribution('mne').get_metadata_lines('PKG-INFO') + + info = get_distribution("mne").get_metadata_lines("PKG-INFO") for line in info: - if line.strip().startswith('Requires-Python'): - min_py = line.split(':')[1] -min_py = min_py.lstrip(' =<>') -rst_prolog += f'\n.. |min_python_version| replace:: {min_py}\n' + if line.strip().startswith("Requires-Python"): + min_py = line.split(":")[1] +min_py = min_py.lstrip(" =<>") +rst_prolog += f"\n.. |min_python_version| replace:: {min_py}\n" # -- website redirects -------------------------------------------------------- @@ -1101,141 +1376,214 @@ def reset_warnings(gallery_conf, fname): # since we don't need to add redirects for examples added after this date. needed_plot_redirects = { # tutorials - '10_epochs_overview.py', '10_evoked_overview.py', '10_overview.py', - '10_preprocessing_overview.py', '10_raw_overview.py', - '10_reading_meg_data.py', '15_handling_bad_channels.py', - '20_event_arrays.py', '20_events_from_raw.py', '20_reading_eeg_data.py', - '20_rejecting_bad_data.py', '20_visualize_epochs.py', - '20_visualize_evoked.py', '30_annotate_raw.py', '30_epochs_metadata.py', - '30_filtering_resampling.py', '30_info.py', '30_reading_fnirs_data.py', - '35_artifact_correction_regression.py', '40_artifact_correction_ica.py', - '40_autogenerate_metadata.py', '40_sensor_locations.py', - '40_visualize_raw.py', '45_projectors_background.py', - '50_artifact_correction_ssp.py', '50_configure_mne.py', - '50_epochs_to_data_frame.py', '55_setting_eeg_reference.py', - '59_head_positions.py', '60_make_fixed_length_epochs.py', - '60_maxwell_filtering_sss.py', '70_fnirs_processing.py', + "10_epochs_overview.py", + "10_evoked_overview.py", + "10_overview.py", + "10_preprocessing_overview.py", + "10_raw_overview.py", + "10_reading_meg_data.py", + "15_handling_bad_channels.py", + "20_event_arrays.py", + "20_events_from_raw.py", + "20_reading_eeg_data.py", + "20_rejecting_bad_data.py", + "20_visualize_epochs.py", + "20_visualize_evoked.py", + "30_annotate_raw.py", + "30_epochs_metadata.py", + "30_filtering_resampling.py", + "30_info.py", + "30_reading_fnirs_data.py", + "35_artifact_correction_regression.py", + "40_artifact_correction_ica.py", + "40_autogenerate_metadata.py", + "40_sensor_locations.py", + "40_visualize_raw.py", + "45_projectors_background.py", + "50_artifact_correction_ssp.py", + "50_configure_mne.py", + "50_epochs_to_data_frame.py", + "55_setting_eeg_reference.py", + "59_head_positions.py", + "60_make_fixed_length_epochs.py", + "60_maxwell_filtering_sss.py", + "70_fnirs_processing.py", # examples - '3d_to_2d.py', 'brainstorm_data.py', 'channel_epochs_image.py', - 'cluster_stats_evoked.py', 'compute_csd.py', - 'compute_mne_inverse_epochs_in_label.py', - 'compute_mne_inverse_raw_in_label.py', 'compute_mne_inverse_volume.py', - 'compute_source_psd_epochs.py', 'covariance_whitening_dspm.py', - 'custom_inverse_solver.py', - 'decoding_csp_eeg.py', 'decoding_csp_timefreq.py', - 'decoding_spatio_temporal_source.py', 'decoding_spoc_CMC.py', - 'decoding_time_generalization_conditions.py', - 'decoding_unsupervised_spatial_filter.py', 'decoding_xdawn_eeg.py', - 'define_target_events.py', 'dics_source_power.py', 'eeg_csd.py', - 'eeg_on_scalp.py', 'eeglab_head_sphere.py', 'elekta_epochs.py', - 'ems_filtering.py', 'eog_artifact_histogram.py', 'evoked_arrowmap.py', - 'evoked_ers_source_power.py', 'evoked_topomap.py', 'evoked_whitening.py', - 'fdr_stats_evoked.py', 'find_ref_artifacts.py', - 'fnirs_artifact_removal.py', 'forward_sensitivity_maps.py', - 'gamma_map_inverse.py', 'hf_sef_data.py', 'ica_comparison.py', - 'interpolate_bad_channels.py', 'label_activation_from_stc.py', - 'label_from_stc.py', 'label_source_activations.py', - 'left_cerebellum_volume_source.py', 'limo_data.py', - 'linear_model_patterns.py', 'linear_regression_raw.py', - 'meg_sensors.py', 'mixed_norm_inverse.py', - 'mixed_source_space_inverse.py', - 'mne_cov_power.py', 'mne_helmet.py', 'mne_inverse_coherence_epochs.py', - 'mne_inverse_envelope_correlation.py', - 'mne_inverse_envelope_correlation_volume.py', - 'mne_inverse_psi_visual.py', - 'morph_surface_stc.py', 'morph_volume_stc.py', 'movement_compensation.py', - 'movement_detection.py', 'multidict_reweighted_tfmxne.py', - 'muscle_detection.py', 'opm_data.py', 'otp.py', 'parcellation.py', - 'psf_ctf_label_leakage.py', 'psf_ctf_vertices.py', - 'psf_ctf_vertices_lcmv.py', 'publication_figure.py', 'rap_music.py', - 'read_inverse.py', 'read_neo_format.py', 'read_noise_covariance_matrix.py', - 'read_stc.py', 'receptive_field_mtrf.py', 'resolution_metrics.py', - 'resolution_metrics_eegmeg.py', 'roi_erpimage_by_rt.py', - 'sensor_noise_level.py', - 'sensor_permutation_test.py', 'sensor_regression.py', - 'shift_evoked.py', 'simulate_evoked_data.py', 'simulate_raw_data.py', - 'simulated_raw_data_using_subject_anatomy.py', 'snr_estimate.py', - 'source_label_time_frequency.py', 'source_power_spectrum.py', - 'source_power_spectrum_opm.py', 'source_simulator.py', - 'source_space_morphing.py', 'source_space_snr.py', - 'source_space_time_frequency.py', 'ssd_spatial_filters.py', - 'ssp_projs_sensitivity_map.py', 'temporal_whitening.py', - 'time_frequency_erds.py', 'time_frequency_global_field_power.py', - 'time_frequency_mixed_norm_inverse.py', 'time_frequency_simulated.py', - 'topo_compare_conditions.py', 'topo_customized.py', - 'vector_mne_solution.py', 'virtual_evoked.py', 'xdawn_denoising.py', - 'xhemi.py', + "3d_to_2d.py", + "brainstorm_data.py", + "channel_epochs_image.py", + "cluster_stats_evoked.py", + "compute_csd.py", + "compute_mne_inverse_epochs_in_label.py", + "compute_mne_inverse_raw_in_label.py", + "compute_mne_inverse_volume.py", + "compute_source_psd_epochs.py", + "covariance_whitening_dspm.py", + "custom_inverse_solver.py", + "decoding_csp_eeg.py", + "decoding_csp_timefreq.py", + "decoding_spatio_temporal_source.py", + "decoding_spoc_CMC.py", + "decoding_time_generalization_conditions.py", + "decoding_unsupervised_spatial_filter.py", + "decoding_xdawn_eeg.py", + "define_target_events.py", + "dics_source_power.py", + "eeg_csd.py", + "eeg_on_scalp.py", + "eeglab_head_sphere.py", + "elekta_epochs.py", + "ems_filtering.py", + "eog_artifact_histogram.py", + "evoked_arrowmap.py", + "evoked_ers_source_power.py", + "evoked_topomap.py", + "evoked_whitening.py", + "fdr_stats_evoked.py", + "find_ref_artifacts.py", + "fnirs_artifact_removal.py", + "forward_sensitivity_maps.py", + "gamma_map_inverse.py", + "hf_sef_data.py", + "ica_comparison.py", + "interpolate_bad_channels.py", + "label_activation_from_stc.py", + "label_from_stc.py", + "label_source_activations.py", + "left_cerebellum_volume_source.py", + "limo_data.py", + "linear_model_patterns.py", + "linear_regression_raw.py", + "meg_sensors.py", + "mixed_norm_inverse.py", + "mixed_source_space_inverse.py", + "mne_cov_power.py", + "mne_helmet.py", + "mne_inverse_coherence_epochs.py", + "mne_inverse_envelope_correlation.py", + "mne_inverse_envelope_correlation_volume.py", + "mne_inverse_psi_visual.py", + "morph_surface_stc.py", + "morph_volume_stc.py", + "movement_compensation.py", + "movement_detection.py", + "multidict_reweighted_tfmxne.py", + "muscle_detection.py", + "opm_data.py", + "otp.py", + "parcellation.py", + "psf_ctf_label_leakage.py", + "psf_ctf_vertices.py", + "psf_ctf_vertices_lcmv.py", + "publication_figure.py", + "rap_music.py", + "read_inverse.py", + "read_neo_format.py", + "read_noise_covariance_matrix.py", + "read_stc.py", + "receptive_field_mtrf.py", + "resolution_metrics.py", + "resolution_metrics_eegmeg.py", + "roi_erpimage_by_rt.py", + "sensor_noise_level.py", + "sensor_permutation_test.py", + "sensor_regression.py", + "shift_evoked.py", + "simulate_evoked_data.py", + "simulate_raw_data.py", + "simulated_raw_data_using_subject_anatomy.py", + "snr_estimate.py", + "source_label_time_frequency.py", + "source_power_spectrum.py", + "source_power_spectrum_opm.py", + "source_simulator.py", + "source_space_morphing.py", + "source_space_snr.py", + "source_space_time_frequency.py", + "ssd_spatial_filters.py", + "ssp_projs_sensitivity_map.py", + "temporal_whitening.py", + "time_frequency_erds.py", + "time_frequency_global_field_power.py", + "time_frequency_mixed_norm_inverse.py", + "time_frequency_simulated.py", + "topo_compare_conditions.py", + "topo_customized.py", + "vector_mne_solution.py", + "virtual_evoked.py", + "xdawn_denoising.py", + "xhemi.py", } -ex = 'auto_examples' -co = 'connectivity' -mne_conn = 'https://mne.tools/mne-connectivity/stable' -tu = 'auto_tutorials' -di = 'discussions' -sm = 'source-modeling' -fw = 'forward' -nv = 'inverse' -sn = 'stats-sensor-space' -sr = 'stats-source-space' -sd = 'sample-datasets' -ml = 'machine-learning' -tf = 'time-freq' -si = 'simulation' +ex = "auto_examples" +co = "connectivity" +mne_conn = "https://mne.tools/mne-connectivity/stable" +tu = "auto_tutorials" +di = "discussions" +sm = "source-modeling" +fw = "forward" +nv = "inverse" +sn = "stats-sensor-space" +sr = "stats-source-space" +sd = "sample-datasets" +ml = "machine-learning" +tf = "time-freq" +si = "simulation" custom_redirects = { # Custom redirects (one HTML path to another, relative to outdir) # can be added here as fr->to key->value mappings - f'{tu}/evoked/plot_eeg_erp.html': f'{tu}/evoked/30_eeg_erp.html', - f'{tu}/evoked/plot_whitened.html': f'{tu}/evoked/40_whitened.html', - f'{tu}/misc/plot_modifying_data_inplace.html': f'{tu}/intro/15_inplace.html', # noqa E501 - f'{tu}/misc/plot_report.html': f'{tu}/intro/70_report.html', - f'{tu}/misc/plot_seeg.html': f'{tu}/clinical/20_seeg.html', - f'{tu}/misc/plot_ecog.html': f'{tu}/clinical/30_ecog.html', - f'{tu}/{ml}/plot_receptive_field.html': f'{tu}/{ml}/30_strf.html', - f'{tu}/{ml}/plot_sensors_decoding.html': f'{tu}/{ml}/50_decoding.html', - f'{tu}/{sm}/plot_background_freesurfer.html': f'{tu}/{fw}/10_background_freesurfer.html', # noqa E501 - f'{tu}/{sm}/plot_source_alignment.html': f'{tu}/{fw}/20_source_alignment.html', # noqa E501 - f'{tu}/{sm}/plot_forward.html': f'{tu}/{fw}/30_forward.html', - f'{tu}/{sm}/plot_eeg_no_mri.html': f'{tu}/{fw}/35_eeg_no_mri.html', - f'{tu}/{sm}/plot_background_freesurfer_mne.html': f'{tu}/{fw}/50_background_freesurfer_mne.html', # noqa E501 - f'{tu}/{sm}/plot_fix_bem_in_blender.html': f'{tu}/{fw}/80_fix_bem_in_blender.html', # noqa E501 - f'{tu}/{sm}/plot_compute_covariance.html': f'{tu}/{fw}/90_compute_covariance.html', # noqa E501 - f'{tu}/{sm}/plot_object_source_estimate.html': f'{tu}/{nv}/10_stc_class.html', # noqa E501 - f'{tu}/{sm}/plot_dipole_fit.html': f'{tu}/{nv}/20_dipole_fit.html', - f'{tu}/{sm}/plot_mne_dspm_source_localization.html': f'{tu}/{nv}/30_mne_dspm_loreta.html', # noqa E501 - f'{tu}/{sm}/plot_dipole_orientations.html': f'{tu}/{nv}/35_dipole_orientations.html', # noqa E501 - f'{tu}/{sm}/plot_mne_solutions.html': f'{tu}/{nv}/40_mne_fixed_free.html', - f'{tu}/{sm}/plot_beamformer_lcmv.html': f'{tu}/{nv}/50_beamformer_lcmv.html', # noqa E501 - f'{tu}/{sm}/plot_visualize_stc.html': f'{tu}/{nv}/60_visualize_stc.html', - f'{tu}/{sm}/plot_eeg_mri_coords.html': f'{tu}/{nv}/70_eeg_mri_coords.html', - f'{tu}/{sd}/plot_brainstorm_phantom_elekta.html': f'{tu}/{nv}/80_brainstorm_phantom_elekta.html', # noqa E501 - f'{tu}/{sd}/plot_brainstorm_phantom_ctf.html': f'{tu}/{nv}/85_brainstorm_phantom_ctf.html', # noqa E501 - f'{tu}/{sd}/plot_phantom_4DBTi.html': f'{tu}/{nv}/90_phantom_4DBTi.html', - f'{tu}/{sd}/plot_brainstorm_auditory.html': f'{tu}/io/60_ctf_bst_auditory.html', # noqa E501 - f'{tu}/{sd}/plot_sleep.html': f'{tu}/clinical/60_sleep.html', - f'{tu}/{di}/plot_background_filtering.html': f'{tu}/preprocessing/25_background_filtering.html', # noqa E501 - f'{tu}/{di}/plot_background_statistics.html': f'{tu}/{sn}/10_background_stats.html', # noqa E501 - f'{tu}/{sn}/plot_stats_cluster_erp.html': f'{tu}/{sn}/20_erp_stats.html', - f'{tu}/{sn}/plot_stats_cluster_1samp_test_time_frequency.html': f'{tu}/{sn}/40_cluster_1samp_time_freq.html', # noqa E501 - f'{tu}/{sn}/plot_stats_cluster_time_frequency.html': f'{tu}/{sn}/50_cluster_between_time_freq.html', # noqa E501 - f'{tu}/{sn}/plot_stats_spatio_temporal_cluster_sensors.html': f'{tu}/{sn}/75_cluster_ftest_spatiotemporal.html', # noqa E501 - f'{tu}/{sr}/plot_stats_cluster_spatio_temporal.html': f'{tu}/{sr}/20_cluster_1samp_spatiotemporal.html', # noqa E501 - f'{tu}/{sr}/plot_stats_cluster_spatio_temporal_2samp.html': f'{tu}/{sr}/30_cluster_ftest_spatiotemporal.html', # noqa E501 - f'{tu}/{sr}/plot_stats_cluster_spatio_temporal_repeated_measures_anova.html': f'{tu}/{sr}/60_cluster_rmANOVA_spatiotemporal.html', # noqa E501 - f'{tu}/{sr}/plot_stats_cluster_time_frequency_repeated_measures_anova.html': f'{tu}/{sn}/70_cluster_rmANOVA_time_freq.html', # noqa E501 - f'{tu}/{tf}/plot_sensors_time_frequency.html': f'{tu}/{tf}/20_sensors_time_frequency.html', # noqa E501 - f'{tu}/{tf}/plot_ssvep.html': f'{tu}/{tf}/50_ssvep.html', - f'{tu}/{si}/plot_creating_data_structures.html': f'{tu}/{si}/10_array_objs.html', # noqa E501 - f'{tu}/{si}/plot_point_spread.html': f'{tu}/{si}/70_point_spread.html', - f'{tu}/{si}/plot_dics.html': f'{tu}/{si}/80_dics.html', - f'{tu}/{tf}/plot_eyetracking.html': f'{tu}/preprocessing/90_eyetracking_data.html', # noqa E501 - f'{ex}/{co}/mne_inverse_label_connectivity.html': f'{mne_conn}/{ex}/mne_inverse_label_connectivity.html', # noqa E501 - f'{ex}/{co}/cwt_sensor_connectivity.html': f'{mne_conn}/{ex}/cwt_sensor_connectivity.html', # noqa E501 - f'{ex}/{co}/mixed_source_space_connectivity.html': f'{mne_conn}/{ex}/mixed_source_space_connectivity.html', # noqa E501 - f'{ex}/{co}/mne_inverse_coherence_epochs.html': f'{mne_conn}/{ex}/mne_inverse_coherence_epochs.html', # noqa E501 - f'{ex}/{co}/mne_inverse_connectivity_spectrum.html': f'{mne_conn}/{ex}/mne_inverse_connectivity_spectrum.html', # noqa E501 - f'{ex}/{co}/mne_inverse_envelope_correlation_volume.html': f'{mne_conn}/{ex}/mne_inverse_envelope_correlation_volume.html', # noqa E501 - f'{ex}/{co}/mne_inverse_envelope_correlation.html': f'{mne_conn}/{ex}/mne_inverse_envelope_correlation.html', # noqa E501 - f'{ex}/{co}/mne_inverse_psi_visual.html': f'{mne_conn}/{ex}/mne_inverse_psi_visual.html', # noqa E501 - f'{ex}/{co}/sensor_connectivity.html': f'{mne_conn}/{ex}/sensor_connectivity.html', # noqa E501 + f"{tu}/evoked/plot_eeg_erp.html": f"{tu}/evoked/30_eeg_erp.html", + f"{tu}/evoked/plot_whitened.html": f"{tu}/evoked/40_whitened.html", + f"{tu}/misc/plot_modifying_data_inplace.html": f"{tu}/intro/15_inplace.html", # noqa E501 + f"{tu}/misc/plot_report.html": f"{tu}/intro/70_report.html", + f"{tu}/misc/plot_seeg.html": f"{tu}/clinical/20_seeg.html", + f"{tu}/misc/plot_ecog.html": f"{tu}/clinical/30_ecog.html", + f"{tu}/{ml}/plot_receptive_field.html": f"{tu}/{ml}/30_strf.html", + f"{tu}/{ml}/plot_sensors_decoding.html": f"{tu}/{ml}/50_decoding.html", + f"{tu}/{sm}/plot_background_freesurfer.html": f"{tu}/{fw}/10_background_freesurfer.html", # noqa E501 + f"{tu}/{sm}/plot_source_alignment.html": f"{tu}/{fw}/20_source_alignment.html", # noqa E501 + f"{tu}/{sm}/plot_forward.html": f"{tu}/{fw}/30_forward.html", + f"{tu}/{sm}/plot_eeg_no_mri.html": f"{tu}/{fw}/35_eeg_no_mri.html", + f"{tu}/{sm}/plot_background_freesurfer_mne.html": f"{tu}/{fw}/50_background_freesurfer_mne.html", # noqa E501 + f"{tu}/{sm}/plot_fix_bem_in_blender.html": f"{tu}/{fw}/80_fix_bem_in_blender.html", # noqa E501 + f"{tu}/{sm}/plot_compute_covariance.html": f"{tu}/{fw}/90_compute_covariance.html", # noqa E501 + f"{tu}/{sm}/plot_object_source_estimate.html": f"{tu}/{nv}/10_stc_class.html", # noqa E501 + f"{tu}/{sm}/plot_dipole_fit.html": f"{tu}/{nv}/20_dipole_fit.html", + f"{tu}/{sm}/plot_mne_dspm_source_localization.html": f"{tu}/{nv}/30_mne_dspm_loreta.html", # noqa E501 + f"{tu}/{sm}/plot_dipole_orientations.html": f"{tu}/{nv}/35_dipole_orientations.html", # noqa E501 + f"{tu}/{sm}/plot_mne_solutions.html": f"{tu}/{nv}/40_mne_fixed_free.html", + f"{tu}/{sm}/plot_beamformer_lcmv.html": f"{tu}/{nv}/50_beamformer_lcmv.html", # noqa E501 + f"{tu}/{sm}/plot_visualize_stc.html": f"{tu}/{nv}/60_visualize_stc.html", + f"{tu}/{sm}/plot_eeg_mri_coords.html": f"{tu}/{nv}/70_eeg_mri_coords.html", + f"{tu}/{sd}/plot_brainstorm_phantom_elekta.html": f"{tu}/{nv}/80_brainstorm_phantom_elekta.html", # noqa E501 + f"{tu}/{sd}/plot_brainstorm_phantom_ctf.html": f"{tu}/{nv}/85_brainstorm_phantom_ctf.html", # noqa E501 + f"{tu}/{sd}/plot_phantom_4DBTi.html": f"{tu}/{nv}/90_phantom_4DBTi.html", + f"{tu}/{sd}/plot_brainstorm_auditory.html": f"{tu}/io/60_ctf_bst_auditory.html", # noqa E501 + f"{tu}/{sd}/plot_sleep.html": f"{tu}/clinical/60_sleep.html", + f"{tu}/{di}/plot_background_filtering.html": f"{tu}/preprocessing/25_background_filtering.html", # noqa E501 + f"{tu}/{di}/plot_background_statistics.html": f"{tu}/{sn}/10_background_stats.html", # noqa E501 + f"{tu}/{sn}/plot_stats_cluster_erp.html": f"{tu}/{sn}/20_erp_stats.html", + f"{tu}/{sn}/plot_stats_cluster_1samp_test_time_frequency.html": f"{tu}/{sn}/40_cluster_1samp_time_freq.html", # noqa E501 + f"{tu}/{sn}/plot_stats_cluster_time_frequency.html": f"{tu}/{sn}/50_cluster_between_time_freq.html", # noqa E501 + f"{tu}/{sn}/plot_stats_spatio_temporal_cluster_sensors.html": f"{tu}/{sn}/75_cluster_ftest_spatiotemporal.html", # noqa E501 + f"{tu}/{sr}/plot_stats_cluster_spatio_temporal.html": f"{tu}/{sr}/20_cluster_1samp_spatiotemporal.html", # noqa E501 + f"{tu}/{sr}/plot_stats_cluster_spatio_temporal_2samp.html": f"{tu}/{sr}/30_cluster_ftest_spatiotemporal.html", # noqa E501 + f"{tu}/{sr}/plot_stats_cluster_spatio_temporal_repeated_measures_anova.html": f"{tu}/{sr}/60_cluster_rmANOVA_spatiotemporal.html", # noqa E501 + f"{tu}/{sr}/plot_stats_cluster_time_frequency_repeated_measures_anova.html": f"{tu}/{sn}/70_cluster_rmANOVA_time_freq.html", # noqa E501 + f"{tu}/{tf}/plot_sensors_time_frequency.html": f"{tu}/{tf}/20_sensors_time_frequency.html", # noqa E501 + f"{tu}/{tf}/plot_ssvep.html": f"{tu}/{tf}/50_ssvep.html", + f"{tu}/{si}/plot_creating_data_structures.html": f"{tu}/{si}/10_array_objs.html", # noqa E501 + f"{tu}/{si}/plot_point_spread.html": f"{tu}/{si}/70_point_spread.html", + f"{tu}/{si}/plot_dics.html": f"{tu}/{si}/80_dics.html", + f"{tu}/{tf}/plot_eyetracking.html": f"{tu}/preprocessing/90_eyetracking_data.html", # noqa E501 + f"{ex}/{co}/mne_inverse_label_connectivity.html": f"{mne_conn}/{ex}/mne_inverse_label_connectivity.html", # noqa E501 + f"{ex}/{co}/cwt_sensor_connectivity.html": f"{mne_conn}/{ex}/cwt_sensor_connectivity.html", # noqa E501 + f"{ex}/{co}/mixed_source_space_connectivity.html": f"{mne_conn}/{ex}/mixed_source_space_connectivity.html", # noqa E501 + f"{ex}/{co}/mne_inverse_coherence_epochs.html": f"{mne_conn}/{ex}/mne_inverse_coherence_epochs.html", # noqa E501 + f"{ex}/{co}/mne_inverse_connectivity_spectrum.html": f"{mne_conn}/{ex}/mne_inverse_connectivity_spectrum.html", # noqa E501 + f"{ex}/{co}/mne_inverse_envelope_correlation_volume.html": f"{mne_conn}/{ex}/mne_inverse_envelope_correlation_volume.html", # noqa E501 + f"{ex}/{co}/mne_inverse_envelope_correlation.html": f"{mne_conn}/{ex}/mne_inverse_envelope_correlation.html", # noqa E501 + f"{ex}/{co}/mne_inverse_psi_visual.html": f"{mne_conn}/{ex}/mne_inverse_psi_visual.html", # noqa E501 + f"{ex}/{co}/sensor_connectivity.html": f"{mne_conn}/{ex}/sensor_connectivity.html", # noqa E501 } @@ -1243,11 +1591,12 @@ def make_redirects(app, exception): """Make HTML redirects.""" # https://www.sphinx-doc.org/en/master/extdev/appapi.html # Adapted from sphinxcontrib/redirects (BSD-2-Clause) - if not (isinstance(app.builder, - sphinx.builders.html.StandaloneHTMLBuilder) and - exception is None): + if not ( + isinstance(app.builder, sphinx.builders.html.StandaloneHTMLBuilder) + and exception is None + ): return - logger = sphinx.util.logging.getLogger('mne') + logger = sphinx.util.logging.getLogger("mne") TEMPLATE = """\ @@ -1263,79 +1612,88 @@ def make_redirects(app, exception): If you are not redirected automatically, follow this link. """ # noqa: E501 - sphinx_gallery_conf = app.config['sphinx_gallery_conf'] - for src_dir, out_dir in zip(sphinx_gallery_conf['examples_dirs'], - sphinx_gallery_conf['gallery_dirs']): + sphinx_gallery_conf = app.config["sphinx_gallery_conf"] + for src_dir, out_dir in zip( + sphinx_gallery_conf["examples_dirs"], sphinx_gallery_conf["gallery_dirs"] + ): root = os.path.abspath(os.path.join(app.srcdir, src_dir)) - fnames = [os.path.join(os.path.relpath(dirpath, root), fname) - for dirpath, _, fnames in os.walk(root) - for fname in fnames - if fname in needed_plot_redirects] + fnames = [ + os.path.join(os.path.relpath(dirpath, root), fname) + for dirpath, _, fnames in os.walk(root) + for fname in fnames + if fname in needed_plot_redirects + ] # plot_ redirects for fname in fnames: dirname = os.path.join(app.outdir, out_dir, os.path.dirname(fname)) - to_fname = os.path.splitext(os.path.basename(fname))[0] + '.html' - fr_fname = f'plot_{to_fname}' + to_fname = os.path.splitext(os.path.basename(fname))[0] + ".html" + fr_fname = f"plot_{to_fname}" to_path = os.path.join(dirname, to_fname) fr_path = os.path.join(dirname, fr_fname) assert os.path.isfile(to_path), (fname, to_path) - with open(fr_path, 'w') as fid: + with open(fr_path, "w") as fid: fid.write(TEMPLATE.format(to=to_fname)) sphinx_logger.info( - f'Added {len(fnames):3d} HTML plot_* redirects for {out_dir}') + f"Added {len(fnames):3d} HTML plot_* redirects for {out_dir}" + ) # custom redirects for fr, to in custom_redirects.items(): - if not to.startswith('http'): + if not to.startswith("http"): assert os.path.isfile(os.path.join(app.outdir, to)), to # handle links to sibling folders - path_parts = to.split('/') + path_parts = to.split("/") assert tu in path_parts, path_parts # need to refactor otherwise - path_parts = ['..'] + path_parts[(path_parts.index(tu) + 1):] + path_parts = [".."] + path_parts[(path_parts.index(tu) + 1) :] to = os.path.join(*path_parts) - assert to.endswith('html'), to + assert to.endswith("html"), to fr_path = os.path.join(app.outdir, fr) - assert fr_path.endswith('html'), fr_path + assert fr_path.endswith("html"), fr_path # allow overwrite if existing file is just a redirect if os.path.isfile(fr_path): - with open(fr_path, 'r') as fid: + with open(fr_path, "r") as fid: for _ in range(8): next(fid) line = fid.readline() - assert 'Page Redirection' in line, line + assert "Page Redirection" in line, line # handle folders that no longer exist - if fr_path.split('/')[-2] in ( - 'misc', 'discussions', 'source-modeling', 'sample-datasets', - 'connectivity'): + if fr_path.split("/")[-2] in ( + "misc", + "discussions", + "source-modeling", + "sample-datasets", + "connectivity", + ): os.makedirs(os.path.dirname(fr_path), exist_ok=True) - with open(fr_path, 'w') as fid: + with open(fr_path, "w") as fid: fid.write(TEMPLATE.format(to=to)) - sphinx_logger.info( - f'Added {len(custom_redirects):3d} HTML custom redirects') + sphinx_logger.info(f"Added {len(custom_redirects):3d} HTML custom redirects") def make_version(app, exception): """Make a text file with the git version.""" - if not (isinstance(app.builder, - sphinx.builders.html.StandaloneHTMLBuilder) and - exception is None): + if not ( + isinstance(app.builder, sphinx.builders.html.StandaloneHTMLBuilder) + and exception is None + ): return - logger = sphinx.util.logging.getLogger('mne') + logger = sphinx.util.logging.getLogger("mne") try: - stdout, _ = run_subprocess(['git', 'rev-parse', 'HEAD'], verbose=False) + stdout, _ = run_subprocess(["git", "rev-parse", "HEAD"], verbose=False) except Exception as exc: - sphinx_logger.warning(f'Failed to write _version.txt: {exc}') + sphinx_logger.warning(f"Failed to write _version.txt: {exc}") return - with open(os.path.join(app.outdir, '_version.txt'), 'w') as fid: + with open(os.path.join(app.outdir, "_version.txt"), "w") as fid: fid.write(stdout) sphinx_logger.info(f'Added "{stdout.rstrip()}" > _version.txt') # -- Connect our handlers to the main Sphinx app --------------------------- + def setup(app): """Set up the Sphinx app.""" - app.connect('autodoc-process-docstring', append_attr_meth_examples) + app.connect("autodoc-process-docstring", append_attr_meth_examples) report_scraper.app = app - app.connect('builder-inited', report_scraper.copyfiles) - app.connect('build-finished', make_redirects) - app.connect('build-finished', make_version) + app.connect("builder-inited", report_scraper.copyfiles) + app.connect("build-finished", make_redirects) + app.connect("build-finished", make_version) diff --git a/doc/sphinxext/flow_diagram.py b/doc/sphinxext/flow_diagram.py index 9adb8636e2f..d6a941d7869 100644 --- a/doc/sphinxext/flow_diagram.py +++ b/doc/sphinxext/flow_diagram.py @@ -1,14 +1,14 @@ import os from os import path as op -title = 'mne-python flow diagram' +title = "mne-python flow diagram" -font_face = 'Arial' +font_face = "Arial" node_size = 12 node_small_size = 9 edge_size = 9 -sensor_color = '#7bbeca' -source_color = '#ff6347' +sensor_color = "#7bbeca" +source_color = "#ff6347" legend = """ < @@ -17,62 +17,74 @@ Sensor (M/EEG) space Source (brain) space ->""" % (edge_size, sensor_color, source_color) -legend = ''.join(legend.split('\n')) +>""" % ( + edge_size, + sensor_color, + source_color, +) +legend = "".join(legend.split("\n")) nodes = dict( - T1='T1', - flashes='Flash5/30', - trans='Head-MRI trans', - recon='Freesurfer surfaces', - bem='BEM', - src='Source space\nmne.SourceSpaces', - cov='Noise covariance\nmne.Covariance', - fwd='Forward solution\nmne.forward.Forward', - inv='Inverse operator\nmne.minimum_norm.InverseOperator', - stc='Source estimate\nmne.SourceEstimate', - raw='Raw data\nmne.io.Raw', - epo='Epoched data\nmne.Epochs', - evo='Averaged data\nmne.Evoked', - pre='Preprocessed data\nmne.io.Raw', + T1="T1", + flashes="Flash5/30", + trans="Head-MRI trans", + recon="Freesurfer surfaces", + bem="BEM", + src="Source space\nmne.SourceSpaces", + cov="Noise covariance\nmne.Covariance", + fwd="Forward solution\nmne.forward.Forward", + inv="Inverse operator\nmne.minimum_norm.InverseOperator", + stc="Source estimate\nmne.SourceEstimate", + raw="Raw data\nmne.io.Raw", + epo="Epoched data\nmne.Epochs", + evo="Averaged data\nmne.Evoked", + pre="Preprocessed data\nmne.io.Raw", legend=legend, ) -sensor_space = ('raw', 'pre', 'epo', 'evo', 'cov') -source_space = ('src', 'stc', 'bem', 'flashes', 'recon', 'T1') +sensor_space = ("raw", "pre", "epo", "evo", "cov") +source_space = ("src", "stc", "bem", "flashes", "recon", "T1") edges = ( - ('T1', 'recon'), - ('flashes', 'bem'), - ('recon', 'bem'), - ('recon', 'src', 'mne.setup_source_space'), - ('src', 'fwd'), - ('bem', 'fwd'), - ('trans', 'fwd', 'mne.make_forward_solution'), - ('fwd', 'inv'), - ('cov', 'inv', 'mne.make_inverse_operator'), - ('inv', 'stc'), - ('evo', 'stc', 'mne.minimum_norm.apply_inverse'), - ('raw', 'pre', 'raw.filter\n' - 'mne.preprocessing.ICA\n' - 'mne.preprocessing.compute_proj_eog\n' - 'mne.preprocessing.compute_proj_ecg\n' - '...'), - ('pre', 'epo', 'mne.Epochs'), - ('epo', 'evo', 'epochs.average'), - ('epo', 'cov', 'mne.compute_covariance'), + ("T1", "recon"), + ("flashes", "bem"), + ("recon", "bem"), + ("recon", "src", "mne.setup_source_space"), + ("src", "fwd"), + ("bem", "fwd"), + ("trans", "fwd", "mne.make_forward_solution"), + ("fwd", "inv"), + ("cov", "inv", "mne.make_inverse_operator"), + ("inv", "stc"), + ("evo", "stc", "mne.minimum_norm.apply_inverse"), + ( + "raw", + "pre", + "raw.filter\n" + "mne.preprocessing.ICA\n" + "mne.preprocessing.compute_proj_eog\n" + "mne.preprocessing.compute_proj_ecg\n" + "...", + ), + ("pre", "epo", "mne.Epochs"), + ("epo", "evo", "epochs.average"), + ("epo", "cov", "mne.compute_covariance"), ) subgraphs = ( - [('T1', 'flashes', 'recon', 'bem', 'src'), - ('' - 'Freesurfer / MNE-C>' % node_small_size)], + [ + ("T1", "flashes", "recon", "bem", "src"), + ( + '' + "Freesurfer / MNE-C>" % node_small_size + ), + ], ) def setup(app): - app.connect('builder-inited', generate_flow_diagram) - app.add_config_value('make_flow_diagram', True, 'html') + app.connect("builder-inited", generate_flow_diagram) + app.add_config_value("make_flow_diagram", True, "html") def setup_module(): @@ -81,84 +93,88 @@ def setup_module(): def generate_flow_diagram(app): - out_dir = op.join(app.builder.outdir, '_static') + out_dir = op.join(app.builder.outdir, "_static") if not op.isdir(out_dir): os.makedirs(out_dir) - out_fname = op.join(out_dir, 'mne-python_flow.svg') - make_flow_diagram = app is None or \ - bool(app.builder.config.make_flow_diagram) + out_fname = op.join(out_dir, "mne-python_flow.svg") + make_flow_diagram = app is None or bool(app.builder.config.make_flow_diagram) if not make_flow_diagram: - print('Skipping flow diagram, webpage will have a missing image') + print("Skipping flow diagram, webpage will have a missing image") return import pygraphviz as pgv + g = pgv.AGraph(name=title, directed=True) for key, label in nodes.items(): - label = label.split('\n') + label = label.split("\n") if len(label) > 1: - label[0] = ('<' % node_size - + label[0] + '') + label[0] = '<' % node_size + label[0] + "" for li in range(1, len(label)): - label[li] = ('' % node_small_size - + label[li] + '') - label[-1] = label[-1] + '>' - label = '
'.join(label) + label[li] = ( + '' % node_small_size + + label[li] + + "" + ) + label[-1] = label[-1] + ">" + label = "
".join(label) else: label = label[0] - g.add_node(key, shape='plaintext', label=label) + g.add_node(key, shape="plaintext", label=label) # Create and customize nodes and edges for edge in edges: g.add_edge(*edge[:2]) e = g.get_edge(*edge[:2]) if len(edge) > 2: - e.attr['label'] = ('<' + - '
'.join(edge[2].split('\n')) + - '
>') - e.attr['fontsize'] = edge_size + e.attr["label"] = ( + "<" + + '
'.join(edge[2].split("\n")) + + '
>' + ) + e.attr["fontsize"] = edge_size # Change colors - for these_nodes, color in zip((sensor_space, source_space), - (sensor_color, source_color)): + for these_nodes, color in zip( + (sensor_space, source_space), (sensor_color, source_color) + ): for node in these_nodes: - g.get_node(node).attr['fillcolor'] = color - g.get_node(node).attr['style'] = 'filled' + g.get_node(node).attr["fillcolor"] = color + g.get_node(node).attr["style"] = "filled" # Create subgraphs for si, subgraph in enumerate(subgraphs): - g.add_subgraph(subgraph[0], 'cluster%s' % si, - label=subgraph[1], color='black') + g.add_subgraph(subgraph[0], "cluster%s" % si, label=subgraph[1], color="black") # Format (sub)graphs for gr in g.subgraphs() + [g]: for x in [gr.node_attr, gr.edge_attr]: - x['fontname'] = font_face - g.node_attr['shape'] = 'box' + x["fontname"] = font_face + g.node_attr["shape"] = "box" # A couple of special ones - for ni, node in enumerate(('fwd', 'inv', 'trans')): + for ni, node in enumerate(("fwd", "inv", "trans")): node = g.get_node(node) - node.attr['gradientangle'] = 270 + node.attr["gradientangle"] = 270 colors = (source_color, sensor_color) colors = colors if ni == 0 else colors[::-1] - node.attr['fillcolor'] = ':'.join(colors) - node.attr['style'] = 'filled' + node.attr["fillcolor"] = ":".join(colors) + node.attr["style"] = "filled" del node - g.get_node('legend').attr.update(shape='plaintext', margin=0, rank='sink') + g.get_node("legend").attr.update(shape="plaintext", margin=0, rank="sink") # put legend in same rank/level as inverse - leg = g.add_subgraph(['legend', 'inv'], name='legendy') - leg.graph_attr['rank'] = 'same' + leg = g.add_subgraph(["legend", "inv"], name="legendy") + leg.graph_attr["rank"] = "same" - g.layout('dot') - g.draw(out_fname, format='svg') + g.layout("dot") + g.draw(out_fname, format="svg") return g # This is useful for testing/iterating to see what the result looks like -if __name__ == '__main__': +if __name__ == "__main__": from mne.io.constants import Bunch - out_dir = op.abspath(op.join(op.dirname(__file__), '..', '_build', 'html')) - app = Bunch(builder=Bunch(outdir=out_dir, - config=Bunch(make_flow_diagram=True))) + + out_dir = op.abspath(op.join(op.dirname(__file__), "..", "_build", "html")) + app = Bunch(builder=Bunch(outdir=out_dir, config=Bunch(make_flow_diagram=True))) g = generate_flow_diagram(app) diff --git a/doc/sphinxext/gen_commands.py b/doc/sphinxext/gen_commands.py index 0339160b2bb..0ca15319d36 100644 --- a/doc/sphinxext/gen_commands.py +++ b/doc/sphinxext/gen_commands.py @@ -7,7 +7,7 @@ def setup(app): - app.connect('builder-inited', generate_commands_rst) + app.connect("builder-inited", generate_commands_rst) def setup_module(): @@ -52,23 +52,24 @@ def generate_commands_rst(app=None): except Exception: from sphinx.util import status_iterator root = Path(__file__).parent.parent.parent.absolute() - out_dir = (root / 'doc' / 'generated').absolute() + out_dir = (root / "doc" / "generated").absolute() out_dir.mkdir(exist_ok=True) - out_fname =out_dir / 'commands.rst.new' + out_fname = out_dir / "commands.rst.new" - command_path = root / 'mne' / 'commands' + command_path = root / "mne" / "commands" fnames = sorted( - Path(fname).name - for fname in glob.glob(str(command_path / 'mne_*.py'))) + Path(fname).name for fname in glob.glob(str(command_path / "mne_*.py")) + ) assert len(fnames) iterator = status_iterator( - fnames, 'generating MNE command help ... ', length=len(fnames)) - with open(out_fname, 'w', encoding='utf8') as f: + fnames, "generating MNE command help ... ", length=len(fnames) + ) + with open(out_fname, "w", encoding="utf8") as f: f.write(header) for fname in iterator: cmd_name = fname[:-3] - module = import_module('.' + cmd_name, 'mne.commands') - with ArgvSetter(('mne', cmd_name, '--help')) as out: + module = import_module("." + cmd_name, "mne.commands") + with ArgvSetter(("mne", cmd_name, "--help")) as out: try: module.run() except SystemExit: # this is how these terminate @@ -80,29 +81,30 @@ def generate_commands_rst(app=None): # Add header marking for idx in (1, 0): - output.insert(idx, '-' * len(output[0])) + output.insert(idx, "-" * len(output[0])) # Add code styling for the "Usage: " line for li, line in enumerate(output): - if line.startswith('Usage: mne '): - output[li] = 'Usage: ``%s``' % line[7:] + if line.startswith("Usage: mne "): + output[li] = "Usage: ``%s``" % line[7:] break # Turn "Options:" into field list - if 'Options:' in output: - ii = output.index('Options:') - output[ii] = 'Options' - output.insert(ii + 1, '-------') - output.insert(ii + 2, '') - output.insert(ii + 3, '.. rst-class:: field-list cmd-list') - output.insert(ii + 4, '') - output = '\n'.join(output) - cmd_name_space = cmd_name.replace('mne_', 'mne ') - f.write(command_rst.format( - cmd_name_space, '=' * len(cmd_name_space), output)) + if "Options:" in output: + ii = output.index("Options:") + output[ii] = "Options" + output.insert(ii + 1, "-------") + output.insert(ii + 2, "") + output.insert(ii + 3, ".. rst-class:: field-list cmd-list") + output.insert(ii + 4, "") + output = "\n".join(output) + cmd_name_space = cmd_name.replace("mne_", "mne ") + f.write( + command_rst.format(cmd_name_space, "=" * len(cmd_name_space), output) + ) _replace_md5(str(out_fname)) # This is useful for testing/iterating to see what the result looks like -if __name__ == '__main__': +if __name__ == "__main__": generate_commands_rst() diff --git a/doc/sphinxext/gen_names.py b/doc/sphinxext/gen_names.py index 92c155b8f52..c5cc7f9f9ea 100644 --- a/doc/sphinxext/gen_names.py +++ b/doc/sphinxext/gen_names.py @@ -3,7 +3,7 @@ def setup(app): - app.connect('builder-inited', generate_name_links_rst) + app.connect("builder-inited", generate_name_links_rst) def setup_module(): @@ -12,17 +12,18 @@ def setup_module(): def generate_name_links_rst(app=None): - if 'linkcheck' not in str(app.builder).lower(): + if "linkcheck" not in str(app.builder).lower(): return - out_dir = op.abspath(op.join(op.dirname(__file__), '..', 'generated')) + out_dir = op.abspath(op.join(op.dirname(__file__), "..", "generated")) if not op.isdir(out_dir): os.mkdir(out_dir) - out_fname = op.join(out_dir, '_names.rst') + out_fname = op.join(out_dir, "_names.rst") names_path = op.abspath( - op.join(os.path.dirname(__file__), '..', 'changes', 'names.inc')) - with open(out_fname, 'w', encoding='utf8') as fout: - fout.write(':orphan:\n\n') - with open(names_path, 'r') as fin: + op.join(os.path.dirname(__file__), "..", "changes", "names.inc") + ) + with open(out_fname, "w", encoding="utf8") as fout: + fout.write(":orphan:\n\n") + with open(names_path, "r") as fin: for line in fin: - if line.startswith('.. _'): - fout.write(f'- {line[4:]}') + if line.startswith(".. _"): + fout.write(f"- {line[4:]}") diff --git a/doc/sphinxext/gh_substitutions.py b/doc/sphinxext/gh_substitutions.py index f0c6a05c5ba..4463425867d 100644 --- a/doc/sphinxext/gh_substitutions.py +++ b/doc/sphinxext/gh_substitutions.py @@ -15,14 +15,14 @@ def gh_role(name, rawtext, text, lineno, inliner, options={}, content=[]): # direct link mode slug = text else: - slug = 'issues/' + text - text = '#' + text - ref = 'https://github.com/mne-tools/mne-python/' + slug + slug = "issues/" + text + text = "#" + text + ref = "https://github.com/mne-tools/mne-python/" + slug set_classes(options) node = reference(rawtext, text, refuri=ref, **options) return [node], [] def setup(app): - app.add_role('gh', gh_role) + app.add_role("gh", gh_role) return diff --git a/doc/sphinxext/mne_substitutions.py b/doc/sphinxext/mne_substitutions.py index a9309baaf42..a1b8627edf9 100644 --- a/doc/sphinxext/mne_substitutions.py +++ b/doc/sphinxext/mne_substitutions.py @@ -3,46 +3,57 @@ from docutils.statemachine import StringList from mne.defaults import DEFAULTS -from mne.io.pick import (_PICK_TYPES_DATA_DICT, _DATA_CH_TYPES_SPLIT, - _DATA_CH_TYPES_ORDER_DEFAULT) +from mne.io.pick import ( + _PICK_TYPES_DATA_DICT, + _DATA_CH_TYPES_SPLIT, + _DATA_CH_TYPES_ORDER_DEFAULT, +) class MNESubstitution(Directive): # noqa: D101 - has_content = False required_arguments = 1 final_argument_whitespace = True def run(self, **kwargs): # noqa: D102 env = self.state.document.settings.env - if self.arguments[0] == 'data channels list': + if self.arguments[0] == "data channels list": keys = list() for key in _DATA_CH_TYPES_ORDER_DEFAULT: if key in _DATA_CH_TYPES_SPLIT: keys.append(key) - elif key not in ('meg', 'fnirs') and \ - _PICK_TYPES_DATA_DICT.get(key, False): + elif key not in ("meg", "fnirs") and _PICK_TYPES_DATA_DICT.get( + key, False + ): keys.append(key) - rst = '- ' + '\n- '.join( - '``%r``: **%s** (scaled by %g to plot in *%s*)' - % (key, DEFAULTS['titles'][key], DEFAULTS['scalings'][key], - DEFAULTS['units'][key]) - for key in keys) + rst = "- " + "\n- ".join( + "``%r``: **%s** (scaled by %g to plot in *%s*)" + % ( + key, + DEFAULTS["titles"][key], + DEFAULTS["scalings"][key], + DEFAULTS["units"][key], + ) + for key in keys + ) else: raise self.error( - 'MNE directive unknown in %s: %r' - % (env.doc2path(env.docname, base=None), - self.arguments[0],)) + "MNE directive unknown in %s: %r" + % ( + env.doc2path(env.docname, base=None), + self.arguments[0], + ) + ) node = nodes.compound(rst) # General(Body), Element content = StringList( - rst.split('\n'), parent=self.content.parent, - parent_offset=self.content.parent_offset) + rst.split("\n"), + parent=self.content.parent, + parent_offset=self.content.parent_offset, + ) self.state.nested_parse(content, self.content_offset, node) return [node] def setup(app): # noqa: D103 - app.add_directive('mne', MNESubstitution) - return {'version': '0.1', - 'parallel_read_safe': True, - 'parallel_write_safe': True} + app.add_directive("mne", MNESubstitution) + return {"version": "0.1", "parallel_read_safe": True, "parallel_write_safe": True} diff --git a/doc/sphinxext/newcontrib_substitutions.py b/doc/sphinxext/newcontrib_substitutions.py index 68595e74bdb..8c31e8ca0e2 100644 --- a/doc/sphinxext/newcontrib_substitutions.py +++ b/doc/sphinxext/newcontrib_substitutions.py @@ -1,18 +1,17 @@ from docutils.nodes import reference, strong, target -def newcontrib_role(name, rawtext, text, lineno, inliner, options={}, - content=[]): +def newcontrib_role(name, rawtext, text, lineno, inliner, options={}, content=[]): """Create a role to highlight new contributors in changelog entries.""" - newcontrib = f'new contributor {text}' - alias_text = f' <{text}_>' - rawtext = f'`{newcontrib}{alias_text}`_' + newcontrib = f"new contributor {text}" + alias_text = f" <{text}_>" + rawtext = f"`{newcontrib}{alias_text}`_" refname = text.lower() strong_node = strong(rawtext, newcontrib) target_node = target(alias_text, refname=refname, names=[newcontrib]) target_node.indirect_reference_name = text options.update(refname=refname, name=newcontrib) - ref_node = reference('', '', strong_node, **options) + ref_node = reference("", "", strong_node, **options) ref_node[0].rawsource = rawtext inliner.document.note_indirect_target(target_node) inliner.document.note_refname(ref_node) @@ -20,5 +19,5 @@ def newcontrib_role(name, rawtext, text, lineno, inliner, options={}, def setup(app): - app.add_role('newcontrib', newcontrib_role) + app.add_role("newcontrib", newcontrib_role) return diff --git a/doc/sphinxext/unit_role.py b/doc/sphinxext/unit_role.py index d912786b474..83b82c223e4 100644 --- a/doc/sphinxext/unit_role.py +++ b/doc/sphinxext/unit_role.py @@ -6,8 +6,10 @@ def unit_role(name, rawtext, text, lineno, inliner, options={}, content=[]): def pass_error_to_sphinx(rawtext, text, lineno, inliner): msg = inliner.reporter.error( - 'The :unit: role requires a space-separated number and unit; ' - f'got {text}', line=lineno) + "The :unit: role requires a space-separated number and unit; " + f"got {text}", + line=lineno, + ) prb = inliner.problematic(rawtext, rawtext, msg) return [prb], [msg] @@ -20,10 +22,10 @@ def pass_error_to_sphinx(rawtext, text, lineno, inliner): except ValueError: return pass_error_to_sphinx(rawtext, text, lineno, inliner) # input is well-formatted: proceed - node = nodes.Text('\u202F'.join(parts)) + node = nodes.Text("\u202F".join(parts)) return [node], [] def setup(app): - app.add_role('unit', unit_role) + app.add_role("unit", unit_role) return diff --git a/examples/datasets/brainstorm_data.py b/examples/datasets/brainstorm_data.py index 949f2511a88..df08d0d383a 100644 --- a/examples/datasets/brainstorm_data.py +++ b/examples/datasets/brainstorm_data.py @@ -30,29 +30,38 @@ data_path = bst_raw.data_path() -raw_path = (data_path / 'MEG' / 'bst_raw' / - 'subj001_somatosensory_20111109_01_AUX-f.ds') +raw_path = data_path / "MEG" / "bst_raw" / "subj001_somatosensory_20111109_01_AUX-f.ds" # Here we crop to half the length to save memory raw = read_raw_ctf(raw_path).crop(0, 120).load_data() raw.plot() # set EOG channel -raw.set_channel_types({'EEG058': 'eog'}) -raw.set_eeg_reference('average', projection=True) +raw.set_channel_types({"EEG058": "eog"}) +raw.set_eeg_reference("average", projection=True) # show power line interference and remove it raw.compute_psd(tmax=60).plot(average=False) -raw.notch_filter(np.arange(60, 181, 60), fir_design='firwin') +raw.notch_filter(np.arange(60, 181, 60), fir_design="firwin") -events = mne.find_events(raw, stim_channel='UPPT001') +events = mne.find_events(raw, stim_channel="UPPT001") # pick MEG channels -picks = mne.pick_types(raw.info, meg=True, eeg=False, stim=False, eog=True, - exclude='bads') +picks = mne.pick_types( + raw.info, meg=True, eeg=False, stim=False, eog=True, exclude="bads" +) # Compute epochs -epochs = mne.Epochs(raw, events, event_id, tmin, tmax, picks=picks, - baseline=(None, 0), reject=reject, preload=False) +epochs = mne.Epochs( + raw, + events, + event_id, + tmin, + tmax, + picks=picks, + baseline=(None, 0), + reject=reject, + preload=False, +) # compute evoked evoked = epochs.average() @@ -68,11 +77,10 @@ evoked.shift_time(-0.004) # plot the result -evoked.plot(time_unit='s') +evoked.plot(time_unit="s") # show topomaps -evoked.plot_topomap(times=np.array([0.016, 0.030, 0.060, 0.070]), - time_unit='s') +evoked.plot_topomap(times=np.array([0.016, 0.030, 0.060, 0.070]), time_unit="s") # %% # References diff --git a/examples/datasets/hf_sef_data.py b/examples/datasets/hf_sef_data.py index 9857d22d09d..36ea2cbc2bb 100644 --- a/examples/datasets/hf_sef_data.py +++ b/examples/datasets/hf_sef_data.py @@ -18,8 +18,7 @@ import os from mne.datasets import hf_sef -fname_evoked = os.path.join(hf_sef.data_path(), - 'MEG/subject_b/hf_sef_15min-ave.fif') +fname_evoked = os.path.join(hf_sef.data_path(), "MEG/subject_b/hf_sef_15min-ave.fif") print(__doc__) @@ -34,7 +33,7 @@ # %% # Compare high-pass filtered and unfiltered data on a single channel -ch = 'MEG0443' +ch = "MEG0443" pick = evoked.ch_names.index(ch) -edi = {'HF': evoked_hp, 'Regular': evoked} +edi = {"HF": evoked_hp, "Regular": evoked} mne.viz.plot_compare_evokeds(edi, picks=pick) diff --git a/examples/datasets/limo_data.py b/examples/datasets/limo_data.py index d5670f62ffe..4285411dd6c 100644 --- a/examples/datasets/limo_data.py +++ b/examples/datasets/limo_data.py @@ -112,7 +112,7 @@ # metadata. # We want include all columns in the summary table -epochs_summary = limo_epochs.metadata.describe(include='all').round(3) +epochs_summary = limo_epochs.metadata.describe(include="all").round(3) print(epochs_summary) # %% @@ -137,13 +137,13 @@ ts_args = dict(xlim=(-0.25, 0.5)) # plot evoked response for face A -limo_epochs['Face/A'].average().plot_joint(times=[0.15], - title='Evoked response: Face A', - ts_args=ts_args) +limo_epochs["Face/A"].average().plot_joint( + times=[0.15], title="Evoked response: Face A", ts_args=ts_args +) # and face B -limo_epochs['Face/B'].average().plot_joint(times=[0.15], - title='Evoked response: Face B', - ts_args=ts_args) +limo_epochs["Face/B"].average().plot_joint( + times=[0.15], title="Evoked response: Face B", ts_args=ts_args +) # %% # We can also compute the difference wave contrasting Face A and Face B. @@ -151,12 +151,12 @@ # differences among these face-stimuli. # Face A minus Face B -difference_wave = combine_evoked([limo_epochs['Face/A'].average(), - limo_epochs['Face/B'].average()], - weights=[1, -1]) +difference_wave = combine_evoked( + [limo_epochs["Face/A"].average(), limo_epochs["Face/B"].average()], weights=[1, -1] +) # plot difference wave -difference_wave.plot_joint(times=[0.15], title='Difference Face A - Face B') +difference_wave.plot_joint(times=[0.15], title="Difference Face A - Face B") # %% # As expected, no clear pattern appears when contrasting @@ -167,11 +167,10 @@ # Create a dictionary containing the evoked responses conditions = ["Face/A", "Face/B"] -evokeds = {condition: limo_epochs[condition].average() - for condition in conditions} +evokeds = {condition: limo_epochs[condition].average() for condition in conditions} # concentrate analysis an occipital electrodes (e.g. B11) -pick = evokeds["Face/A"].ch_names.index('B11') +pick = evokeds["Face/A"].ch_names.index("B11") # compare evoked responses plot_compare_evokeds(evokeds, picks=pick, ylim=dict(eeg=(-15, 7.5))) @@ -188,26 +187,30 @@ # one could expect that faces with high phase-coherence should evoke stronger # activation patterns along occipital electrodes. -phase_coh = limo_epochs.metadata['phase-coherence'] +phase_coh = limo_epochs.metadata["phase-coherence"] # get levels of phase coherence levels = sorted(phase_coh.unique()) # create labels for levels of phase coherence (i.e., 0 - 85%) -labels = ["{0:.2f}".format(i) for i in np.arange(0., 0.90, 0.05)] +labels = ["{0:.2f}".format(i) for i in np.arange(0.0, 0.90, 0.05)] # create dict of evokeds for each level of phase-coherence -evokeds = {label: limo_epochs[phase_coh == level].average() - for level, label in zip(levels, labels)} +evokeds = { + label: limo_epochs[phase_coh == level].average() + for level, label in zip(levels, labels) +} # pick channel to plot -electrodes = ['C22', 'B11'] +electrodes = ["C22", "B11"] # create figures for electrode in electrodes: fig, ax = plt.subplots(figsize=(8, 4)) - plot_compare_evokeds(evokeds, - axes=ax, - ylim=dict(eeg=(-20, 15)), - picks=electrode, - cmap=("Phase coherence", "magma")) + plot_compare_evokeds( + evokeds, + axes=ax, + ylim=dict(eeg=(-20, 15)), + picks=electrode, + cmap=("Phase coherence", "magma"), + ) # %% # As shown above, there are some considerable differences between the @@ -225,7 +228,7 @@ # present in the data: limo_epochs.interpolate_bads(reset_bads=True) -limo_epochs.drop_channels(['EXG1', 'EXG2', 'EXG3', 'EXG4']) +limo_epochs.drop_channels(["EXG1", "EXG2", "EXG3", "EXG4"]) # %% # Define predictor variables and design matrix @@ -238,21 +241,19 @@ # ``limo_epochs.metadata``: phase-coherence and Face A vs. Face B. # name of predictors + intercept -predictor_vars = ['face a - face b', 'phase-coherence', 'intercept'] +predictor_vars = ["face a - face b", "phase-coherence", "intercept"] # create design matrix -design = limo_epochs.metadata[['phase-coherence', 'face']].copy() -design['face a - face b'] = np.where(design['face'] == 'A', 1, -1) -design['intercept'] = 1 +design = limo_epochs.metadata[["phase-coherence", "face"]].copy() +design["face a - face b"] = np.where(design["face"] == "A", 1, -1) +design["intercept"] = 1 design = design[predictor_vars] # %% # Now we can set up the linear model to be used in the analysis using # MNE-Python's func:`~mne.stats.linear_regression` function. -reg = linear_regression(limo_epochs, - design_matrix=design, - names=predictor_vars) +reg = linear_regression(limo_epochs, design_matrix=design, names=predictor_vars) # %% # Extract regression coefficients @@ -262,8 +263,8 @@ # which is a dictionary of evoked objects containing # multiple inferential measures for each predictor in the design matrix. -print('predictors are:', list(reg)) -print('fields are:', [field for field in getattr(reg['intercept'], '_fields')]) +print("predictors are:", list(reg)) +print("fields are:", [field for field in getattr(reg["intercept"], "_fields")]) # %% # Plot model results @@ -279,25 +280,23 @@ # the activity measured at occipital electrodes around 200 to 250 ms following # stimulus onset. -reg['phase-coherence'].beta.plot_joint(ts_args=ts_args, - title='Effect of Phase-coherence', - times=[0.23]) +reg["phase-coherence"].beta.plot_joint( + ts_args=ts_args, title="Effect of Phase-coherence", times=[0.23] +) # %% # We can also plot the corresponding T values. # use unit=False and scale=1 to keep values at their original # scale (i.e., avoid conversion to micro-volt). -ts_args = dict(xlim=(-0.25, 0.5), - unit=False) -topomap_args = dict(scalings=dict(eeg=1), - average=0.05) +ts_args = dict(xlim=(-0.25, 0.5), unit=False) +topomap_args = dict(scalings=dict(eeg=1), average=0.05) # sphinx_gallery_thumbnail_number = 9 -fig = reg['phase-coherence'].t_val.plot_joint(ts_args=ts_args, - topomap_args=topomap_args, - times=[0.23]) -fig.axes[0].set_ylabel('T-value') +fig = reg["phase-coherence"].t_val.plot_joint( + ts_args=ts_args, topomap_args=topomap_args, times=[0.23] +) +fig.axes[0].set_ylabel("T-value") # %% # Conversely, there appears to be no (or very small) systematic effects when @@ -305,9 +304,9 @@ # difference wave approach presented above. ts_args = dict(xlim=(-0.25, 0.5)) -reg['face a - face b'].beta.plot_joint(ts_args=ts_args, - title='Effect of Face A vs. Face B', - times=[0.23]) +reg["face a - face b"].beta.plot_joint( + ts_args=ts_args, title="Effect of Face A vs. Face B", times=[0.23] +) # %% # References diff --git a/examples/datasets/opm_data.py b/examples/datasets/opm_data.py index ec6daab1037..184ea216866 100644 --- a/examples/datasets/opm_data.py +++ b/examples/datasets/opm_data.py @@ -16,13 +16,12 @@ import mne data_path = mne.datasets.opm.data_path() -subject = 'OPM_sample' -subjects_dir = data_path / 'subjects' -raw_fname = data_path / 'MEG' / 'OPM' / 'OPM_SEF_raw.fif' -bem_fname = (subjects_dir / subject / 'bem' / - f'{subject}-5120-5120-5120-bem-sol.fif') -fwd_fname = data_path / 'MEG' / 'OPM' / 'OPM_sample-fwd.fif' -coil_def_fname = data_path / 'MEG' / 'OPM' / 'coil_def.dat' +subject = "OPM_sample" +subjects_dir = data_path / "subjects" +raw_fname = data_path / "MEG" / "OPM" / "OPM_SEF_raw.fif" +bem_fname = subjects_dir / subject / "bem" / f"{subject}-5120-5120-5120-bem-sol.fif" +fwd_fname = data_path / "MEG" / "OPM" / "OPM_sample-fwd.fif" +coil_def_fname = data_path / "MEG" / "OPM" / "coil_def.dat" # %% # Prepare data for localization @@ -30,8 +29,8 @@ # First we filter and epoch the data: raw = mne.io.read_raw_fif(raw_fname, preload=True) -raw.filter(None, 90, h_trans_bandwidth=10.) -raw.notch_filter(50., notch_widths=1) +raw.filter(None, 90, h_trans_bandwidth=10.0) +raw.notch_filter(50.0, notch_widths=1) # Set epoch rejection threshold a bit larger than for SQUIDs @@ -40,16 +39,26 @@ # Find median nerve stimulator trigger event_id = dict(Median=257) -events = mne.find_events(raw, stim_channel='STI101', mask=257, mask_type='and') +events = mne.find_events(raw, stim_channel="STI101", mask=257, mask_type="and") picks = mne.pick_types(raw.info, meg=True, eeg=False) # We use verbose='error' to suppress warning about decimation causing aliasing, # ideally we would low-pass and then decimate instead -epochs = mne.Epochs(raw, events, event_id, tmin, tmax, verbose='error', - reject=reject, picks=picks, proj=False, decim=10, - preload=True) +epochs = mne.Epochs( + raw, + events, + event_id, + tmin, + tmax, + verbose="error", + reject=reject, + picks=picks, + proj=False, + decim=10, + preload=True, +) evoked = epochs.average() evoked.plot() -cov = mne.compute_covariance(epochs, tmax=0.) +cov = mne.compute_covariance(epochs, tmax=0.0) del epochs, raw # %% @@ -63,7 +72,7 @@ # but should be fine for these analyses. bem = mne.read_bem_solution(bem_fname) -trans = mne.transforms.Transform('head', 'mri') # identity transformation +trans = mne.transforms.Transform("head", "mri") # identity transformation # To compute the forward solution, we must # provide our temporary/custom coil definitions, which can be done as:: @@ -78,12 +87,18 @@ mne.convert_forward_solution(fwd, force_fixed=True, copy=False) with mne.use_coil_def(coil_def_fname): - fig = mne.viz.plot_alignment(evoked.info, trans=trans, subject=subject, - subjects_dir=subjects_dir, - surfaces=('head', 'pial'), bem=bem) - -mne.viz.set_3d_view(figure=fig, azimuth=45, elevation=60, distance=0.4, - focalpoint=(0.02, 0, 0.04)) + fig = mne.viz.plot_alignment( + evoked.info, + trans=trans, + subject=subject, + subjects_dir=subjects_dir, + surfaces=("head", "pial"), + bem=bem, + ) + +mne.viz.set_3d_view( + figure=fig, azimuth=45, elevation=60, distance=0.4, focalpoint=(0.02, 0, 0.04) +) # %% # Perform dipole fitting @@ -91,15 +106,17 @@ # Fit dipoles on a subset of time points with mne.use_coil_def(coil_def_fname): - dip_opm, _ = mne.fit_dipole(evoked.copy().crop(0.040, 0.080), - cov, bem, trans, verbose=True) + dip_opm, _ = mne.fit_dipole( + evoked.copy().crop(0.040, 0.080), cov, bem, trans, verbose=True + ) idx = np.argmax(dip_opm.gof) -print('Best dipole at t=%0.1f ms with %0.1f%% GOF' - % (1000 * dip_opm.times[idx], dip_opm.gof[idx])) +print( + "Best dipole at t=%0.1f ms with %0.1f%% GOF" + % (1000 * dip_opm.times[idx], dip_opm.gof[idx]) +) # Plot N20m dipole as an example -dip_opm.plot_locations(trans, subject, subjects_dir, - mode='orthoview', idx=idx) +dip_opm.plot_locations(trans, subject, subjects_dir, mode="orthoview", idx=idx) # %% # Perform minimum-norm localization @@ -109,18 +126,24 @@ # areas we are sensitive to might be a good idea. inverse_operator = mne.minimum_norm.make_inverse_operator( - evoked.info, fwd, cov, loose=0., depth=None) + evoked.info, fwd, cov, loose=0.0, depth=None +) del fwd, cov method = "MNE" -snr = 3. -lambda2 = 1. / snr ** 2 +snr = 3.0 +lambda2 = 1.0 / snr**2 stc = mne.minimum_norm.apply_inverse( - evoked, inverse_operator, lambda2, method=method, - pick_ori=None, verbose=True) + evoked, inverse_operator, lambda2, method=method, pick_ori=None, verbose=True +) # Plot source estimate at time of best dipole fit -brain = stc.plot(hemi='rh', views='lat', subjects_dir=subjects_dir, - initial_time=dip_opm.times[idx], - clim=dict(kind='percent', lims=[99, 99.9, 99.99]), - size=(400, 300), background='w') +brain = stc.plot( + hemi="rh", + views="lat", + subjects_dir=subjects_dir, + initial_time=dip_opm.times[idx], + clim=dict(kind="percent", lims=[99, 99.9, 99.99]), + size=(400, 300), + background="w", +) diff --git a/examples/datasets/spm_faces_dataset_sgskip.py b/examples/datasets/spm_faces_dataset_sgskip.py index 875cc2eb5d5..8806059b395 100644 --- a/examples/datasets/spm_faces_dataset_sgskip.py +++ b/examples/datasets/spm_faces_dataset_sgskip.py @@ -35,26 +35,26 @@ print(__doc__) data_path = spm_face.data_path() -subjects_dir = data_path / 'subjects' -spm_path = data_path / 'MEG' / 'spm' +subjects_dir = data_path / "subjects" +spm_path = data_path / "MEG" / "spm" # %% # Load and filter data, set up epochs -raw_fname = spm_path / 'SPM_CTF_MEG_example_faces%d_3D.ds' +raw_fname = spm_path / "SPM_CTF_MEG_example_faces%d_3D.ds" raw = io.read_raw_ctf(raw_fname % 1, preload=True) # Take first run # Here to save memory and time we'll downsample heavily -- this is not # advised for real data as it can effectively jitter events! -raw.resample(120., npad='auto') +raw.resample(120.0, npad="auto") -picks = mne.pick_types(raw.info, meg=True, exclude='bads') -raw.filter(1, 30, method='fir', fir_design='firwin') +picks = mne.pick_types(raw.info, meg=True, exclude="bads") +raw.filter(1, 30, method="fir", fir_design="firwin") -events = mne.find_events(raw, stim_channel='UPPT001') +events = mne.find_events(raw, stim_channel="UPPT001") # plot the events to get an idea of the paradigm -mne.viz.plot_events(events, raw.info['sfreq']) +mne.viz.plot_events(events, raw.info["sfreq"]) event_ids = {"faces": 1, "scrambled": 2} @@ -62,16 +62,25 @@ baseline = None # no baseline as high-pass is applied reject = dict(mag=5e-12) -epochs = mne.Epochs(raw, events, event_ids, tmin, tmax, picks=picks, - baseline=baseline, preload=True, reject=reject) +epochs = mne.Epochs( + raw, + events, + event_ids, + tmin, + tmax, + picks=picks, + baseline=baseline, + preload=True, + reject=reject, +) # Fit ICA, find and remove major artifacts -ica = ICA(n_components=0.95, max_iter='auto', random_state=0) +ica = ICA(n_components=0.95, max_iter="auto", random_state=0) ica.fit(raw, decim=1, reject=reject) # compute correlation scores, get bad indices sorted by score -eog_epochs = create_eog_epochs(raw, ch_name='MRT31-2908', reject=reject) -eog_inds, eog_scores = ica.find_bads_eog(eog_epochs, ch_name='MRT31-2908') +eog_epochs = create_eog_epochs(raw, ch_name="MRT31-2908", reject=reject) +eog_inds, eog_scores = ica.find_bads_eog(eog_epochs, ch_name="MRT31-2908") ica.plot_scores(eog_scores, eog_inds) # see scores the selection is based on ica.plot_components(eog_inds) # view topographic sensitivity of components ica.exclude += eog_inds[:1] # we saw the 2nd ECG component looked too dipolar @@ -90,18 +99,18 @@ plt.show() # estimate noise covarariance -noise_cov = mne.compute_covariance(epochs, tmax=0, method='shrunk', - rank=None) +noise_cov = mne.compute_covariance(epochs, tmax=0, method="shrunk", rank=None) # %% # Visualize fields on MEG helmet # The transformation here was aligned using the dig-montage. It's included in # the spm_faces dataset and is named SPM_dig_montage.fif. -trans_fname = spm_path / 'SPM_CTF_MEG_example_faces1_3D_raw-trans.fif' +trans_fname = spm_path / "SPM_CTF_MEG_example_faces1_3D_raw-trans.fif" -maps = mne.make_field_map(evoked[0], trans_fname, subject='spm', - subjects_dir=subjects_dir, n_jobs=None) +maps = mne.make_field_map( + evoked[0], trans_fname, subject="spm", subjects_dir=subjects_dir, n_jobs=None +) evoked[0].plot_field(maps, time=0.170) @@ -113,25 +122,31 @@ # %% # Compute forward model -src = subjects_dir / 'spm' / 'bem' / 'spm-oct-6-src.fif' -bem = subjects_dir / 'spm' / 'bem' / 'spm-5120-5120-5120-bem-sol.fif' +src = subjects_dir / "spm" / "bem" / "spm-oct-6-src.fif" +bem = subjects_dir / "spm" / "bem" / "spm-5120-5120-5120-bem-sol.fif" forward = mne.make_forward_solution(contrast.info, trans_fname, src, bem) # %% # Compute inverse solution snr = 3.0 -lambda2 = 1.0 / snr ** 2 -method = 'dSPM' +lambda2 = 1.0 / snr**2 +method = "dSPM" -inverse_operator = make_inverse_operator(contrast.info, forward, noise_cov, - loose=0.2, depth=0.8) +inverse_operator = make_inverse_operator( + contrast.info, forward, noise_cov, loose=0.2, depth=0.8 +) # Compute inverse solution on contrast stc = apply_inverse(contrast, inverse_operator, lambda2, method, pick_ori=None) # stc.save('spm_%s_dSPM_inverse' % contrast.comment) # Plot contrast in 3D with mne.viz.Brain if available -brain = stc.plot(hemi='both', subjects_dir=subjects_dir, initial_time=0.170, - views=['ven'], clim={'kind': 'value', 'lims': [3., 6., 9.]}) +brain = stc.plot( + hemi="both", + subjects_dir=subjects_dir, + initial_time=0.170, + views=["ven"], + clim={"kind": "value", "lims": [3.0, 6.0, 9.0]}, +) # brain.save_image('dSPM_map.png') diff --git a/examples/decoding/decoding_csp_eeg.py b/examples/decoding/decoding_csp_eeg.py index beef85bbdc0..1ee2f0ce87d 100644 --- a/examples/decoding/decoding_csp_eeg.py +++ b/examples/decoding/decoding_csp_eeg.py @@ -40,7 +40,7 @@ # avoid classification of evoked responses by using epochs that start 1s after # cue onset. -tmin, tmax = -1., 4. +tmin, tmax = -1.0, 4.0 event_id = dict(hands=2, feet=3) subject = 1 runs = [6, 10, 14] # motor imagery: hands vs feet @@ -48,22 +48,30 @@ raw_fnames = eegbci.load_data(subject, runs) raw = concatenate_raws([read_raw_edf(f, preload=True) for f in raw_fnames]) eegbci.standardize(raw) # set channel names -montage = make_standard_montage('standard_1005') +montage = make_standard_montage("standard_1005") raw.set_montage(montage) # Apply band-pass filter -raw.filter(7., 30., fir_design='firwin', skip_by_annotation='edge') +raw.filter(7.0, 30.0, fir_design="firwin", skip_by_annotation="edge") events, _ = events_from_annotations(raw, event_id=dict(T1=2, T2=3)) -picks = pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False, - exclude='bads') +picks = pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False, exclude="bads") # Read epochs (train will be done only between 1 and 2s) # Testing will be done with a running classifier -epochs = Epochs(raw, events, event_id, tmin, tmax, proj=True, picks=picks, - baseline=None, preload=True) -epochs_train = epochs.copy().crop(tmin=1., tmax=2.) +epochs = Epochs( + raw, + events, + event_id, + tmin, + tmax, + proj=True, + picks=picks, + baseline=None, + preload=True, +) +epochs_train = epochs.copy().crop(tmin=1.0, tmax=2.0) labels = epochs.events[:, -1] - 2 # %% @@ -81,25 +89,26 @@ csp = CSP(n_components=4, reg=None, log=True, norm_trace=False) # Use scikit-learn Pipeline with cross_val_score function -clf = Pipeline([('CSP', csp), ('LDA', lda)]) +clf = Pipeline([("CSP", csp), ("LDA", lda)]) scores = cross_val_score(clf, epochs_data_train, labels, cv=cv, n_jobs=None) # Printing the results class_balance = np.mean(labels == labels[0]) -class_balance = max(class_balance, 1. - class_balance) -print("Classification accuracy: %f / Chance level: %f" % (np.mean(scores), - class_balance)) +class_balance = max(class_balance, 1.0 - class_balance) +print( + "Classification accuracy: %f / Chance level: %f" % (np.mean(scores), class_balance) +) # plot CSP patterns estimated on full data for visualization csp.fit_transform(epochs_data, labels) -csp.plot_patterns(epochs.info, ch_type='eeg', units='Patterns (AU)', size=1.5) +csp.plot_patterns(epochs.info, ch_type="eeg", units="Patterns (AU)", size=1.5) # %% # Look at performance over time -sfreq = raw.info['sfreq'] -w_length = int(sfreq * 0.5) # running classifier: window length +sfreq = raw.info["sfreq"] +w_length = int(sfreq * 0.5) # running classifier: window length w_step = int(sfreq * 0.1) # running classifier: window step size w_start = np.arange(0, epochs_data.shape[2] - w_length, w_step) @@ -117,21 +126,21 @@ # running classifier: test classifier on sliding window score_this_window = [] for n in w_start: - X_test = csp.transform(epochs_data[test_idx][:, :, n:(n + w_length)]) + X_test = csp.transform(epochs_data[test_idx][:, :, n : (n + w_length)]) score_this_window.append(lda.score(X_test, y_test)) scores_windows.append(score_this_window) # Plot scores over time -w_times = (w_start + w_length / 2.) / sfreq + epochs.tmin +w_times = (w_start + w_length / 2.0) / sfreq + epochs.tmin plt.figure() -plt.plot(w_times, np.mean(scores_windows, 0), label='Score') -plt.axvline(0, linestyle='--', color='k', label='Onset') -plt.axhline(0.5, linestyle='-', color='k', label='Chance') -plt.xlabel('time (s)') -plt.ylabel('classification accuracy') -plt.title('Classification score over time') -plt.legend(loc='lower right') +plt.plot(w_times, np.mean(scores_windows, 0), label="Score") +plt.axvline(0, linestyle="--", color="k", label="Onset") +plt.axhline(0.5, linestyle="-", color="k", label="Chance") +plt.xlabel("time (s)") +plt.ylabel("classification accuracy") +plt.title("Classification score over time") +plt.legend(loc="lower right") plt.show() ############################################################################## diff --git a/examples/decoding/decoding_csp_timefreq.py b/examples/decoding/decoding_csp_timefreq.py index 6407646910b..3b048587ec1 100644 --- a/examples/decoding/decoding_csp_timefreq.py +++ b/examples/decoding/decoding_csp_timefreq.py @@ -44,22 +44,24 @@ raw = concatenate_raws([read_raw_edf(f) for f in raw_fnames]) # Extract information from the raw file -sfreq = raw.info['sfreq'] +sfreq = raw.info["sfreq"] events, _ = events_from_annotations(raw, event_id=dict(T1=2, T2=3)) -raw.pick_types(meg=False, eeg=True, stim=False, eog=False, exclude='bads') +raw.pick_types(meg=False, eeg=True, stim=False, eog=False, exclude="bads") raw.load_data() # Assemble the classifier using scikit-learn pipeline -clf = make_pipeline(CSP(n_components=4, reg=None, log=True, norm_trace=False), - LinearDiscriminantAnalysis()) +clf = make_pipeline( + CSP(n_components=4, reg=None, log=True, norm_trace=False), + LinearDiscriminantAnalysis(), +) n_splits = 3 # for cross-validation, 5 is better, here we use 3 for speed cv = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42) # Classification & time-frequency parameters -tmin, tmax = -.200, 2.000 -n_cycles = 10. # how many complete cycles: used to define window size -min_freq = 8. -max_freq = 20. +tmin, tmax = -0.200, 2.000 +n_cycles = 10.0 # how many complete cycles: used to define window size +min_freq = 8.0 +max_freq = 20.0 n_freqs = 6 # how many frequency bins to use # Assemble list of frequency range tuples @@ -67,7 +69,7 @@ freq_ranges = list(zip(freqs[:-1], freqs[1:])) # make freqs list of tuples # Infer window spacing from the max freq and number of cycles to avoid gaps -window_spacing = (n_cycles / np.max(freqs) / 2.) +window_spacing = n_cycles / np.max(freqs) / 2.0 centered_w_times = np.arange(tmin, tmax, window_spacing)[1:] n_windows = len(centered_w_times) @@ -82,39 +84,50 @@ # Loop through each frequency range of interest for freq, (fmin, fmax) in enumerate(freq_ranges): - # Infer window size based on the frequency being used - w_size = n_cycles / ((fmax + fmin) / 2.) # in seconds + w_size = n_cycles / ((fmax + fmin) / 2.0) # in seconds # Apply band-pass filter to isolate the specified frequencies - raw_filter = raw.copy().filter(fmin, fmax, fir_design='firwin', - skip_by_annotation='edge') + raw_filter = raw.copy().filter( + fmin, fmax, fir_design="firwin", skip_by_annotation="edge" + ) # Extract epochs from filtered data, padded by window size - epochs = Epochs(raw_filter, events, event_id, tmin - w_size, tmax + w_size, - proj=False, baseline=None, preload=True) + epochs = Epochs( + raw_filter, + events, + event_id, + tmin - w_size, + tmax + w_size, + proj=False, + baseline=None, + preload=True, + ) epochs.drop_bad() y = le.fit_transform(epochs.events[:, 2]) X = epochs.get_data() # Save mean scores over folds for each frequency and time window - freq_scores[freq] = np.mean(cross_val_score( - estimator=clf, X=X, y=y, scoring='roc_auc', cv=cv), axis=0) + freq_scores[freq] = np.mean( + cross_val_score(estimator=clf, X=X, y=y, scoring="roc_auc", cv=cv), axis=0 + ) # %% # Plot frequency results -plt.bar(freqs[:-1], freq_scores, width=np.diff(freqs)[0], - align='edge', edgecolor='black') +plt.bar( + freqs[:-1], freq_scores, width=np.diff(freqs)[0], align="edge", edgecolor="black" +) plt.xticks(freqs) plt.ylim([0, 1]) -plt.axhline(len(epochs['feet']) / len(epochs), color='k', linestyle='--', - label='chance level') +plt.axhline( + len(epochs["feet"]) / len(epochs), color="k", linestyle="--", label="chance level" +) plt.legend() -plt.xlabel('Frequency (Hz)') -plt.ylabel('Decoding Scores') -plt.title('Frequency Decoding Scores') +plt.xlabel("Frequency (Hz)") +plt.ylabel("Decoding Scores") +plt.title("Frequency Decoding Scores") # %% # Loop through frequencies and time, apply classifier and save scores @@ -124,41 +137,53 @@ # Loop through each frequency range of interest for freq, (fmin, fmax) in enumerate(freq_ranges): - # Infer window size based on the frequency being used - w_size = n_cycles / ((fmax + fmin) / 2.) # in seconds + w_size = n_cycles / ((fmax + fmin) / 2.0) # in seconds # Apply band-pass filter to isolate the specified frequencies - raw_filter = raw.copy().filter(fmin, fmax, fir_design='firwin', - skip_by_annotation='edge') + raw_filter = raw.copy().filter( + fmin, fmax, fir_design="firwin", skip_by_annotation="edge" + ) # Extract epochs from filtered data, padded by window size - epochs = Epochs(raw_filter, events, event_id, tmin - w_size, tmax + w_size, - proj=False, baseline=None, preload=True) + epochs = Epochs( + raw_filter, + events, + event_id, + tmin - w_size, + tmax + w_size, + proj=False, + baseline=None, + preload=True, + ) epochs.drop_bad() y = le.fit_transform(epochs.events[:, 2]) # Roll covariance, csp and lda over time for t, w_time in enumerate(centered_w_times): - # Center the min and max of the window - w_tmin = w_time - w_size / 2. - w_tmax = w_time + w_size / 2. + w_tmin = w_time - w_size / 2.0 + w_tmax = w_time + w_size / 2.0 # Crop data into time-window of interest X = epochs.copy().crop(w_tmin, w_tmax).get_data() # Save mean scores over folds for each frequency and time window - tf_scores[freq, t] = np.mean(cross_val_score( - estimator=clf, X=X, y=y, scoring='roc_auc', cv=cv), axis=0) + tf_scores[freq, t] = np.mean( + cross_val_score(estimator=clf, X=X, y=y, scoring="roc_auc", cv=cv), axis=0 + ) # %% # Plot time-frequency results # Set up time frequency object -av_tfr = AverageTFR(create_info(['freq'], sfreq), tf_scores[np.newaxis, :], - centered_w_times, freqs[1:], 1) +av_tfr = AverageTFR( + create_info(["freq"], sfreq), + tf_scores[np.newaxis, :], + centered_w_times, + freqs[1:], + 1, +) chance = np.mean(y) # set chance level to white in the plot -av_tfr.plot([0], vmin=chance, title="Time-Frequency Decoding Scores", - cmap=plt.cm.Reds) +av_tfr.plot([0], vmin=chance, title="Time-Frequency Decoding Scores", cmap=plt.cm.Reds) diff --git a/examples/decoding/decoding_rsa_sgskip.py b/examples/decoding/decoding_rsa_sgskip.py index ba1be187372..7cc6dbfbb01 100644 --- a/examples/decoding/decoding_rsa_sgskip.py +++ b/examples/decoding/decoding_rsa_sgskip.py @@ -50,7 +50,7 @@ data_path = visual_92_categories.data_path() # Define stimulus - trigger mapping -fname = data_path / 'visual_stimuli.csv' +fname = data_path / "visual_stimuli.csv" conds = read_csv(fname) print(conds.head(5)) @@ -64,38 +64,48 @@ conditions = [] for c in conds.values: cond_tags = list(c[:2]) - cond_tags += [('not-' if i == 0 else '') + conds.columns[k] - for k, i in enumerate(c[2:], 2)] - conditions.append('/'.join(map(str, cond_tags))) + cond_tags += [ + ("not-" if i == 0 else "") + conds.columns[k] for k, i in enumerate(c[2:], 2) + ] + conditions.append("/".join(map(str, cond_tags))) print(conditions[:10]) ############################################################################## # Let's make the event_id dictionary event_id = dict(zip(conditions, conds.trigger + 1)) -event_id['0/human bodypart/human/not-face/animal/natural'] +event_id["0/human bodypart/human/not-face/animal/natural"] ############################################################################## # Read MEG data n_runs = 4 # 4 for full data (use less to speed up computations) -fnames = [data_path / f'sample_subject_{b}_tsss_mc.fif' for b in range(n_runs)] -raws = [read_raw_fif(fname, verbose='error', on_split_missing='ignore') - for fname in fnames] # ignore filename warnings +fnames = [data_path / f"sample_subject_{b}_tsss_mc.fif" for b in range(n_runs)] +raws = [ + read_raw_fif(fname, verbose="error", on_split_missing="ignore") for fname in fnames +] # ignore filename warnings raw = concatenate_raws(raws) -events = mne.find_events(raw, min_duration=.002) +events = mne.find_events(raw, min_duration=0.002) events = events[events[:, 2] <= max_trigger] ############################################################################## # Epoch data picks = mne.pick_types(raw.info, meg=True) -epochs = mne.Epochs(raw, events=events, event_id=event_id, baseline=None, - picks=picks, tmin=-.1, tmax=.500, preload=True) +epochs = mne.Epochs( + raw, + events=events, + event_id=event_id, + baseline=None, + picks=picks, + tmin=-0.1, + tmax=0.500, + preload=True, +) ############################################################################## # Let's plot some conditions -epochs['face'].average().plot() -epochs['not-face'].average().plot() +epochs["face"].average().plot() +epochs["not-face"].average().plot() ############################################################################## # Representational Similarity Analysis (RSA) is a neuroimaging-specific @@ -112,9 +122,9 @@ # Classify using the average signal in the window 50ms to 300ms # to focus the classifier on the time interval with best SNR. -clf = make_pipeline(StandardScaler(), - LogisticRegression(C=1, solver='liblinear', - multi_class='auto')) +clf = make_pipeline( + StandardScaler(), LogisticRegression(C=1, solver="liblinear", multi_class="auto") +) X = epochs.copy().crop(0.05, 0.3).get_data().mean(axis=2) y = epochs.events[:, 2] @@ -139,15 +149,15 @@ ############################################################################## # Plot -labels = [''] * 5 + ['face'] + [''] * 11 + ['bodypart'] + [''] * 6 +labels = [""] * 5 + ["face"] + [""] * 11 + ["bodypart"] + [""] * 6 fig, ax = plt.subplots(1) -im = ax.matshow(confusion, cmap='RdBu_r', clim=[0.3, 0.7]) +im = ax.matshow(confusion, cmap="RdBu_r", clim=[0.3, 0.7]) ax.set_yticks(range(len(classes))) ax.set_yticklabels(labels) ax.set_xticks(range(len(classes))) -ax.set_xticklabels(labels, rotation=40, ha='left') -ax.axhline(11.5, color='k') -ax.axvline(11.5, color='k') +ax.set_xticklabels(labels, rotation=40, ha="left") +ax.axhline(11.5, color="k") +ax.axvline(11.5, color="k") plt.colorbar(im) plt.tight_layout() plt.show() @@ -157,19 +167,25 @@ # summarized with dimensionality reduction using multi-dimensional scaling [1]. # See how the face samples cluster together. fig, ax = plt.subplots(1) -mds = MDS(2, random_state=0, dissimilarity='precomputed') +mds = MDS(2, random_state=0, dissimilarity="precomputed") chance = 0.5 summary = mds.fit_transform(chance - confusion) -cmap = plt.colormaps['rainbow'] -colors = ['r', 'b'] -names = list(conds['condition'].values) +cmap = plt.colormaps["rainbow"] +colors = ["r", "b"] +names = list(conds["condition"].values) for color, name in zip(colors, set(names)): sel = np.where([this_name == name for this_name in names])[0] - size = 500 if name == 'human face' else 100 - ax.scatter(summary[sel, 0], summary[sel, 1], s=size, - facecolors=color, label=name, edgecolors='k') -ax.axis('off') -ax.legend(loc='lower right', scatterpoints=1, ncol=2) + size = 500 if name == "human face" else 100 + ax.scatter( + summary[sel, 0], + summary[sel, 1], + s=size, + facecolors=color, + label=name, + edgecolors="k", + ) +ax.axis("off") +ax.legend(loc="lower right", scatterpoints=1, ncol=2) plt.tight_layout() plt.show() diff --git a/examples/decoding/decoding_spatio_temporal_source.py b/examples/decoding/decoding_spatio_temporal_source.py index 476b4d170c6..ad96720f640 100644 --- a/examples/decoding/decoding_spatio_temporal_source.py +++ b/examples/decoding/decoding_spatio_temporal_source.py @@ -31,42 +31,51 @@ import mne from mne.minimum_norm import apply_inverse_epochs, read_inverse_operator -from mne.decoding import (cross_val_multiscore, LinearModel, SlidingEstimator, - get_coef) +from mne.decoding import cross_val_multiscore, LinearModel, SlidingEstimator, get_coef print(__doc__) data_path = mne.datasets.sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -fname_fwd = meg_path / 'sample_audvis-meg-oct-6-fwd.fif' -fname_evoked = meg_path / 'sample_audvis-ave.fif' -subjects_dir = data_path / 'subjects' +meg_path = data_path / "MEG" / "sample" +fname_fwd = meg_path / "sample_audvis-meg-oct-6-fwd.fif" +fname_evoked = meg_path / "sample_audvis-ave.fif" +subjects_dir = data_path / "subjects" # %% # Set parameters -raw_fname = meg_path / 'sample_audvis_filt-0-40_raw.fif' -event_fname = meg_path / 'sample_audvis_filt-0-40_raw-eve.fif' -fname_cov = meg_path / 'sample_audvis-cov.fif' -fname_inv = meg_path / 'sample_audvis-meg-oct-6-meg-inv.fif' +raw_fname = meg_path / "sample_audvis_filt-0-40_raw.fif" +event_fname = meg_path / "sample_audvis_filt-0-40_raw-eve.fif" +fname_cov = meg_path / "sample_audvis-cov.fif" +fname_inv = meg_path / "sample_audvis-meg-oct-6-meg-inv.fif" tmin, tmax = -0.2, 0.8 event_id = dict(aud_r=2, vis_r=4) # load contra-lateral conditions # Setup for reading the raw data raw = mne.io.read_raw_fif(raw_fname, preload=True) -raw.filter(None, 10., fir_design='firwin') +raw.filter(None, 10.0, fir_design="firwin") events = mne.read_events(event_fname) # Set up pick list: MEG - bad channels (modify to your needs) -raw.info['bads'] += ['MEG 2443'] # mark bads -picks = mne.pick_types(raw.info, meg=True, eeg=False, stim=True, eog=True, - exclude='bads') +raw.info["bads"] += ["MEG 2443"] # mark bads +picks = mne.pick_types( + raw.info, meg=True, eeg=False, stim=True, eog=True, exclude="bads" +) # Read epochs -epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=True, - picks=picks, baseline=(None, 0), preload=True, - reject=dict(grad=4000e-13, eog=150e-6), - decim=5) # decimate to save memory and increase speed +epochs = mne.Epochs( + raw, + events, + event_id, + tmin, + tmax, + proj=True, + picks=picks, + baseline=(None, 0), + preload=True, + reject=dict(grad=4000e-13, eog=150e-6), + decim=5, +) # decimate to save memory and increase speed # %% # Compute inverse solution @@ -74,9 +83,14 @@ noise_cov = mne.read_cov(fname_cov) inverse_operator = read_inverse_operator(fname_inv) -stcs = apply_inverse_epochs(epochs, inverse_operator, - lambda2=1.0 / snr ** 2, verbose=False, - method="dSPM", pick_ori="normal") +stcs = apply_inverse_epochs( + epochs, + inverse_operator, + lambda2=1.0 / snr**2, + verbose=False, + method="dSPM", + pick_ori="normal", +) # %% # Decoding in sensor space using a logistic regression @@ -86,19 +100,21 @@ y = epochs.events[:, 2] # prepare a series of classifier applied at each time sample -clf = make_pipeline(StandardScaler(), # z-score normalization - SelectKBest(f_classif, k=500), # select features for speed - LinearModel(LogisticRegression(C=1, solver='liblinear'))) -time_decod = SlidingEstimator(clf, scoring='roc_auc') +clf = make_pipeline( + StandardScaler(), # z-score normalization + SelectKBest(f_classif, k=500), # select features for speed + LinearModel(LogisticRegression(C=1, solver="liblinear")), +) +time_decod = SlidingEstimator(clf, scoring="roc_auc") # Run cross-validated decoding analyses: scores = cross_val_multiscore(time_decod, X, y, cv=5, n_jobs=None) # Plot average decoding scores of 5 splits fig, ax = plt.subplots(1) -ax.plot(epochs.times, scores.mean(0), label='score') -ax.axhline(.5, color='k', linestyle='--', label='chance') -ax.axvline(0, color='k') +ax.plot(epochs.times, scores.mean(0), label="score") +ax.axhline(0.5, color="k", linestyle="--", label="chance") +ax.axvline(0, color="k") plt.legend() # %% @@ -109,13 +125,22 @@ time_decod.fit(X, y) # Retrieve patterns after inversing the z-score normalization step: -patterns = get_coef(time_decod, 'patterns_', inverse_transform=True) +patterns = get_coef(time_decod, "patterns_", inverse_transform=True) stc = stcs[0] # for convenience, lookup parameters from first stc vertices = [stc.lh_vertno, np.array([], int)] # empty array for right hemi -stc_feat = mne.SourceEstimate(np.abs(patterns), vertices=vertices, - tmin=stc.tmin, tstep=stc.tstep, subject='sample') - -brain = stc_feat.plot(views=['lat'], transparent=True, - initial_time=0.1, time_unit='s', - subjects_dir=subjects_dir) +stc_feat = mne.SourceEstimate( + np.abs(patterns), + vertices=vertices, + tmin=stc.tmin, + tstep=stc.tstep, + subject="sample", +) + +brain = stc_feat.plot( + views=["lat"], + transparent=True, + initial_time=0.1, + time_unit="s", + subjects_dir=subjects_dir, +) diff --git a/examples/decoding/decoding_spoc_CMC.py b/examples/decoding/decoding_spoc_CMC.py index f1fb8c86400..81acb0b9cc4 100644 --- a/examples/decoding/decoding_spoc_CMC.py +++ b/examples/decoding/decoding_spoc_CMC.py @@ -35,32 +35,31 @@ from sklearn.model_selection import KFold, cross_val_predict # Define parameters -fname = data_path() / 'SubjectCMC.ds' +fname = data_path() / "SubjectCMC.ds" raw = mne.io.read_raw_ctf(fname) -raw.crop(50., 200.) # crop for memory purposes +raw.crop(50.0, 200.0) # crop for memory purposes # Filter muscular activity to only keep high frequencies -emg = raw.copy().pick_channels(['EMGlft']).load_data() -emg.filter(20., None) +emg = raw.copy().pick_channels(["EMGlft"]).load_data() +emg.filter(20.0, None) # Filter MEG data to focus on beta band raw.pick_types(meg=True, ref_meg=True, eeg=False, eog=False).load_data() -raw.filter(15., 30.) +raw.filter(15.0, 30.0) # Build epochs as sliding windows over the continuous raw file events = mne.make_fixed_length_events(raw, id=1, duration=0.75) # Epoch length is 1.5 second -meg_epochs = Epochs(raw, events, tmin=0., tmax=1.5, baseline=None, - detrend=1, decim=12) -emg_epochs = Epochs(emg, events, tmin=0., tmax=1.5, baseline=None) +meg_epochs = Epochs(raw, events, tmin=0.0, tmax=1.5, baseline=None, detrend=1, decim=12) +emg_epochs = Epochs(emg, events, tmin=0.0, tmax=1.5, baseline=None) # Prepare classification X = meg_epochs.get_data() y = emg_epochs.get_data().var(axis=2)[:, 0] # target is EMG power # Classification pipeline with SPoC spatial filtering and Ridge Regression -spoc = SPoC(n_components=2, log=True, reg='oas', rank='full') +spoc = SPoC(n_components=2, log=True, reg="oas", rank="full") clf = make_pipeline(spoc, Ridge()) # Define a two fold cross-validation cv = KFold(n_splits=2, shuffle=False) @@ -71,11 +70,11 @@ # Plot the True EMG power and the EMG power predicted from MEG data fig, ax = plt.subplots(1, 1, figsize=[10, 4]) times = raw.times[meg_epochs.events[:, 0] - raw.first_samp] -ax.plot(times, y_preds, color='b', label='Predicted EMG') -ax.plot(times, y, color='r', label='True EMG') -ax.set_xlabel('Time (s)') -ax.set_ylabel('EMG Power') -ax.set_title('SPoC MEG Predictions') +ax.plot(times, y_preds, color="b", label="Predicted EMG") +ax.plot(times, y, color="r", label="True EMG") +ax.set_xlabel("Time (s)") +ax.set_ylabel("EMG Power") +ax.set_title("SPoC MEG Predictions") plt.legend() mne.viz.tight_layout() plt.show() diff --git a/examples/decoding/decoding_time_generalization_conditions.py b/examples/decoding/decoding_time_generalization_conditions.py index d39797e6561..08ca0d9c0c3 100644 --- a/examples/decoding/decoding_time_generalization_conditions.py +++ b/examples/decoding/decoding_time_generalization_conditions.py @@ -34,56 +34,78 @@ # Preprocess data data_path = sample.data_path() # Load and filter data, set up epochs -meg_path = data_path / 'MEG' / 'sample' -raw_fname = meg_path / 'sample_audvis_filt-0-40_raw.fif' -events_fname = meg_path / 'sample_audvis_filt-0-40_raw-eve.fif' +meg_path = data_path / "MEG" / "sample" +raw_fname = meg_path / "sample_audvis_filt-0-40_raw.fif" +events_fname = meg_path / "sample_audvis_filt-0-40_raw-eve.fif" raw = mne.io.read_raw_fif(raw_fname, preload=True) -picks = mne.pick_types(raw.info, meg=True, exclude='bads') # Pick MEG channels -raw.filter(1., 30., fir_design='firwin') # Band pass filtering signals +picks = mne.pick_types(raw.info, meg=True, exclude="bads") # Pick MEG channels +raw.filter(1.0, 30.0, fir_design="firwin") # Band pass filtering signals events = mne.read_events(events_fname) -event_id = {'Auditory/Left': 1, 'Auditory/Right': 2, - 'Visual/Left': 3, 'Visual/Right': 4} +event_id = { + "Auditory/Left": 1, + "Auditory/Right": 2, + "Visual/Left": 3, + "Visual/Right": 4, +} tmin = -0.050 tmax = 0.400 # decimate to make the example faster to run, but then use verbose='error' in # the Epochs constructor to suppress warning about decimation causing aliasing decim = 2 -epochs = mne.Epochs(raw, events, event_id=event_id, tmin=tmin, tmax=tmax, - proj=True, picks=picks, baseline=None, preload=True, - reject=dict(mag=5e-12), decim=decim, verbose='error') +epochs = mne.Epochs( + raw, + events, + event_id=event_id, + tmin=tmin, + tmax=tmax, + proj=True, + picks=picks, + baseline=None, + preload=True, + reject=dict(mag=5e-12), + decim=decim, + verbose="error", +) # %% # We will train the classifier on all left visual vs auditory trials # and test on all right visual vs auditory trials. clf = make_pipeline( StandardScaler(), - LogisticRegression(solver='liblinear') # liblinear is faster than lbfgs + LogisticRegression(solver="liblinear"), # liblinear is faster than lbfgs ) -time_gen = GeneralizingEstimator(clf, scoring='roc_auc', n_jobs=None, - verbose=True) +time_gen = GeneralizingEstimator(clf, scoring="roc_auc", n_jobs=None, verbose=True) # Fit classifiers on the epochs where the stimulus was presented to the left. # Note that the experimental condition y indicates auditory or visual -time_gen.fit(X=epochs['Left'].get_data(), - y=epochs['Left'].events[:, 2] > 2) +time_gen.fit(X=epochs["Left"].get_data(), y=epochs["Left"].events[:, 2] > 2) # %% # Score on the epochs where the stimulus was presented to the right. -scores = time_gen.score(X=epochs['Right'].get_data(), - y=epochs['Right'].events[:, 2] > 2) +scores = time_gen.score( + X=epochs["Right"].get_data(), y=epochs["Right"].events[:, 2] > 2 +) # %% # Plot fig, ax = plt.subplots(constrained_layout=True) -im = ax.matshow(scores, vmin=0, vmax=1., cmap='RdBu_r', origin='lower', - extent=epochs.times[[0, -1, 0, -1]]) -ax.axhline(0., color='k') -ax.axvline(0., color='k') -ax.xaxis.set_ticks_position('bottom') -ax.set_xlabel('Condition: "Right"\nTesting Time (s)',) +im = ax.matshow( + scores, + vmin=0, + vmax=1.0, + cmap="RdBu_r", + origin="lower", + extent=epochs.times[[0, -1, 0, -1]], +) +ax.axhline(0.0, color="k") +ax.axvline(0.0, color="k") +ax.xaxis.set_ticks_position("bottom") +ax.set_xlabel( + 'Condition: "Right"\nTesting Time (s)', +) ax.set_ylabel('Condition: "Left"\nTraining Time (s)') -ax.set_title('Generalization across time and condition', fontweight='bold') -fig.colorbar(im, ax=ax, label='Performance (ROC AUC)') +ax.set_title("Generalization across time and condition", fontweight="bold") +fig.colorbar(im, ax=ax, label="Performance (ROC AUC)") plt.show() ############################################################################## diff --git a/examples/decoding/decoding_unsupervised_spatial_filter.py b/examples/decoding/decoding_unsupervised_spatial_filter.py index a3fab432ace..d215203ac3c 100644 --- a/examples/decoding/decoding_unsupervised_spatial_filter.py +++ b/examples/decoding/decoding_unsupervised_spatial_filter.py @@ -32,22 +32,32 @@ data_path = sample.data_path() # Load and filter data, set up epochs -meg_path = data_path / 'MEG' / 'sample' -raw_fname = meg_path / 'sample_audvis_filt-0-40_raw.fif' -event_fname = meg_path / 'sample_audvis_filt-0-40_raw-eve.fif' +meg_path = data_path / "MEG" / "sample" +raw_fname = meg_path / "sample_audvis_filt-0-40_raw.fif" +event_fname = meg_path / "sample_audvis_filt-0-40_raw-eve.fif" tmin, tmax = -0.1, 0.3 event_id = dict(aud_l=1, aud_r=2, vis_l=3, vis_r=4) raw = mne.io.read_raw_fif(raw_fname, preload=True) -raw.filter(1, 20, fir_design='firwin') +raw.filter(1, 20, fir_design="firwin") events = mne.read_events(event_fname) -picks = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False, - exclude='bads') - -epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=False, - picks=picks, baseline=None, preload=True, - verbose=False) +picks = mne.pick_types( + raw.info, meg=False, eeg=True, stim=False, eog=False, exclude="bads" +) + +epochs = mne.Epochs( + raw, + events, + event_id, + tmin, + tmax, + proj=False, + picks=picks, + baseline=None, + preload=True, + verbose=False, +) X = epochs.get_data() @@ -55,19 +65,22 @@ # Transform data with PCA computed on the average ie evoked response pca = UnsupervisedSpatialFilter(PCA(30), average=False) pca_data = pca.fit_transform(X) -ev = mne.EvokedArray(np.mean(pca_data, axis=0), - mne.create_info(30, epochs.info['sfreq'], - ch_types='eeg'), tmin=tmin) -ev.plot(show=False, window_title="PCA", time_unit='s') +ev = mne.EvokedArray( + np.mean(pca_data, axis=0), + mne.create_info(30, epochs.info["sfreq"], ch_types="eeg"), + tmin=tmin, +) +ev.plot(show=False, window_title="PCA", time_unit="s") ############################################################################## # Transform data with ICA computed on the raw epochs (no averaging) -ica = UnsupervisedSpatialFilter( - FastICA(30, whiten='unit-variance'), average=False) +ica = UnsupervisedSpatialFilter(FastICA(30, whiten="unit-variance"), average=False) ica_data = ica.fit_transform(X) -ev1 = mne.EvokedArray(np.mean(ica_data, axis=0), - mne.create_info(30, epochs.info['sfreq'], - ch_types='eeg'), tmin=tmin) -ev1.plot(show=False, window_title='ICA', time_unit='s') +ev1 = mne.EvokedArray( + np.mean(ica_data, axis=0), + mne.create_info(30, epochs.info["sfreq"], ch_types="eeg"), + tmin=tmin, +) +ev1.plot(show=False, window_title="ICA", time_unit="s") plt.show() diff --git a/examples/decoding/decoding_xdawn_eeg.py b/examples/decoding/decoding_xdawn_eeg.py index 9ec65f54976..3bdff716228 100644 --- a/examples/decoding/decoding_xdawn_eeg.py +++ b/examples/decoding/decoding_xdawn_eeg.py @@ -37,32 +37,45 @@ # %% # Set parameters and read data -meg_path = data_path / 'MEG' / 'sample' -raw_fname = meg_path / 'sample_audvis_filt-0-40_raw.fif' -event_fname = meg_path / 'sample_audvis_filt-0-40_raw-eve.fif' +meg_path = data_path / "MEG" / "sample" +raw_fname = meg_path / "sample_audvis_filt-0-40_raw.fif" +event_fname = meg_path / "sample_audvis_filt-0-40_raw-eve.fif" tmin, tmax = -0.1, 0.3 -event_id = {'Auditory/Left': 1, 'Auditory/Right': 2, - 'Visual/Left': 3, 'Visual/Right': 4} +event_id = { + "Auditory/Left": 1, + "Auditory/Right": 2, + "Visual/Left": 3, + "Visual/Right": 4, +} n_filter = 3 # Setup for reading the raw data raw = io.read_raw_fif(raw_fname, preload=True) -raw.filter(1, 20, fir_design='firwin') +raw.filter(1, 20, fir_design="firwin") events = read_events(event_fname) -picks = pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False, - exclude='bads') - -epochs = Epochs(raw, events, event_id, tmin, tmax, proj=False, - picks=picks, baseline=None, preload=True, - verbose=False) +picks = pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False, exclude="bads") + +epochs = Epochs( + raw, + events, + event_id, + tmin, + tmax, + proj=False, + picks=picks, + baseline=None, + preload=True, + verbose=False, +) # Create classification pipeline -clf = make_pipeline(Xdawn(n_components=n_filter), - Vectorizer(), - MinMaxScaler(), - LogisticRegression(penalty='l1', solver='liblinear', - multi_class='auto')) +clf = make_pipeline( + Xdawn(n_components=n_filter), + Vectorizer(), + MinMaxScaler(), + LogisticRegression(penalty="l1", solver="liblinear", multi_class="auto"), +) # Get the labels labels = epochs.events[:, -1] @@ -77,7 +90,7 @@ preds[test] = clf.predict(epochs[test]) # Classification report -target_names = ['aud_l', 'aud_r', 'vis_l', 'vis_r'] +target_names = ["aud_l", "aud_r", "vis_l", "vis_r"] report = classification_report(labels, preds, target_names=target_names) print(report) @@ -87,21 +100,22 @@ # Plot confusion matrix fig, ax = plt.subplots(1) -im = ax.imshow(cm_normalized, interpolation='nearest', cmap=plt.cm.Blues) -ax.set(title='Normalized Confusion matrix') +im = ax.imshow(cm_normalized, interpolation="nearest", cmap=plt.cm.Blues) +ax.set(title="Normalized Confusion matrix") fig.colorbar(im) tick_marks = np.arange(len(target_names)) plt.xticks(tick_marks, target_names, rotation=45) plt.yticks(tick_marks, target_names) fig.tight_layout() -ax.set(ylabel='True label', xlabel='Predicted label') +ax.set(ylabel="True label", xlabel="Predicted label") # %% # The ``patterns_`` attribute of a fitted Xdawn instance (here from the last # cross-validation fold) can be used for visualization. -fig, axes = plt.subplots(nrows=len(event_id), ncols=n_filter, - figsize=(n_filter, len(event_id) * 2)) +fig, axes = plt.subplots( + nrows=len(event_id), ncols=n_filter, figsize=(n_filter, len(event_id) * 2) +) fitted_xdawn = clf.steps[0][1] info = create_info(epochs.ch_names, 1, epochs.get_channel_types()) info.set_montage(epochs.get_montage()) @@ -110,8 +124,12 @@ pattern_evoked = EvokedArray(cur_patterns[:n_filter].T, info, tmin=0) pattern_evoked.plot_topomap( times=np.arange(n_filter), - time_format='Component %d' if ii == 0 else '', colorbar=False, - show_names=False, axes=axes[ii], show=False) + time_format="Component %d" if ii == 0 else "", + colorbar=False, + show_names=False, + axes=axes[ii], + show=False, + ) axes[ii, 0].set(ylabel=cur_class) fig.tight_layout(h_pad=1.0, w_pad=1.0, pad=0.1) diff --git a/examples/decoding/ems_filtering.py b/examples/decoding/ems_filtering.py index 8807bf57079..34b3bcf8489 100644 --- a/examples/decoding/ems_filtering.py +++ b/examples/decoding/ems_filtering.py @@ -39,24 +39,33 @@ data_path = sample.data_path() # Preprocess the data -meg_path = data_path / 'MEG' / 'sample' -raw_fname = meg_path / 'sample_audvis_filt-0-40_raw.fif' -event_fname = meg_path / 'sample_audvis_filt-0-40_raw-eve.fif' -event_ids = {'AudL': 1, 'VisL': 3} +meg_path = data_path / "MEG" / "sample" +raw_fname = meg_path / "sample_audvis_filt-0-40_raw.fif" +event_fname = meg_path / "sample_audvis_filt-0-40_raw-eve.fif" +event_ids = {"AudL": 1, "VisL": 3} # Read data and create epochs raw = io.read_raw_fif(raw_fname, preload=True) -raw.filter(0.5, 45, fir_design='firwin') +raw.filter(0.5, 45, fir_design="firwin") events = mne.read_events(event_fname) -picks = mne.pick_types(raw.info, meg='grad', eeg=False, stim=False, eog=True, - exclude='bads') - -epochs = mne.Epochs(raw, events, event_ids, tmin=-0.2, tmax=0.5, picks=picks, - baseline=None, reject=dict(grad=4000e-13, eog=150e-6), - preload=True) +picks = mne.pick_types( + raw.info, meg="grad", eeg=False, stim=False, eog=True, exclude="bads" +) + +epochs = mne.Epochs( + raw, + events, + event_ids, + tmin=-0.2, + tmax=0.5, + picks=picks, + baseline=None, + reject=dict(grad=4000e-13, eog=150e-6), + preload=True, +) epochs.drop_bad() -epochs.pick_types(meg='grad') +epochs.pick_types(meg="grad") # Setup the data to use it a scikit-learn way: X = epochs.get_data() # The MEG data @@ -98,23 +107,27 @@ # Plot individual trials plt.figure() -plt.title('single trial surrogates') -plt.imshow(X_transform[y.argsort()], origin='lower', aspect='auto', - extent=[epochs.times[0], epochs.times[-1], 1, len(X_transform)], - cmap='RdBu_r') -plt.xlabel('Time (ms)') -plt.ylabel('Trials (reordered by condition)') +plt.title("single trial surrogates") +plt.imshow( + X_transform[y.argsort()], + origin="lower", + aspect="auto", + extent=[epochs.times[0], epochs.times[-1], 1, len(X_transform)], + cmap="RdBu_r", +) +plt.xlabel("Time (ms)") +plt.ylabel("Trials (reordered by condition)") # Plot average response plt.figure() -plt.title('Average EMS signal') +plt.title("Average EMS signal") mappings = [(key, value) for key, value in event_ids.items()] for key, value in mappings: ems_ave = X_transform[y == value] plt.plot(epochs.times, ems_ave.mean(0), label=key) -plt.xlabel('Time (ms)') -plt.ylabel('a.u.') -plt.legend(loc='best') +plt.xlabel("Time (ms)") +plt.ylabel("a.u.") +plt.legend(loc="best") plt.show() # Visualize spatial filters across time diff --git a/examples/decoding/linear_model_patterns.py b/examples/decoding/linear_model_patterns.py index f708503214b..1786df4a4b8 100644 --- a/examples/decoding/linear_model_patterns.py +++ b/examples/decoding/linear_model_patterns.py @@ -37,23 +37,24 @@ print(__doc__) data_path = sample.data_path() -sample_path = data_path / 'MEG' / 'sample' +sample_path = data_path / "MEG" / "sample" # %% # Set parameters -raw_fname = sample_path / 'sample_audvis_filt-0-40_raw.fif' -event_fname = sample_path / 'sample_audvis_filt-0-40_raw-eve.fif' +raw_fname = sample_path / "sample_audvis_filt-0-40_raw.fif" +event_fname = sample_path / "sample_audvis_filt-0-40_raw-eve.fif" tmin, tmax = -0.1, 0.4 event_id = dict(aud_l=1, vis_l=3) # Setup for reading the raw data raw = io.read_raw_fif(raw_fname, preload=True) -raw.filter(.5, 25, fir_design='firwin') +raw.filter(0.5, 25, fir_design="firwin") events = mne.read_events(event_fname) # Read epochs -epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=True, - decim=2, baseline=None, preload=True) +epochs = mne.Epochs( + raw, events, event_id, tmin, tmax, proj=True, decim=2, baseline=None, preload=True +) del raw labels = epochs.events[:, -1] @@ -66,7 +67,7 @@ # Decoding in sensor space using a LogisticRegression classifier # -------------------------------------------------------------- -clf = LogisticRegression(solver='liblinear') # liblinear is faster than lbfgs +clf = LogisticRegression(solver="liblinear") # liblinear is faster than lbfgs scaler = StandardScaler() # create a linear model with LogisticRegression @@ -77,7 +78,7 @@ model.fit(X, labels) # Extract and plot spatial filters and spatial patterns -for name, coef in (('patterns', model.patterns_), ('filters', model.filters_)): +for name, coef in (("patterns", model.patterns_), ("filters", model.filters_)): # We fitted the linear model onto Z-scored data. To make the filters # interpretable, we must reverse this normalization step coef = scaler.inverse_transform([coef])[0] @@ -89,7 +90,7 @@ # Plot evoked = EvokedArray(coef, meg_epochs.info, tmin=epochs.tmin) fig = evoked.plot_topomap() - fig.suptitle(f'MEG {name}') + fig.suptitle(f"MEG {name}") # %% # Let's do the same on EEG data using a scikit-learn pipeline @@ -99,22 +100,22 @@ # Define a unique pipeline to sequentially: clf = make_pipeline( - Vectorizer(), # 1) vectorize across time and channels - StandardScaler(), # 2) normalize features across trials - LinearModel( # 3) fits a logistic regression - LogisticRegression(solver='liblinear') - ) + Vectorizer(), # 1) vectorize across time and channels + StandardScaler(), # 2) normalize features across trials + LinearModel( # 3) fits a logistic regression + LogisticRegression(solver="liblinear") + ), ) clf.fit(X, y) # Extract and plot patterns and filters -for name in ('patterns_', 'filters_'): +for name in ("patterns_", "filters_"): # The `inverse_transform` parameter will call this method on any estimator # contained in the pipeline, in reverse order. coef = get_coef(clf, name, inverse_transform=True) evoked = EvokedArray(coef, epochs.info, tmin=epochs.tmin) fig = evoked.plot_topomap() - fig.suptitle(f'EEG {name[:-1]}') + fig.suptitle(f"EEG {name[:-1]}") # %% # References diff --git a/examples/decoding/receptive_field_mtrf.py b/examples/decoding/receptive_field_mtrf.py index 4e948613dbb..0773811f1f3 100644 --- a/examples/decoding/receptive_field_mtrf.py +++ b/examples/decoding/receptive_field_mtrf.py @@ -52,25 +52,25 @@ path = mne.datasets.mtrf.data_path() decim = 2 -data = loadmat(join(path, 'speech_data.mat')) -raw = data['EEG'].T -speech = data['envelope'].T -sfreq = float(data['Fs']) +data = loadmat(join(path, "speech_data.mat")) +raw = data["EEG"].T +speech = data["envelope"].T +sfreq = float(data["Fs"]) sfreq /= decim -speech = mne.filter.resample(speech, down=decim, npad='auto') -raw = mne.filter.resample(raw, down=decim, npad='auto') +speech = mne.filter.resample(speech, down=decim, npad="auto") +raw = mne.filter.resample(raw, down=decim, npad="auto") # Read in channel positions and create our MNE objects from the raw data -montage = mne.channels.make_standard_montage('biosemi128') -info = mne.create_info(montage.ch_names, sfreq, 'eeg').set_montage(montage) +montage = mne.channels.make_standard_montage("biosemi128") +info = mne.create_info(montage.ch_names, sfreq, "eeg").set_montage(montage) raw = mne.io.RawArray(raw, info) n_channels = len(raw.ch_names) # Plot a sample of brain and stimulus activity fig, ax = plt.subplots() -lns = ax.plot(scale(raw[:, :800][0].T), color='k', alpha=.1) -ln1 = ax.plot(scale(speech[0, :800]), color='r', lw=2) -ax.legend([lns[0], ln1[0]], ['EEG', 'Speech Envelope'], frameon=False) +lns = ax.plot(scale(raw[:, :800][0].T), color="k", alpha=0.1) +ln1 = ax.plot(scale(speech[0, :800]), color="r", lw=2) +ax.legend([lns[0], ln1[0]], ["EEG", "Speech Envelope"], frameon=False) ax.set(title="Sample activity", xlabel="Time (s)") mne.viz.tight_layout() @@ -83,11 +83,12 @@ # us to make predictions about the response to new stimuli. # Define the delays that we will use in the receptive field -tmin, tmax = -.2, .4 +tmin, tmax = -0.2, 0.4 # Initialize the model -rf = ReceptiveField(tmin, tmax, sfreq, feature_names=['envelope'], - estimator=1., scoring='corrcoef') +rf = ReceptiveField( + tmin, tmax, sfreq, feature_names=["envelope"], estimator=1.0, scoring="corrcoef" +) # We'll have (tmax - tmin) * sfreq delays # and an extra 2 delays since we are inclusive on the beginning / end index n_delays = int((tmax - tmin) * sfreq) + 2 @@ -104,7 +105,7 @@ coefs = np.zeros((n_splits, n_channels, n_delays)) scores = np.zeros((n_splits, n_channels)) for ii, (train, test) in enumerate(cv.split(speech)): - print('split %s / %s' % (ii + 1, n_splits)) + print("split %s / %s" % (ii + 1, n_splits)) rf.fit(speech[train], Y[train]) scores[ii] = rf.score(speech[test], Y[test]) # coef_ is shape (n_outputs, n_features, n_delays). we only have 1 feature @@ -119,7 +120,7 @@ fig, ax = plt.subplots() ix_chs = np.arange(n_channels) ax.plot(ix_chs, mean_scores) -ax.axhline(0, ls='--', color='r') +ax.axhline(0, ls="--", color="r") ax.set(title="Mean prediction score", xlabel="Channel", ylabel="Score ($r$)") mne.viz.tight_layout() @@ -135,20 +136,33 @@ time_plot = 0.180 # For highlighting a specific time. fig, ax = plt.subplots(figsize=(4, 8)) max_coef = mean_coefs.max() -ax.pcolormesh(times, ix_chs, mean_coefs, cmap='RdBu_r', - vmin=-max_coef, vmax=max_coef, shading='gouraud') -ax.axvline(time_plot, ls='--', color='k', lw=2) -ax.set(xlabel='Delay (s)', ylabel='Channel', title="Mean Model\nCoefficients", - xlim=times[[0, -1]], ylim=[len(ix_chs) - 1, 0], - xticks=np.arange(tmin, tmax + .2, .2)) +ax.pcolormesh( + times, + ix_chs, + mean_coefs, + cmap="RdBu_r", + vmin=-max_coef, + vmax=max_coef, + shading="gouraud", +) +ax.axvline(time_plot, ls="--", color="k", lw=2) +ax.set( + xlabel="Delay (s)", + ylabel="Channel", + title="Mean Model\nCoefficients", + xlim=times[[0, -1]], + ylim=[len(ix_chs) - 1, 0], + xticks=np.arange(tmin, tmax + 0.2, 0.2), +) plt.setp(ax.get_xticklabels(), rotation=45) mne.viz.tight_layout() # Make a topographic map of coefficients for a given delay (see Fig 2C) ix_plot = np.argmin(np.abs(time_plot - times)) fig, ax = plt.subplots() -mne.viz.plot_topomap(mean_coefs[:, ix_plot], pos=info, axes=ax, show=False, - vlim=(-max_coef, max_coef)) +mne.viz.plot_topomap( + mean_coefs[:, ix_plot], pos=info, axes=ax, show=False, vlim=(-max_coef, max_coef) +) ax.set(title="Topomap of model coefficients\nfor delay %s" % time_plot) mne.viz.tight_layout() @@ -174,15 +188,22 @@ # positive lags would index how a unit change in the amplitude of the EEG would # affect later stimulus activity (obviously this should have an amplitude of # zero). -tmin, tmax = -.2, 0. +tmin, tmax = -0.2, 0.0 # Initialize the model. Here the features are the EEG data. We also specify # ``patterns=True`` to compute inverse-transformed coefficients during model # fitting (cf. next section and :footcite:`HaufeEtAl2014`). # We'll use a ridge regression estimator with an alpha value similar to # Crosse et al. -sr = ReceptiveField(tmin, tmax, sfreq, feature_names=raw.ch_names, - estimator=1e4, scoring='corrcoef', patterns=True) +sr = ReceptiveField( + tmin, + tmax, + sfreq, + feature_names=raw.ch_names, + estimator=1e4, + scoring="corrcoef", + patterns=True, +) # We'll have (tmax - tmin) * sfreq delays # and an extra 2 delays since we are inclusive on the beginning / end index n_delays = int((tmax - tmin) * sfreq) + 2 @@ -195,7 +216,7 @@ patterns = coefs.copy() scores = np.zeros((n_splits,)) for ii, (train, test) in enumerate(cv.split(speech)): - print('split %s / %s' % (ii + 1, n_splits)) + print("split %s / %s" % (ii + 1, n_splits)) sr.fit(Y[train], speech[train]) scores[ii] = sr.score(Y[test], speech[test])[0] # coef_ is shape (n_outputs, n_features, n_delays). We have 128 features @@ -218,14 +239,15 @@ # stimulus envelopes side by side. y_pred = sr.predict(Y[test]) -time = np.linspace(0, 2., 5 * int(sfreq)) +time = np.linspace(0, 2.0, 5 * int(sfreq)) fig, ax = plt.subplots(figsize=(8, 4)) -ax.plot(time, speech[test][sr.valid_samples_][:int(5 * sfreq)], - color='grey', lw=2, ls='--') -ax.plot(time, y_pred[sr.valid_samples_][:int(5 * sfreq)], color='r', lw=2) -ax.legend([lns[0], ln1[0]], ['Envelope', 'Reconstruction'], frameon=False) +ax.plot( + time, speech[test][sr.valid_samples_][: int(5 * sfreq)], color="grey", lw=2, ls="--" +) +ax.plot(time, y_pred[sr.valid_samples_][: int(5 * sfreq)], color="r", lw=2) +ax.legend([lns[0], ln1[0]], ["Envelope", "Reconstruction"], frameon=False) ax.set(title="Stimulus reconstruction") -ax.set_xlabel('Time (s)') +ax.set_xlabel("Time (s)") mne.viz.tight_layout() # %% @@ -243,21 +265,33 @@ # interpretation as their value (and sign) directly relates to the stimulus # signal's strength (and effect direction). -time_plot = (-.140, -.125) # To average between two timepoints. -ix_plot = np.arange(np.argmin(np.abs(time_plot[0] - times)), - np.argmin(np.abs(time_plot[1] - times))) +time_plot = (-0.140, -0.125) # To average between two timepoints. +ix_plot = np.arange( + np.argmin(np.abs(time_plot[0] - times)), np.argmin(np.abs(time_plot[1] - times)) +) fig, ax = plt.subplots(1, 2) -mne.viz.plot_topomap(np.mean(mean_coefs[:, ix_plot], axis=1), - pos=info, axes=ax[0], show=False, - vlim=(-max_coef, max_coef)) -ax[0].set(title="Model coefficients\nbetween delays %s and %s" - % (time_plot[0], time_plot[1])) - -mne.viz.plot_topomap(np.mean(mean_patterns[:, ix_plot], axis=1), - pos=info, axes=ax[1], - show=False, vlim=(-max_patterns, max_patterns)) -ax[1].set(title="Inverse-transformed coefficients\nbetween delays %s and %s" - % (time_plot[0], time_plot[1])) +mne.viz.plot_topomap( + np.mean(mean_coefs[:, ix_plot], axis=1), + pos=info, + axes=ax[0], + show=False, + vlim=(-max_coef, max_coef), +) +ax[0].set( + title="Model coefficients\nbetween delays %s and %s" % (time_plot[0], time_plot[1]) +) + +mne.viz.plot_topomap( + np.mean(mean_patterns[:, ix_plot], axis=1), + pos=info, + axes=ax[1], + show=False, + vlim=(-max_patterns, max_patterns), +) +ax[1].set( + title="Inverse-transformed coefficients\nbetween delays %s and %s" + % (time_plot[0], time_plot[1]) +) mne.viz.tight_layout() # %% diff --git a/examples/decoding/ssd_spatial_filters.py b/examples/decoding/ssd_spatial_filters.py index 723667c1864..6be80b9667c 100644 --- a/examples/decoding/ssd_spatial_filters.py +++ b/examples/decoding/ssd_spatial_filters.py @@ -28,11 +28,11 @@ # %% # Define parameters -fname = data_path() / 'SubjectCMC.ds' +fname = data_path() / "SubjectCMC.ds" # Prepare data raw = mne.io.read_raw_ctf(fname) -raw.crop(50., 110.).load_data() # crop for memory purposes +raw.crop(50.0, 110.0).load_data() # crop for memory purposes raw.resample(sfreq=250) raw.pick_types(meg=True, eeg=False, ref_meg=False) @@ -41,13 +41,23 @@ freqs_noise = 8, 13 -ssd = SSD(info=raw.info, - reg='oas', - sort_by_spectral_ratio=False, # False for purpose of example. - filt_params_signal=dict(l_freq=freqs_sig[0], h_freq=freqs_sig[1], - l_trans_bandwidth=1, h_trans_bandwidth=1), - filt_params_noise=dict(l_freq=freqs_noise[0], h_freq=freqs_noise[1], - l_trans_bandwidth=1, h_trans_bandwidth=1)) +ssd = SSD( + info=raw.info, + reg="oas", + sort_by_spectral_ratio=False, # False for purpose of example. + filt_params_signal=dict( + l_freq=freqs_sig[0], + h_freq=freqs_sig[1], + l_trans_bandwidth=1, + h_trans_bandwidth=1, + ), + filt_params_noise=dict( + l_freq=freqs_noise[0], + h_freq=freqs_noise[1], + l_trans_bandwidth=1, + h_trans_bandwidth=1, + ), +) ssd.fit(X=raw.get_data()) @@ -58,9 +68,8 @@ # (W^{-1}) or by multiplying the noise cov with the filters Eq. (22) (C_n W)^t. # We rely on the inversion approach here. -pattern = mne.EvokedArray(data=ssd.patterns_[:4].T, - info=ssd.info) -pattern.plot_topomap(units=dict(mag='A.U.'), time_format='') +pattern = mne.EvokedArray(data=ssd.patterns_[:4].T, info=ssd.info) +pattern.plot_topomap(units=dict(mag="A.U."), time_format="") # The topographies suggest that we picked up a parietal alpha generator. @@ -69,7 +78,8 @@ # Get psd of SSD-filtered signals. psd, freqs = mne.time_frequency.psd_array_welch( - ssd_sources, sfreq=raw.info['sfreq'], n_fft=4096) + ssd_sources, sfreq=raw.info["sfreq"], n_fft=4096 +) # Get spec_ratio information (already sorted). # Note that this is not necessary if sort_by_spectral_ratio=True (default). @@ -77,12 +87,12 @@ # Plot spectral ratio (see Eq. 24 in Nikulin 2011). fig, ax = plt.subplots(1) -ax.plot(spec_ratio, color='black') -ax.plot(spec_ratio[sorter], color='orange', label='sorted eigenvalues') +ax.plot(spec_ratio, color="black") +ax.plot(spec_ratio[sorter], color="orange", label="sorted eigenvalues") ax.set_xlabel("Eigenvalue Index") ax.set_ylabel(r"Spectral Ratio $\frac{P_f}{P_{sf}}$") ax.legend() -ax.axhline(1, linestyle='--') +ax.axhline(1, linestyle="--") # We can see that the initial sorting based on the eigenvalues # was already quite good. However, when using few components only @@ -96,12 +106,12 @@ # for highlighting the freq. band of interest bandfilt = (freqs_sig[0] <= freqs) & (freqs <= freqs_sig[1]) fig, ax = plt.subplots(1) -ax.loglog(freqs[below50], psd[0, below50], label='max SNR') -ax.loglog(freqs[below50], psd[-1, below50], label='min SNR') -ax.loglog(freqs[below50], psd[:, below50].mean(axis=0), label='mean') -ax.fill_between(freqs[bandfilt], 0, 10000, color='green', alpha=0.15) -ax.set_xlabel('log(frequency)') -ax.set_ylabel('log(power)') +ax.loglog(freqs[below50], psd[0, below50], label="max SNR") +ax.loglog(freqs[below50], psd[-1, below50], label="min SNR") +ax.loglog(freqs[below50], psd[:, below50].mean(axis=0), label="mean") +ax.fill_between(freqs[bandfilt], 0, 10000, color="green", alpha=0.15) +ax.set_xlabel("log(frequency)") +ax.set_ylabel("log(power)") ax.legend() # We can clearly see that the selected component enjoys an SNR that is @@ -117,25 +127,29 @@ events = mne.make_fixed_length_events(raw, id=1, duration=5.0, overlap=0.0) # Epoch length is 5 seconds. -epochs = Epochs(raw, events, tmin=0., tmax=5, - baseline=None, preload=True) - -ssd_epochs = SSD(info=epochs.info, - reg='oas', - filt_params_signal=dict(l_freq=freqs_sig[0], - h_freq=freqs_sig[1], - l_trans_bandwidth=1, - h_trans_bandwidth=1), - filt_params_noise=dict(l_freq=freqs_noise[0], - h_freq=freqs_noise[1], - l_trans_bandwidth=1, - h_trans_bandwidth=1)) +epochs = Epochs(raw, events, tmin=0.0, tmax=5, baseline=None, preload=True) + +ssd_epochs = SSD( + info=epochs.info, + reg="oas", + filt_params_signal=dict( + l_freq=freqs_sig[0], + h_freq=freqs_sig[1], + l_trans_bandwidth=1, + h_trans_bandwidth=1, + ), + filt_params_noise=dict( + l_freq=freqs_noise[0], + h_freq=freqs_noise[1], + l_trans_bandwidth=1, + h_trans_bandwidth=1, + ), +) ssd_epochs.fit(X=epochs.get_data()) # Plot topographies. -pattern_epochs = mne.EvokedArray(data=ssd_epochs.patterns_[:4].T, - info=ssd_epochs.info) -pattern_epochs.plot_topomap(units=dict(mag='A.U.'), time_format='') +pattern_epochs = mne.EvokedArray(data=ssd_epochs.patterns_[:4].T, info=ssd_epochs.info) +pattern_epochs.plot_topomap(units=dict(mag="A.U."), time_format="") # %% # References # ---------- diff --git a/examples/forward/forward_sensitivity_maps.py b/examples/forward/forward_sensitivity_maps.py index e17e8e38c12..dca41bb9b12 100644 --- a/examples/forward/forward_sensitivity_maps.py +++ b/examples/forward/forward_sensitivity_maps.py @@ -28,80 +28,84 @@ print(__doc__) data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -fwd_fname = meg_path / 'sample_audvis-meg-eeg-oct-6-fwd.fif' -subjects_dir = data_path / 'subjects' +meg_path = data_path / "MEG" / "sample" +fwd_fname = meg_path / "sample_audvis-meg-eeg-oct-6-fwd.fif" +subjects_dir = data_path / "subjects" # Read the forward solutions with surface orientation fwd = mne.read_forward_solution(fwd_fname) mne.convert_forward_solution(fwd, surf_ori=True, copy=False) -leadfield = fwd['sol']['data'] +leadfield = fwd["sol"]["data"] print("Leadfield size : %d x %d" % leadfield.shape) # %% # Compute sensitivity maps -grad_map = mne.sensitivity_map(fwd, ch_type='grad', mode='fixed') -mag_map = mne.sensitivity_map(fwd, ch_type='mag', mode='fixed') -eeg_map = mne.sensitivity_map(fwd, ch_type='eeg', mode='fixed') +grad_map = mne.sensitivity_map(fwd, ch_type="grad", mode="fixed") +mag_map = mne.sensitivity_map(fwd, ch_type="mag", mode="fixed") +eeg_map = mne.sensitivity_map(fwd, ch_type="eeg", mode="fixed") # %% # Show gain matrix a.k.a. leadfield matrix with sensitivity map -picks_meg = mne.pick_types(fwd['info'], meg=True, eeg=False) -picks_eeg = mne.pick_types(fwd['info'], meg=False, eeg=True) +picks_meg = mne.pick_types(fwd["info"], meg=True, eeg=False) +picks_eeg = mne.pick_types(fwd["info"], meg=False, eeg=True) fig, axes = plt.subplots(2, 1, figsize=(10, 8), sharex=True) -fig.suptitle('Lead field matrix (500 dipoles only)', fontsize=14) -for ax, picks, ch_type in zip(axes, [picks_meg, picks_eeg], ['meg', 'eeg']): - im = ax.imshow(leadfield[picks, :500], origin='lower', aspect='auto', - cmap='RdBu_r') +fig.suptitle("Lead field matrix (500 dipoles only)", fontsize=14) +for ax, picks, ch_type in zip(axes, [picks_meg, picks_eeg], ["meg", "eeg"]): + im = ax.imshow(leadfield[picks, :500], origin="lower", aspect="auto", cmap="RdBu_r") ax.set_title(ch_type.upper()) - ax.set_xlabel('sources') - ax.set_ylabel('sensors') + ax.set_xlabel("sources") + ax.set_ylabel("sensors") fig.colorbar(im, ax=ax) fig_2, ax = plt.subplots() -ax.hist([grad_map.data.ravel(), mag_map.data.ravel(), eeg_map.data.ravel()], - bins=20, label=['Gradiometers', 'Magnetometers', 'EEG'], - color=['c', 'b', 'k']) +ax.hist( + [grad_map.data.ravel(), mag_map.data.ravel(), eeg_map.data.ravel()], + bins=20, + label=["Gradiometers", "Magnetometers", "EEG"], + color=["c", "b", "k"], +) fig_2.legend() -ax.set(title='Normal orientation sensitivity', - xlabel='sensitivity', ylabel='count') +ax.set(title="Normal orientation sensitivity", xlabel="sensitivity", ylabel="count") # sphinx_gallery_thumbnail_number = 3 brain_sens = grad_map.plot( - subjects_dir=subjects_dir, clim=dict(lims=[0, 50, 100]), figure=1) -brain_sens.add_text(0.1, 0.9, 'Gradiometer sensitivity', 'title', font_size=16) + subjects_dir=subjects_dir, clim=dict(lims=[0, 50, 100]), figure=1 +) +brain_sens.add_text(0.1, 0.9, "Gradiometer sensitivity", "title", font_size=16) # %% # Compare sensitivity map with distribution of source depths # source space with vertices -src = fwd['src'] +src = fwd["src"] # Compute minimum Euclidean distances between vertices and MEG sensors -depths = compute_distance_to_sensors(src=src, info=fwd['info'], - picks=picks_meg).min(axis=1) +depths = compute_distance_to_sensors(src=src, info=fwd["info"], picks=picks_meg).min( + axis=1 +) maxdep = depths.max() # for scaling -vertices = [src[0]['vertno'], src[1]['vertno']] +vertices = [src[0]["vertno"], src[1]["vertno"]] -depths_map = SourceEstimate(data=depths, vertices=vertices, tmin=0., - tstep=1.) +depths_map = SourceEstimate(data=depths, vertices=vertices, tmin=0.0, tstep=1.0) brain_dep = depths_map.plot( - subject='sample', subjects_dir=subjects_dir, - clim=dict(kind='value', lims=[0, maxdep / 2., maxdep]), figure=2) -brain_dep.add_text(0.1, 0.9, 'Source depth (m)', 'title', font_size=16) + subject="sample", + subjects_dir=subjects_dir, + clim=dict(kind="value", lims=[0, maxdep / 2.0, maxdep]), + figure=2, +) +brain_dep.add_text(0.1, 0.9, "Source depth (m)", "title", font_size=16) # %% # Sensitivity is likely to co-vary with the distance between sources to # sensors. To determine the strength of this relationship, we can compute the # correlation between source depth and sensitivity values. corr = np.corrcoef(depths, grad_map.data[:, 0])[0, 1] -print('Correlation between source depth and gradiomter sensitivity values: %f.' - % corr) +print("Correlation between source depth and gradiomter sensitivity values: %f." % corr) # %% # Gradiometer sensitiviy is highest close to the sensors, and decreases rapidly diff --git a/examples/forward/left_cerebellum_volume_source.py b/examples/forward/left_cerebellum_volume_source.py index c8327100f10..e74b71c6c4f 100644 --- a/examples/forward/left_cerebellum_volume_source.py +++ b/examples/forward/left_cerebellum_volume_source.py @@ -23,9 +23,9 @@ print(__doc__) data_path = sample.data_path() -subjects_dir = data_path / 'subjects' -subject = 'sample' -aseg_fname = subjects_dir / 'sample' / 'mri' / 'aseg.mgz' +subjects_dir = data_path / "subjects" +subject = "sample" +aseg_fname = subjects_dir / "sample" / "mri" / "aseg.mgz" # %% # Setup the source spaces @@ -35,11 +35,16 @@ lh_surf = surf[0] # setup a volume source space of the left cerebellum cortex -volume_label = 'Left-Cerebellum-Cortex' +volume_label = "Left-Cerebellum-Cortex" sphere = (0, 0, 0, 0.12) lh_cereb = setup_volume_source_space( - subject, mri=aseg_fname, sphere=sphere, volume_label=volume_label, - subjects_dir=subjects_dir, sphere_units='m') + subject, + mri=aseg_fname, + sphere=sphere, + volume_label=volume_label, + subjects_dir=subjects_dir, + sphere_units="m", +) # Combine the source spaces src = surf + lh_cereb @@ -47,11 +52,16 @@ # %% # Plot the positions of each source space -fig = mne.viz.plot_alignment(subject=subject, subjects_dir=subjects_dir, - surfaces='white', coord_frame='mri', - src=src) -mne.viz.set_3d_view(fig, azimuth=180, elevation=90, - distance=0.30, focalpoint=(-0.03, -0.01, 0.03)) +fig = mne.viz.plot_alignment( + subject=subject, + subjects_dir=subjects_dir, + surfaces="white", + coord_frame="mri", + src=src, +) +mne.viz.set_3d_view( + fig, azimuth=180, elevation=90, distance=0.30, focalpoint=(-0.03, -0.01, 0.03) +) # %% # You can export source positions to a NIfTI file:: diff --git a/examples/forward/source_space_morphing.py b/examples/forward/source_space_morphing.py index 77688705e97..5085e629615 100644 --- a/examples/forward/source_space_morphing.py +++ b/examples/forward/source_space_morphing.py @@ -24,38 +24,38 @@ import mne data_path = mne.datasets.sample.data_path() -subjects_dir = data_path / 'subjects' -fname_trans = ( - data_path / 'MEG' / 'sample' / 'sample_audvis_raw-trans.fif') -fname_bem = ( - subjects_dir / 'sample' / 'bem' / 'sample-5120-bem-sol.fif') -fname_src_fs = ( - subjects_dir / 'fsaverage' / 'bem' / 'fsaverage-ico-5-src.fif') -raw_fname = data_path / 'MEG' / 'sample' / 'sample_audvis_raw.fif' +subjects_dir = data_path / "subjects" +fname_trans = data_path / "MEG" / "sample" / "sample_audvis_raw-trans.fif" +fname_bem = subjects_dir / "sample" / "bem" / "sample-5120-bem-sol.fif" +fname_src_fs = subjects_dir / "fsaverage" / "bem" / "fsaverage-ico-5-src.fif" +raw_fname = data_path / "MEG" / "sample" / "sample_audvis_raw.fif" # Get relevant channel information info = mne.io.read_info(raw_fname) -info = mne.pick_info(info, mne.pick_types(info, meg=True, eeg=False, - exclude=[])) +info = mne.pick_info(info, mne.pick_types(info, meg=True, eeg=False, exclude=[])) # Morph fsaverage's source space to sample src_fs = mne.read_source_spaces(fname_src_fs) -src_morph = mne.morph_source_spaces(src_fs, subject_to='sample', - subjects_dir=subjects_dir) +src_morph = mne.morph_source_spaces( + src_fs, subject_to="sample", subjects_dir=subjects_dir +) # Compute the forward with our morphed source space -fwd = mne.make_forward_solution(info, trans=fname_trans, - src=src_morph, bem=fname_bem) -mag_map = mne.sensitivity_map(fwd, ch_type='mag') +fwd = mne.make_forward_solution(info, trans=fname_trans, src=src_morph, bem=fname_bem) +mag_map = mne.sensitivity_map(fwd, ch_type="mag") # Return this SourceEstimate (on sample's surfaces) to fsaverage's surfaces mag_map_fs = mag_map.to_original_src(src_fs, subjects_dir=subjects_dir) # Plot the result, which tracks the sulcal-gyral folding # outliers may occur, we'll place the cutoff at 99 percent. -kwargs = dict(clim=dict(kind='percent', lims=[0, 50, 99]), - # no smoothing, let's see the dipoles on the cortex. - smoothing_steps=1, hemi='rh', views=['lat']) +kwargs = dict( + clim=dict(kind="percent", lims=[0, 50, 99]), + # no smoothing, let's see the dipoles on the cortex. + smoothing_steps=1, + hemi="rh", + views=["lat"], +) # Now note that the dipoles on fsaverage are almost equidistant while # morphing will distribute the dipoles unevenly across the given subject's @@ -63,7 +63,9 @@ # Our testing code suggests a correlation of higher than 0.99. brain_subject = mag_map.plot( # plot forward in subject source space (morphed) - time_label='Morphed', subjects_dir=subjects_dir, **kwargs) + time_label="Morphed", subjects_dir=subjects_dir, **kwargs +) brain_fs = mag_map_fs.plot( # plot forward in original source space (remapped) - time_label='Remapped', subjects_dir=subjects_dir, **kwargs) + time_label="Remapped", subjects_dir=subjects_dir, **kwargs +) diff --git a/examples/inverse/compute_mne_inverse_epochs_in_label.py b/examples/inverse/compute_mne_inverse_epochs_in_label.py index e78b37c17fe..e779444f6cf 100644 --- a/examples/inverse/compute_mne_inverse_epochs_in_label.py +++ b/examples/inverse/compute_mne_inverse_epochs_in_label.py @@ -25,18 +25,18 @@ print(__doc__) data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -fname_inv = meg_path / 'sample_audvis-meg-oct-6-meg-inv.fif' -fname_raw = meg_path / 'sample_audvis_filt-0-40_raw.fif' -fname_event = meg_path / 'sample_audvis_filt-0-40_raw-eve.fif' -label_name = 'Aud-lh' -fname_label = meg_path / 'labels' / f'{label_name}.label' +meg_path = data_path / "MEG" / "sample" +fname_inv = meg_path / "sample_audvis-meg-oct-6-meg-inv.fif" +fname_raw = meg_path / "sample_audvis_filt-0-40_raw.fif" +fname_event = meg_path / "sample_audvis_filt-0-40_raw-eve.fif" +label_name = "Aud-lh" +fname_label = meg_path / "labels" / f"{label_name}.label" event_id, tmin, tmax = 1, -0.2, 0.5 # Using the same inverse operator when inspecting single trials Vs. evoked snr = 3.0 # Standard assumption for average data but using it for single trial -lambda2 = 1.0 / snr ** 2 +lambda2 = 1.0 / snr**2 method = "dSPM" # use dSPM method (could also be MNE or sLORETA) @@ -50,15 +50,23 @@ include = [] # Add a bad channel -raw.info['bads'] += ['EEG 053'] # bads + 1 more +raw.info["bads"] += ["EEG 053"] # bads + 1 more # pick MEG channels -picks = mne.pick_types(raw.info, meg=True, eeg=False, stim=False, eog=True, - include=include, exclude='bads') +picks = mne.pick_types( + raw.info, meg=True, eeg=False, stim=False, eog=True, include=include, exclude="bads" +) # Read epochs -epochs = mne.Epochs(raw, events, event_id, tmin, tmax, picks=picks, - baseline=(None, 0), reject=dict(mag=4e-12, grad=4000e-13, - eog=150e-6)) +epochs = mne.Epochs( + raw, + events, + event_id, + tmin, + tmax, + picks=picks, + baseline=(None, 0), + reject=dict(mag=4e-12, grad=4000e-13, eog=150e-6), +) # Get evoked data (averaging across trials in sensor space) evoked = epochs.average() @@ -66,21 +74,27 @@ # Compute inverse solution and stcs for each epoch # Use the same inverse operator as with evoked data (i.e., set nave) # If you use a different nave, dSPM just scales by a factor sqrt(nave) -stcs = apply_inverse_epochs(epochs, inverse_operator, lambda2, method, label, - pick_ori="normal", nave=evoked.nave) +stcs = apply_inverse_epochs( + epochs, + inverse_operator, + lambda2, + method, + label, + pick_ori="normal", + nave=evoked.nave, +) # Mean across trials but not across vertices in label mean_stc = sum(stcs) / len(stcs) # compute sign flip to avoid signal cancellation when averaging signed values -flip = mne.label_sign_flip(label, inverse_operator['src']) +flip = mne.label_sign_flip(label, inverse_operator["src"]) label_mean = np.mean(mean_stc.data, axis=0) label_mean_flip = np.mean(flip[:, np.newaxis] * mean_stc.data, axis=0) # Get inverse solution by inverting evoked data -stc_evoked = apply_inverse(evoked, inverse_operator, lambda2, method, - pick_ori="normal") +stc_evoked = apply_inverse(evoked, inverse_operator, lambda2, method, pick_ori="normal") # apply_inverse() does whole brain, so sub-select label of interest stc_evoked_label = stc_evoked.in_label(label) @@ -94,13 +108,12 @@ times = 1e3 * stcs[0].times # times in ms plt.figure() -h0 = plt.plot(times, mean_stc.data.T, 'k') -h1, = plt.plot(times, label_mean, 'r', linewidth=3) -h2, = plt.plot(times, label_mean_flip, 'g', linewidth=3) -plt.legend((h0[0], h1, h2), ('all dipoles in label', 'mean', - 'mean with sign flip')) -plt.xlabel('time (ms)') -plt.ylabel('dSPM value') +h0 = plt.plot(times, mean_stc.data.T, "k") +(h1,) = plt.plot(times, label_mean, "r", linewidth=3) +(h2,) = plt.plot(times, label_mean_flip, "g", linewidth=3) +plt.legend((h0[0], h1, h2), ("all dipoles in label", "mean", "mean with sign flip")) +plt.xlabel("time (ms)") +plt.ylabel("dSPM value") plt.show() # %% @@ -110,19 +123,21 @@ # Single trial plt.figure() for k, stc_trial in enumerate(stcs): - plt.plot(times, np.mean(stc_trial.data, axis=0).T, 'k--', - label='Single Trials' if k == 0 else '_nolegend_', - alpha=0.5) + plt.plot( + times, + np.mean(stc_trial.data, axis=0).T, + "k--", + label="Single Trials" if k == 0 else "_nolegend_", + alpha=0.5, + ) # Single trial inverse then average.. making linewidth large to not be masked -plt.plot(times, label_mean, 'b', linewidth=6, - label='dSPM first, then average') +plt.plot(times, label_mean, "b", linewidth=6, label="dSPM first, then average") # Evoked and then inverse -plt.plot(times, label_mean_evoked, 'r', linewidth=2, - label='Average first, then dSPM') +plt.plot(times, label_mean_evoked, "r", linewidth=2, label="Average first, then dSPM") -plt.xlabel('time (ms)') -plt.ylabel('dSPM value') +plt.xlabel("time (ms)") +plt.ylabel("dSPM value") plt.legend() plt.show() diff --git a/examples/inverse/compute_mne_inverse_raw_in_label.py b/examples/inverse/compute_mne_inverse_raw_in_label.py index 1d473f2db1f..5c15563f76a 100644 --- a/examples/inverse/compute_mne_inverse_raw_in_label.py +++ b/examples/inverse/compute_mne_inverse_raw_in_label.py @@ -25,14 +25,13 @@ print(__doc__) data_path = sample.data_path() -fname_inv = ( - data_path / 'MEG' / 'sample' / 'sample_audvis-meg-oct-6-meg-inv.fif') -fname_raw = data_path / 'MEG' / 'sample' / 'sample_audvis_raw.fif' -label_name = 'Aud-lh' -fname_label = data_path / 'MEG' / 'sample' / 'labels' / f'{label_name}.label' +fname_inv = data_path / "MEG" / "sample" / "sample_audvis-meg-oct-6-meg-inv.fif" +fname_raw = data_path / "MEG" / "sample" / "sample_audvis_raw.fif" +label_name = "Aud-lh" +fname_label = data_path / "MEG" / "sample" / "labels" / f"{label_name}.label" snr = 1.0 # use smaller SNR for raw data -lambda2 = 1.0 / snr ** 2 +lambda2 = 1.0 / snr**2 method = "sLORETA" # use sLORETA method (could also be MNE or dSPM) # Load data @@ -40,19 +39,20 @@ inverse_operator = read_inverse_operator(fname_inv) label = mne.read_label(fname_label) -raw.set_eeg_reference('average', projection=True) # set average reference. +raw.set_eeg_reference("average", projection=True) # set average reference. start, stop = raw.time_as_index([0, 15]) # read the first 15s of data # Compute inverse solution -stc = apply_inverse_raw(raw, inverse_operator, lambda2, method, label, - start, stop, pick_ori=None) +stc = apply_inverse_raw( + raw, inverse_operator, lambda2, method, label, start, stop, pick_ori=None +) # Save result in stc files -stc.save('mne_%s_raw_inverse_%s' % (method, label_name), overwrite=True) +stc.save("mne_%s_raw_inverse_%s" % (method, label_name), overwrite=True) # %% # View activation time-series plt.plot(1e3 * stc.times, stc.data[::100, :].T) -plt.xlabel('time (ms)') -plt.ylabel('%s value' % method) +plt.xlabel("time (ms)") +plt.ylabel("%s value" % method) plt.show() diff --git a/examples/inverse/compute_mne_inverse_volume.py b/examples/inverse/compute_mne_inverse_volume.py index 215977ca393..7b5193a081b 100644 --- a/examples/inverse/compute_mne_inverse_volume.py +++ b/examples/inverse/compute_mne_inverse_volume.py @@ -24,33 +24,36 @@ print(__doc__) data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -fname_inv = meg_path / 'sample_audvis-meg-vol-7-meg-inv.fif' -fname_evoked = meg_path / 'sample_audvis-ave.fif' +meg_path = data_path / "MEG" / "sample" +fname_inv = meg_path / "sample_audvis-meg-vol-7-meg-inv.fif" +fname_evoked = meg_path / "sample_audvis-ave.fif" snr = 3.0 -lambda2 = 1.0 / snr ** 2 +lambda2 = 1.0 / snr**2 method = "dSPM" # use dSPM method (could also be MNE or sLORETA) # Load data evoked = read_evokeds(fname_evoked, condition=0, baseline=(None, 0)) inverse_operator = read_inverse_operator(fname_inv) -src = inverse_operator['src'] +src = inverse_operator["src"] # Compute inverse solution stc = apply_inverse(evoked, inverse_operator, lambda2, method) stc.crop(0.0, 0.2) # Export result as a 4D nifti object -img = stc.as_volume(src, - mri_resolution=False) # set True for full MRI resolution +img = stc.as_volume(src, mri_resolution=False) # set True for full MRI resolution # Save it as a nifti file # nib.save(img, 'mne_%s_inverse.nii.gz' % method) -t1_fname = data_path / 'subjects' / 'sample' / 'mri' / 'T1.mgz' +t1_fname = data_path / "subjects" / "sample" / "mri" / "T1.mgz" # %% # Plot with nilearn: -plot_stat_map(index_img(img, 61), str(t1_fname), threshold=8., - title='%s (t=%.1f s.)' % (method, stc.times[61])) +plot_stat_map( + index_img(img, 61), + str(t1_fname), + threshold=8.0, + title="%s (t=%.1f s.)" % (method, stc.times[61]), +) diff --git a/examples/inverse/custom_inverse_solver.py b/examples/inverse/custom_inverse_solver.py index 760ef4408e5..3324b1198bb 100644 --- a/examples/inverse/custom_inverse_solver.py +++ b/examples/inverse/custom_inverse_solver.py @@ -29,12 +29,12 @@ data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -fwd_fname = meg_path / 'sample_audvis-meg-eeg-oct-6-fwd.fif' -ave_fname = meg_path / 'sample_audvis-ave.fif' -cov_fname = meg_path / 'sample_audvis-shrunk-cov.fif' -subjects_dir = data_path / 'subjects' -condition = 'Left Auditory' +meg_path = data_path / "MEG" / "sample" +fwd_fname = meg_path / "sample_audvis-meg-eeg-oct-6-fwd.fif" +ave_fname = meg_path / "sample_audvis-ave.fif" +cov_fname = meg_path / "sample_audvis-shrunk-cov.fif" +subjects_dir = data_path / "subjects" +condition = "Left Auditory" # Read noise covariance matrix noise_cov = mne.read_cov(cov_fname) @@ -50,6 +50,7 @@ # %% # Auxiliary function to run the solver + def apply_solver(solver, evoked, forward, noise_cov, loose=0.2, depth=0.8): """Call a custom solver on evoked data. @@ -93,19 +94,30 @@ def apply_solver(solver, evoked, forward, noise_cov, loose=0.2, depth=0.8): The source estimates. """ # Import the necessary private functions - from mne.inverse_sparse.mxne_inverse import \ - (_prepare_gain, is_fixed_orient, - _reapply_source_weighting, _make_sparse_stc) + from mne.inverse_sparse.mxne_inverse import ( + _prepare_gain, + is_fixed_orient, + _reapply_source_weighting, + _make_sparse_stc, + ) all_ch_names = evoked.ch_names # Handle depth weighting and whitening (here is no weights) forward, gain, gain_info, whitener, source_weighting, mask = _prepare_gain( - forward, evoked.info, noise_cov, pca=False, depth=depth, - loose=loose, weights=None, weights_min=None, rank=None) + forward, + evoked.info, + noise_cov, + pca=False, + depth=depth, + loose=loose, + weights=None, + weights_min=None, + rank=None, + ) # Select channels of interest - sel = [all_ch_names.index(name) for name in gain_info['ch_names']] + sel = [all_ch_names.index(name) for name in gain_info["ch_names"]] M = evoked.data[sel] # Whiten data @@ -115,8 +127,9 @@ def apply_solver(solver, evoked, forward, noise_cov, loose=0.2, depth=0.8): X, active_set = solver(M, gain, n_orient) X = _reapply_source_weighting(X, source_weighting, active_set) - stc = _make_sparse_stc(X, active_set, forward, tmin=evoked.times[0], - tstep=1. / evoked.info['sfreq']) + stc = _make_sparse_stc( + X, active_set, forward, tmin=evoked.times[0], tstep=1.0 / evoked.info["sfreq"] + ) return stc @@ -124,6 +137,7 @@ def apply_solver(solver, evoked, forward, noise_cov, loose=0.2, depth=0.8): # %% # Define your solver + def solver(M, G, n_orient): """Run L2 penalized regression and keep 10 strongest locations. @@ -155,11 +169,11 @@ def solver(M, G, n_orient): K /= np.linalg.norm(K, axis=1)[:, None] X = np.dot(K, M) - indices = np.argsort(np.sum(X ** 2, axis=1))[-10:] + indices = np.argsort(np.sum(X**2, axis=1))[-10:] active_set = np.zeros(G.shape[1], dtype=bool) for idx in indices: idx -= idx % n_orient - active_set[idx:idx + n_orient] = True + active_set[idx : idx + n_orient] = True X = X[active_set] return X, active_set @@ -168,10 +182,9 @@ def solver(M, G, n_orient): # Apply your custom solver # loose, depth = 0.2, 0.8 # corresponds to loose orientation -loose, depth = 1., 0. # corresponds to free orientation +loose, depth = 1.0, 0.0 # corresponds to free orientation stc = apply_solver(solver, evoked, forward, noise_cov, loose, depth) # %% # View in 2D and 3D ("glass" brain like 3D plot) -plot_sparse_source_estimates(forward['src'], stc, bgcolor=(1, 1, 1), - opacity=0.1) +plot_sparse_source_estimates(forward["src"], stc, bgcolor=(1, 1, 1), opacity=0.1) diff --git a/examples/inverse/dics_epochs.py b/examples/inverse/dics_epochs.py index 8aba68b9e44..dc8a0b7e14c 100644 --- a/examples/inverse/dics_epochs.py +++ b/examples/inverse/dics_epochs.py @@ -28,13 +28,13 @@ # Organize the data that we will use for this example. data_path = somato.data_path() -subject = '01' -task = 'somato' -raw_fname = (data_path / f'sub-{subject}' / 'meg' / - f'sub-{subject}_task-{task}_meg.fif') -fname_fwd = (data_path / 'derivatives' / f'sub-{subject}' / - f'sub-{subject}_task-{task}-fwd.fif') -subjects_dir = data_path / 'derivatives' / 'freesurfer' / 'subjects' +subject = "01" +task = "somato" +raw_fname = data_path / f"sub-{subject}" / "meg" / f"sub-{subject}_task-{task}_meg.fif" +fname_fwd = ( + data_path / "derivatives" / f"sub-{subject}" / f"sub-{subject}_task-{task}-fwd.fif" +) +subjects_dir = data_path / "derivatives" / "freesurfer" / "subjects" # %% # First, we load the data and compute for each epoch the time-frequency @@ -43,11 +43,19 @@ # Load raw data and make epochs. raw = mne.io.read_raw_fif(raw_fname) events = mne.find_events(raw) -epochs = mne.Epochs(raw, events, event_id=1, tmin=-1, tmax=2.5, - reject=dict(grad=5000e-13, # unit: T / m (gradiometers) - mag=5e-12, # unit: T (magnetometers) - eog=250e-6, # unit: V (EOG channels) - ), preload=True) +epochs = mne.Epochs( + raw, + events, + event_id=1, + tmin=-1, + tmax=2.5, + reject=dict( + grad=5000e-13, # unit: T / m (gradiometers) + mag=5e-12, # unit: T (magnetometers) + eog=250e-6, # unit: V (EOG channels) + ), + preload=True, +) epochs = epochs[:10] # just for speed of execution for the tutorial # We are mostly interested in the beta band since it has been shown to be @@ -58,8 +66,9 @@ # decomposition for each epoch. We must pass ``output='complex'`` if we wish to # use this TFR later with a DICS beamformer. We also pass ``average=False`` to # compute the TFR for each individual epoch. -epochs_tfr = tfr_morlet(epochs, freqs, n_cycles=5, return_itc=False, - output='complex', average=False) +epochs_tfr = tfr_morlet( + epochs, freqs, n_cycles=5, return_itc=False, output="complex", average=False +) # crop either side to use a buffer to remove edge artifact epochs_tfr.crop(tmin=-0.5, tmax=2) @@ -78,15 +87,21 @@ fwd = mne.read_forward_solution(fname_fwd) # compute scalar DICS beamfomer -filters = make_dics(epochs.info, fwd, csd, noise_csd=baseline_csd, - pick_ori='max-power', reduce_rank=True, real_filter=True) +filters = make_dics( + epochs.info, + fwd, + csd, + noise_csd=baseline_csd, + pick_ori="max-power", + reduce_rank=True, + real_filter=True, +) # project the TFR for each epoch to source space -epochs_stcs = apply_dics_tfr_epochs( - epochs_tfr, filters, return_generator=True) +epochs_stcs = apply_dics_tfr_epochs(epochs_tfr, filters, return_generator=True) # average across frequencies and epochs -data = np.zeros((fwd['nsource'], epochs_tfr.times.size)) +data = np.zeros((fwd["nsource"], epochs_tfr.times.size)) for epoch_stcs in epochs_stcs: for stc in epoch_stcs: data += (stc.data * np.conj(stc.data)).real @@ -104,13 +119,17 @@ fmax = 4500 brain = stc.plot( subjects_dir=subjects_dir, - hemi='both', - views='dorsal', + hemi="both", + views="dorsal", initial_time=0.55, brain_kwargs=dict(show=False), - add_data_kwargs=dict(fmin=fmax / 10, fmid=fmax / 2, fmax=fmax, - scale_factor=0.0001, - colorbar_kwargs=dict(label_font_size=10)) + add_data_kwargs=dict( + fmin=fmax / 10, + fmid=fmax / 2, + fmax=fmax, + scale_factor=0.0001, + colorbar_kwargs=dict(label_font_size=10), + ), ) # You can save a movie like the one on our documentation website with: diff --git a/examples/inverse/dics_source_power.py b/examples/inverse/dics_source_power.py index 8a3ee2c1cf6..68925202b17 100644 --- a/examples/inverse/dics_source_power.py +++ b/examples/inverse/dics_source_power.py @@ -31,10 +31,9 @@ # %% # Reading the raw data and creating epochs: data_path = somato.data_path() -subject = '01' -task = 'somato' -raw_fname = (data_path / f'sub-{subject}' / 'meg' / - f'sub-{subject}_task-{task}_meg.fif') +subject = "01" +task = "somato" +raw_fname = data_path / f"sub-{subject}" / "meg" / f"sub-{subject}_task-{task}_meg.fif" # Use a shorter segment of raw just for speed here raw = mne.io.read_raw_fif(raw_fname) @@ -47,10 +46,11 @@ del raw # Paths to forward operator and FreeSurfer subject directory -fname_fwd = (data_path / 'derivatives' / f'sub-{subject}' / - f'sub-{subject}_task-{task}-fwd.fif') +fname_fwd = ( + data_path / "derivatives" / f"sub-{subject}" / f"sub-{subject}_task-{task}-fwd.fif" +) -subjects_dir = data_path / 'derivatives' / 'freesurfer' / 'subjects' +subjects_dir = data_path / "derivatives" / "freesurfer" / "subjects" # %% # We are interested in the beta band. Define a range of frequencies, using a @@ -79,8 +79,15 @@ # Computing DICS spatial filters using the CSD that was computed on the entire # timecourse. fwd = mne.read_forward_solution(fname_fwd) -filters = make_dics(info, fwd, csd, noise_csd=csd_baseline, - pick_ori='max-power', reduce_rank=True, real_filter=True) +filters = make_dics( + info, + fwd, + csd, + noise_csd=csd_baseline, + pick_ori="max-power", + reduce_rank=True, + real_filter=True, +) del fwd # %% @@ -92,9 +99,14 @@ # %% # Visualizing source power during ERS activity relative to the baseline power. stc = beta_source_power / baseline_source_power -message = 'DICS source power in the 12-30 Hz frequency band' -brain = stc.plot(hemi='both', views='axial', subjects_dir=subjects_dir, - subject=subject, time_label=message) +message = "DICS source power in the 12-30 Hz frequency band" +brain = stc.plot( + hemi="both", + views="axial", + subjects_dir=subjects_dir, + subject=subject, + time_label=message, +) # %% # References diff --git a/examples/inverse/evoked_ers_source_power.py b/examples/inverse/evoked_ers_source_power.py index b3ccaab5e04..272b0518293 100644 --- a/examples/inverse/evoked_ers_source_power.py +++ b/examples/inverse/evoked_ers_source_power.py @@ -22,19 +22,22 @@ from mne.cov import compute_covariance from mne.datasets import somato from mne.time_frequency import csd_morlet -from mne.beamformer import (make_dics, apply_dics_csd, make_lcmv, - apply_lcmv_cov) -from mne.minimum_norm import (make_inverse_operator, apply_inverse_cov) +from mne.beamformer import make_dics, apply_dics_csd, make_lcmv, apply_lcmv_cov +from mne.minimum_norm import make_inverse_operator, apply_inverse_cov print(__doc__) # %% # Reading the raw data and creating epochs: data_path = somato.data_path() -subject = '01' -task = 'somato' -raw_fname = (data_path / 'sub-{}'.format(subject) / 'meg' / - 'sub-{}_task-{}_meg.fif'.format(subject, task)) +subject = "01" +task = "somato" +raw_fname = ( + data_path + / "sub-{}".format(subject) + / "meg" + / "sub-{}_task-{}_meg.fif".format(subject, task) +) # crop to 5 minutes to save memory raw = mne.io.read_raw_fif(raw_fname).crop(0, 300) @@ -44,17 +47,22 @@ # The DICS beamformer currently only supports a single sensor type. # We'll use the gradiometers in this example. -picks = mne.pick_types(raw.info, meg='grad', exclude='bads') +picks = mne.pick_types(raw.info, meg="grad", exclude="bads") # Read epochs events = mne.find_events(raw) -epochs = mne.Epochs(raw, events, event_id=1, tmin=-1.5, tmax=2, picks=picks, - preload=True, decim=3) +epochs = mne.Epochs( + raw, events, event_id=1, tmin=-1.5, tmax=2, picks=picks, preload=True, decim=3 +) # Read forward operator and point to freesurfer subject directory -fname_fwd = (data_path / 'derivatives' / 'sub-{}'.format(subject) / - 'sub-{}_task-{}-fwd.fif'.format(subject, task)) -subjects_dir = data_path / 'derivatives' / 'freesurfer' / 'subjects' +fname_fwd = ( + data_path + / "derivatives" + / "sub-{}".format(subject) + / "sub-{}_task-{}-fwd.fif".format(subject, task) +) +subjects_dir = data_path / "derivatives" / "freesurfer" / "subjects" fwd = mne.read_forward_solution(fname_fwd) @@ -68,14 +76,25 @@ # combination with an advanced covariance estimator like "shrunk", the rank # will be correctly preserved. -rank = mne.compute_rank(epochs, tol=1e-6, tol_kind='relative') +rank = mne.compute_rank(epochs, tol=1e-6, tol_kind="relative") active_win = (0.5, 1.5) baseline_win = (-1, 0) -baseline_cov = compute_covariance(epochs, tmin=baseline_win[0], - tmax=baseline_win[1], method='shrunk', - rank=rank, verbose=True) -active_cov = compute_covariance(epochs, tmin=active_win[0], tmax=active_win[1], - method='shrunk', rank=rank, verbose=True) +baseline_cov = compute_covariance( + epochs, + tmin=baseline_win[0], + tmax=baseline_win[1], + method="shrunk", + rank=rank, + verbose=True, +) +active_cov = compute_covariance( + epochs, + tmin=active_win[0], + tmax=active_win[1], + method="shrunk", + rank=rank, + verbose=True, +) # Weighted averaging is already in the addition of covariance objects. common_cov = baseline_cov + active_cov @@ -93,12 +112,21 @@ def _gen_dics(active_win, baseline_win, epochs): freqs = np.logspace(np.log10(12), np.log10(30), 9) csd = csd_morlet(epochs, freqs, tmin=-1, tmax=1.5, decim=20) - csd_baseline = csd_morlet(epochs, freqs, tmin=baseline_win[0], - tmax=baseline_win[1], decim=20) - csd_ers = csd_morlet(epochs, freqs, tmin=active_win[0], tmax=active_win[1], - decim=20) - filters = make_dics(epochs.info, fwd, csd.mean(), pick_ori='max-power', - reduce_rank=True, real_filter=True, rank=rank) + csd_baseline = csd_morlet( + epochs, freqs, tmin=baseline_win[0], tmax=baseline_win[1], decim=20 + ) + csd_ers = csd_morlet( + epochs, freqs, tmin=active_win[0], tmax=active_win[1], decim=20 + ) + filters = make_dics( + epochs.info, + fwd, + csd.mean(), + pick_ori="max-power", + reduce_rank=True, + real_filter=True, + rank=rank, + ) stc_base, freqs = apply_dics_csd(csd_baseline.mean(), filters) stc_act, freqs = apply_dics_csd(csd_ers.mean(), filters) stc_act /= stc_base @@ -107,8 +135,9 @@ def _gen_dics(active_win, baseline_win, epochs): # generate lcmv source estimate def _gen_lcmv(active_cov, baseline_cov, common_cov): - filters = make_lcmv(epochs.info, fwd, common_cov, reg=0.05, - noise_cov=None, pick_ori='max-power') + filters = make_lcmv( + epochs.info, fwd, common_cov, reg=0.05, noise_cov=None, pick_ori="max-power" + ) stc_base = apply_lcmv_cov(baseline_cov, filters) stc_act = apply_lcmv_cov(active_cov, filters) stc_act /= stc_base @@ -116,12 +145,14 @@ def _gen_lcmv(active_cov, baseline_cov, common_cov): # generate mne/dSPM source estimate -def _gen_mne(active_cov, baseline_cov, common_cov, fwd, info, method='dSPM'): +def _gen_mne(active_cov, baseline_cov, common_cov, fwd, info, method="dSPM"): inverse_operator = make_inverse_operator(info, fwd, common_cov) - stc_act = apply_inverse_cov(active_cov, info, inverse_operator, - method=method, verbose=True) - stc_base = apply_inverse_cov(baseline_cov, info, inverse_operator, - method=method, verbose=True) + stc_act = apply_inverse_cov( + active_cov, info, inverse_operator, method=method, verbose=True + ) + stc_base = apply_inverse_cov( + baseline_cov, info, inverse_operator, method=method, verbose=True + ) stc_act /= stc_base return stc_act @@ -137,22 +168,31 @@ def _gen_mne(active_cov, baseline_cov, common_cov, fwd, info, method='dSPM'): # DICS: brain_dics = stc_dics.plot( - hemi='rh', subjects_dir=subjects_dir, subject=subject, - time_label='DICS source power in the 12-30 Hz frequency band') + hemi="rh", + subjects_dir=subjects_dir, + subject=subject, + time_label="DICS source power in the 12-30 Hz frequency band", +) # %% # LCMV: brain_lcmv = stc_lcmv.plot( - hemi='rh', subjects_dir=subjects_dir, subject=subject, - time_label='LCMV source power in the 12-30 Hz frequency band') + hemi="rh", + subjects_dir=subjects_dir, + subject=subject, + time_label="LCMV source power in the 12-30 Hz frequency band", +) # %% # dSPM: brain_dspm = stc_dspm.plot( - hemi='rh', subjects_dir=subjects_dir, subject=subject, - time_label='dSPM source power in the 12-30 Hz frequency band') + hemi="rh", + subjects_dir=subjects_dir, + subject=subject, + time_label="dSPM source power in the 12-30 Hz frequency band", +) # %% # For more advanced usage, see diff --git a/examples/inverse/gamma_map_inverse.py b/examples/inverse/gamma_map_inverse.py index 20a205c3322..f3ff529a331 100644 --- a/examples/inverse/gamma_map_inverse.py +++ b/examples/inverse/gamma_map_inverse.py @@ -19,22 +19,24 @@ import mne from mne.datasets import sample from mne.inverse_sparse import gamma_map, make_stc_from_dipoles -from mne.viz import (plot_sparse_source_estimates, - plot_dipole_locations, plot_dipole_amplitudes) +from mne.viz import ( + plot_sparse_source_estimates, + plot_dipole_locations, + plot_dipole_amplitudes, +) print(__doc__) data_path = sample.data_path() -subjects_dir = data_path / 'subjects' -meg_path = data_path / 'MEG' / 'sample' -fwd_fname = meg_path / 'sample_audvis-meg-eeg-oct-6-fwd.fif' -evoked_fname = meg_path / 'sample_audvis-ave.fif' -cov_fname = meg_path / 'sample_audvis-cov.fif' +subjects_dir = data_path / "subjects" +meg_path = data_path / "MEG" / "sample" +fwd_fname = meg_path / "sample_audvis-meg-eeg-oct-6-fwd.fif" +evoked_fname = meg_path / "sample_audvis-ave.fif" +cov_fname = meg_path / "sample_audvis-cov.fif" # Read the evoked response and crop it -condition = 'Left visual' -evoked = mne.read_evokeds(evoked_fname, condition=condition, - baseline=(None, 0)) +condition = "Left visual" +evoked = mne.read_evokeds(evoked_fname, condition=condition, baseline=(None, 0)) evoked.crop(tmin=-50e-3, tmax=300e-3) # Read the forward solution @@ -47,8 +49,14 @@ # Run the Gamma-MAP method with dipole output alpha = 0.5 dipoles, residual = gamma_map( - evoked, forward, cov, alpha, xyz_same_gamma=True, return_residual=True, - return_as_dipoles=True) + evoked, + forward, + cov, + alpha, + xyz_same_gamma=True, + return_residual=True, + return_as_dipoles=True, +) # %% # Plot dipole activations @@ -56,9 +64,14 @@ # Plot dipole location of the strongest dipole with MRI slices idx = np.argmax([np.max(np.abs(dip.amplitude)) for dip in dipoles]) -plot_dipole_locations(dipoles[idx], forward['mri_head_t'], 'sample', - subjects_dir=subjects_dir, mode='orthoview', - idx='amplitude') +plot_dipole_locations( + dipoles[idx], + forward["mri_head_t"], + "sample", + subjects_dir=subjects_dir, + mode="orthoview", + idx="amplitude", +) # # Plot dipole locations of all dipoles with MRI slices # for dip in dipoles: @@ -69,17 +82,22 @@ # %% # Show the evoked response and the residual for gradiometers ylim = dict(grad=[-120, 120]) -evoked.pick_types(meg='grad', exclude='bads') -evoked.plot(titles=dict(grad='Evoked Response Gradiometers'), ylim=ylim, - proj=True, time_unit='s') - -residual.pick_types(meg='grad', exclude='bads') -residual.plot(titles=dict(grad='Residuals Gradiometers'), ylim=ylim, - proj=True, time_unit='s') +evoked.pick_types(meg="grad", exclude="bads") +evoked.plot( + titles=dict(grad="Evoked Response Gradiometers"), + ylim=ylim, + proj=True, + time_unit="s", +) + +residual.pick_types(meg="grad", exclude="bads") +residual.plot( + titles=dict(grad="Residuals Gradiometers"), ylim=ylim, proj=True, time_unit="s" +) # %% # Generate stc from dipoles -stc = make_stc_from_dipoles(dipoles, forward['src']) +stc = make_stc_from_dipoles(dipoles, forward["src"]) # %% # View in 2D and 3D ("glass" brain like 3D plot) @@ -88,9 +106,14 @@ scale_factors = 0.5 * (1 + scale_factors / np.max(scale_factors)) plot_sparse_source_estimates( - forward['src'], stc, bgcolor=(1, 1, 1), - modes=['sphere'], opacity=0.1, scale_factors=(scale_factors, None), - fig_name="Gamma-MAP") + forward["src"], + stc, + bgcolor=(1, 1, 1), + modes=["sphere"], + opacity=0.1, + scale_factors=(scale_factors, None), + fig_name="Gamma-MAP", +) # %% # References diff --git a/examples/inverse/label_activation_from_stc.py b/examples/inverse/label_activation_from_stc.py index 20368b68183..358de19bff2 100644 --- a/examples/inverse/label_activation_from_stc.py +++ b/examples/inverse/label_activation_from_stc.py @@ -24,15 +24,15 @@ print(__doc__) data_path = sample.data_path() -subjects_dir = data_path / 'subjects' -meg_path = data_path / 'MEG' / 'sample' +subjects_dir = data_path / "subjects" +meg_path = data_path / "MEG" / "sample" # load the stc -stc = mne.read_source_estimate(meg_path / 'sample_audvis-meg') +stc = mne.read_source_estimate(meg_path / "sample_audvis-meg") # load the labels -aud_lh = mne.read_label(meg_path / 'labels' / 'Aud-lh.label') -aud_rh = mne.read_label(meg_path / 'labels' / 'Aud-rh.label') +aud_lh = mne.read_label(meg_path / "labels" / "Aud-lh.label") +aud_rh = mne.read_label(meg_path / "labels" / "Aud-rh.label") # extract the time course for different labels from the stc stc_lh = stc.in_label(aud_lh) @@ -40,25 +40,27 @@ stc_bh = stc.in_label(aud_lh + aud_rh) # calculate center of mass and transform to mni coordinates -vtx, _, t_lh = stc_lh.center_of_mass('sample', subjects_dir=subjects_dir) -mni_lh = mne.vertex_to_mni(vtx, 0, 'sample', subjects_dir=subjects_dir)[0] -vtx, _, t_rh = stc_rh.center_of_mass('sample', subjects_dir=subjects_dir) -mni_rh = mne.vertex_to_mni(vtx, 1, 'sample', subjects_dir=subjects_dir)[0] +vtx, _, t_lh = stc_lh.center_of_mass("sample", subjects_dir=subjects_dir) +mni_lh = mne.vertex_to_mni(vtx, 0, "sample", subjects_dir=subjects_dir)[0] +vtx, _, t_rh = stc_rh.center_of_mass("sample", subjects_dir=subjects_dir) +mni_rh = mne.vertex_to_mni(vtx, 1, "sample", subjects_dir=subjects_dir)[0] # plot the activation plt.figure() -plt.axes([.1, .275, .85, .625]) -hl = plt.plot(stc.times, stc_lh.data.mean(0), 'b')[0] -hr = plt.plot(stc.times, stc_rh.data.mean(0), 'g')[0] -hb = plt.plot(stc.times, stc_bh.data.mean(0), 'r')[0] -plt.xlabel('Time (s)') -plt.ylabel('Source amplitude (dSPM)') +plt.axes([0.1, 0.275, 0.85, 0.625]) +hl = plt.plot(stc.times, stc_lh.data.mean(0), "b")[0] +hr = plt.plot(stc.times, stc_rh.data.mean(0), "g")[0] +hb = plt.plot(stc.times, stc_bh.data.mean(0), "r")[0] +plt.xlabel("Time (s)") +plt.ylabel("Source amplitude (dSPM)") plt.xlim(stc.times[0], stc.times[-1]) # add a legend including center-of-mass mni coordinates to the plot -labels = ['LH: center of mass = %s' % mni_lh.round(2), - 'RH: center of mass = %s' % mni_rh.round(2), - 'Combined LH & RH'] -plt.figlegend([hl, hr, hb], labels, loc='lower center') -plt.suptitle('Average activation in auditory cortex labels', fontsize=20) +labels = [ + "LH: center of mass = %s" % mni_lh.round(2), + "RH: center of mass = %s" % mni_rh.round(2), + "Combined LH & RH", +] +plt.figlegend([hl, hr, hb], labels, loc="lower center") +plt.suptitle("Average activation in auditory cortex labels", fontsize=20) plt.show() diff --git a/examples/inverse/label_from_stc.py b/examples/inverse/label_from_stc.py index 3d3abae2a16..39469e8c68b 100644 --- a/examples/inverse/label_from_stc.py +++ b/examples/inverse/label_from_stc.py @@ -28,29 +28,27 @@ print(__doc__) data_path = sample.data_path() -fname_inv = ( - data_path / 'MEG' / 'sample' / 'sample_audvis-meg-oct-6-meg-inv.fif') -fname_evoked = data_path / 'MEG' / 'sample' / 'sample_audvis-ave.fif' -subjects_dir = data_path / 'subjects' -subject = 'sample' +fname_inv = data_path / "MEG" / "sample" / "sample_audvis-meg-oct-6-meg-inv.fif" +fname_evoked = data_path / "MEG" / "sample" / "sample_audvis-ave.fif" +subjects_dir = data_path / "subjects" +subject = "sample" snr = 3.0 -lambda2 = 1.0 / snr ** 2 +lambda2 = 1.0 / snr**2 method = "dSPM" # use dSPM method (could also be MNE or sLORETA) # Compute a label/ROI based on the peak power between 80 and 120 ms. # The label bankssts-lh is used for the comparison. -aparc_label_name = 'bankssts-lh' +aparc_label_name = "bankssts-lh" tmin, tmax = 0.080, 0.120 # Load data evoked = mne.read_evokeds(fname_evoked, condition=0, baseline=(None, 0)) inverse_operator = read_inverse_operator(fname_inv) -src = inverse_operator['src'] # get the source space +src = inverse_operator["src"] # get the source space # Compute inverse solution -stc = apply_inverse(evoked, inverse_operator, lambda2, method, - pick_ori='normal') +stc = apply_inverse(evoked, inverse_operator, lambda2, method, pick_ori="normal") # Make an STC in the time interval of interest and take the mean stc_mean = stc.copy().crop(tmin, tmax).mean() @@ -58,33 +56,38 @@ # use the stc_mean to generate a functional label # region growing is halted at 60% of the peak value within the # anatomical label / ROI specified by aparc_label_name -label = mne.read_labels_from_annot(subject, parc='aparc', - subjects_dir=subjects_dir, - regexp=aparc_label_name)[0] +label = mne.read_labels_from_annot( + subject, parc="aparc", subjects_dir=subjects_dir, regexp=aparc_label_name +)[0] stc_mean_label = stc_mean.in_label(label) data = np.abs(stc_mean_label.data) -stc_mean_label.data[data < 0.6 * np.max(data)] = 0. +stc_mean_label.data[data < 0.6 * np.max(data)] = 0.0 # 8.5% of original source space vertices were omitted during forward # calculation, suppress the warning here with verbose='error' -func_labels, _ = mne.stc_to_label(stc_mean_label, src=src, smooth=True, - subjects_dir=subjects_dir, connected=True, - verbose='error') +func_labels, _ = mne.stc_to_label( + stc_mean_label, + src=src, + smooth=True, + subjects_dir=subjects_dir, + connected=True, + verbose="error", +) # take first as func_labels are ordered based on maximum values in stc func_label = func_labels[0] # load the anatomical ROI for comparison -anat_label = mne.read_labels_from_annot(subject, parc='aparc', - subjects_dir=subjects_dir, - regexp=aparc_label_name)[0] +anat_label = mne.read_labels_from_annot( + subject, parc="aparc", subjects_dir=subjects_dir, regexp=aparc_label_name +)[0] # extract the anatomical time course for each label stc_anat_label = stc.in_label(anat_label) -pca_anat = stc.extract_label_time_course(anat_label, src, mode='pca_flip')[0] +pca_anat = stc.extract_label_time_course(anat_label, src, mode="pca_flip")[0] stc_func_label = stc.in_label(func_label) -pca_func = stc.extract_label_time_course(func_label, src, mode='pca_flip')[0] +pca_func = stc.extract_label_time_course(func_label, src, mode="pca_flip")[0] # flip the pca so that the max power between tmin and tmax is positive pca_anat *= np.sign(pca_anat[np.argmax(np.abs(pca_anat))]) @@ -93,18 +96,20 @@ # %% # plot the time courses.... plt.figure() -plt.plot(1e3 * stc_anat_label.times, pca_anat, 'k', - label='Anatomical %s' % aparc_label_name) -plt.plot(1e3 * stc_func_label.times, pca_func, 'b', - label='Functional %s' % aparc_label_name) +plt.plot( + 1e3 * stc_anat_label.times, pca_anat, "k", label="Anatomical %s" % aparc_label_name +) +plt.plot( + 1e3 * stc_func_label.times, pca_func, "b", label="Functional %s" % aparc_label_name +) plt.legend() plt.show() # %% # plot brain in 3D with mne.viz.Brain if available -brain = stc_mean.plot(hemi='lh', subjects_dir=subjects_dir) -brain.show_view('lateral') +brain = stc_mean.plot(hemi="lh", subjects_dir=subjects_dir) +brain.show_view("lateral") # show both labels -brain.add_label(anat_label, borders=True, color='k') -brain.add_label(func_label, borders=True, color='b') +brain.add_label(anat_label, borders=True, color="k") +brain.add_label(func_label, borders=True, color="b") diff --git a/examples/inverse/label_source_activations.py b/examples/inverse/label_source_activations.py index 30a55970d81..599fff4c2f8 100644 --- a/examples/inverse/label_source_activations.py +++ b/examples/inverse/label_source_activations.py @@ -28,32 +28,31 @@ print(__doc__) data_path = sample.data_path() -label = 'Aud-lh' -meg_path = data_path / 'MEG' / 'sample' -label_fname = meg_path / 'labels' / f'{label}.label' -fname_inv = meg_path / 'sample_audvis-meg-oct-6-meg-inv.fif' -fname_evoked = meg_path / 'sample_audvis-ave.fif' +label = "Aud-lh" +meg_path = data_path / "MEG" / "sample" +label_fname = meg_path / "labels" / f"{label}.label" +fname_inv = meg_path / "sample_audvis-meg-oct-6-meg-inv.fif" +fname_evoked = meg_path / "sample_audvis-ave.fif" snr = 3.0 -lambda2 = 1.0 / snr ** 2 +lambda2 = 1.0 / snr**2 method = "dSPM" # use dSPM method (could also be MNE or sLORETA) # Load data evoked = mne.read_evokeds(fname_evoked, condition=0, baseline=(None, 0)) inverse_operator = read_inverse_operator(fname_inv) -src = inverse_operator['src'] +src = inverse_operator["src"] # %% # Compute inverse solution # ------------------------ pick_ori = "normal" # Get signed values to see the effect of sign flip -stc = apply_inverse(evoked, inverse_operator, lambda2, method, - pick_ori=pick_ori) +stc = apply_inverse(evoked, inverse_operator, lambda2, method, pick_ori=pick_ori) label = mne.read_label(label_fname) stc_label = stc.in_label(label) -modes = ('mean', 'mean_flip', 'pca_flip') +modes = ("mean", "mean_flip", "pca_flip") tcs = dict() for mode in modes: tcs[mode] = stc.extract_label_time_course(label, src, mode=mode) @@ -65,17 +64,23 @@ fig, ax = plt.subplots(1) t = 1e3 * stc_label.times -ax.plot(t, stc_label.data.T, 'k', linewidth=0.5, alpha=0.5) -pe = [path_effects.Stroke(linewidth=5, foreground='w', alpha=0.5), - path_effects.Normal()] +ax.plot(t, stc_label.data.T, "k", linewidth=0.5, alpha=0.5) +pe = [ + path_effects.Stroke(linewidth=5, foreground="w", alpha=0.5), + path_effects.Normal(), +] for mode, tc in tcs.items(): ax.plot(t, tc[0], linewidth=3, label=str(mode), path_effects=pe) xlim = t[[0, -1]] ylim = [-27, 22] -ax.legend(loc='upper right') -ax.set(xlabel='Time (ms)', ylabel='Source amplitude', - title='Activations in Label %r' % (label.name), - xlim=xlim, ylim=ylim) +ax.legend(loc="upper right") +ax.set( + xlabel="Time (ms)", + ylabel="Source amplitude", + title="Activations in Label %r" % (label.name), + xlim=xlim, + ylim=ylim, +) mne.viz.tight_layout() # %% @@ -84,21 +89,32 @@ # It's also possible to compute label time courses for a # :class:`mne.VectorSourceEstimate`, but only with ``mode='mean'``. -pick_ori = 'vector' -stc_vec = apply_inverse(evoked, inverse_operator, lambda2, method, - pick_ori=pick_ori) +pick_ori = "vector" +stc_vec = apply_inverse(evoked, inverse_operator, lambda2, method, pick_ori=pick_ori) data = stc_vec.extract_label_time_course(label, src) fig, ax = plt.subplots(1) stc_vec_label = stc_vec.in_label(label) -colors = ['#EE6677', '#228833', '#4477AA'] -for ii, name in enumerate('XYZ'): +colors = ["#EE6677", "#228833", "#4477AA"] +for ii, name in enumerate("XYZ"): color = colors[ii] - ax.plot(t, stc_vec_label.data[:, ii].T, color=color, lw=0.5, alpha=0.5, - zorder=5 - ii) - ax.plot(t, data[0, ii], lw=3, color=color, label='+' + name, zorder=8 - ii, - path_effects=pe) -ax.legend(loc='upper right') -ax.set(xlabel='Time (ms)', ylabel='Source amplitude', - title='Mean vector activations in Label %r' % (label.name,), - xlim=xlim, ylim=ylim) + ax.plot( + t, stc_vec_label.data[:, ii].T, color=color, lw=0.5, alpha=0.5, zorder=5 - ii + ) + ax.plot( + t, + data[0, ii], + lw=3, + color=color, + label="+" + name, + zorder=8 - ii, + path_effects=pe, + ) +ax.legend(loc="upper right") +ax.set( + xlabel="Time (ms)", + ylabel="Source amplitude", + title="Mean vector activations in Label %r" % (label.name,), + xlim=xlim, + ylim=ylim, +) mne.viz.tight_layout() diff --git a/examples/inverse/mixed_norm_inverse.py b/examples/inverse/mixed_norm_inverse.py index 56b64e744a1..ce8a1e74a69 100644 --- a/examples/inverse/mixed_norm_inverse.py +++ b/examples/inverse/mixed_norm_inverse.py @@ -25,22 +25,25 @@ from mne.datasets import sample from mne.inverse_sparse import mixed_norm, make_stc_from_dipoles from mne.minimum_norm import make_inverse_operator, apply_inverse -from mne.viz import (plot_sparse_source_estimates, - plot_dipole_locations, plot_dipole_amplitudes) +from mne.viz import ( + plot_sparse_source_estimates, + plot_dipole_locations, + plot_dipole_amplitudes, +) print(__doc__) data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -fwd_fname = meg_path / 'sample_audvis-meg-eeg-oct-6-fwd.fif' -ave_fname = meg_path / 'sample_audvis-ave.fif' -cov_fname = meg_path / 'sample_audvis-shrunk-cov.fif' -subjects_dir = data_path / 'subjects' +meg_path = data_path / "MEG" / "sample" +fwd_fname = meg_path / "sample_audvis-meg-eeg-oct-6-fwd.fif" +ave_fname = meg_path / "sample_audvis-ave.fif" +cov_fname = meg_path / "sample_audvis-shrunk-cov.fif" +subjects_dir = data_path / "subjects" # Read noise covariance matrix cov = mne.read_cov(cov_fname) # Handling average file -condition = 'Left Auditory' +condition = "Left Auditory" evoked = mne.read_evokeds(ave_fname, condition=condition, baseline=(None, 0)) evoked.crop(tmin=0, tmax=0.3) # Handling forward solution @@ -54,18 +57,30 @@ # if n_mxne_iter > 1 dSPM weighting can be avoided. # Compute dSPM solution to be used as weights in MxNE -inverse_operator = make_inverse_operator(evoked.info, forward, cov, - depth=depth, fixed=True, - use_cps=True) -stc_dspm = apply_inverse(evoked, inverse_operator, lambda2=1. / 9., - method='dSPM') +inverse_operator = make_inverse_operator( + evoked.info, forward, cov, depth=depth, fixed=True, use_cps=True +) +stc_dspm = apply_inverse(evoked, inverse_operator, lambda2=1.0 / 9.0, method="dSPM") # Compute (ir)MxNE inverse solution with dipole output dipoles, residual = mixed_norm( - evoked, forward, cov, alpha, loose=loose, depth=depth, maxit=3000, - tol=1e-4, active_set_size=10, debias=False, weights=stc_dspm, - weights_min=8., n_mxne_iter=n_mxne_iter, return_residual=True, - return_as_dipoles=True, verbose=True, random_state=0, + evoked, + forward, + cov, + alpha, + loose=loose, + depth=depth, + maxit=3000, + tol=1e-4, + active_set_size=10, + debias=False, + weights=stc_dspm, + weights_min=8.0, + n_mxne_iter=n_mxne_iter, + return_residual=True, + return_as_dipoles=True, + verbose=True, + random_state=0, # for this dataset we know we should use a high alpha, so avoid some # of the slower (lower) alpha values sure_alpha_grid=np.linspace(100, 40, 10), @@ -74,8 +89,7 @@ t = 0.083 tidx = evoked.time_as_index(t) for di, dip in enumerate(dipoles, 1): - print(f'Dipole #{di} GOF at {1000 * t:0.1f} ms: ' - f'{float(dip.gof[tidx]):0.1f}%') + print(f"Dipole #{di} GOF at {1000 * t:0.1f} ms: " f"{float(dip.gof[tidx]):0.1f}%") # %% # Plot dipole activations @@ -83,48 +97,70 @@ # Plot dipole location of the strongest dipole with MRI slices idx = np.argmax([np.max(np.abs(dip.amplitude)) for dip in dipoles]) -plot_dipole_locations(dipoles[idx], forward['mri_head_t'], 'sample', - subjects_dir=subjects_dir, mode='orthoview', - idx='amplitude') +plot_dipole_locations( + dipoles[idx], + forward["mri_head_t"], + "sample", + subjects_dir=subjects_dir, + mode="orthoview", + idx="amplitude", +) # Plot dipole locations of all dipoles with MRI slices for dip in dipoles: - plot_dipole_locations(dip, forward['mri_head_t'], 'sample', - subjects_dir=subjects_dir, mode='orthoview', - idx='amplitude') + plot_dipole_locations( + dip, + forward["mri_head_t"], + "sample", + subjects_dir=subjects_dir, + mode="orthoview", + idx="amplitude", + ) # %% # Plot residual ylim = dict(eeg=[-10, 10], grad=[-400, 400], mag=[-600, 600]) -evoked.pick_types(meg=True, eeg=True, exclude='bads') -evoked.plot(ylim=ylim, proj=True, time_unit='s') -residual.pick_types(meg=True, eeg=True, exclude='bads') -residual.plot(ylim=ylim, proj=True, time_unit='s') +evoked.pick_types(meg=True, eeg=True, exclude="bads") +evoked.plot(ylim=ylim, proj=True, time_unit="s") +residual.pick_types(meg=True, eeg=True, exclude="bads") +residual.plot(ylim=ylim, proj=True, time_unit="s") # %% # Generate stc from dipoles -stc = make_stc_from_dipoles(dipoles, forward['src']) +stc = make_stc_from_dipoles(dipoles, forward["src"]) # %% # View in 2D and 3D ("glass" brain like 3D plot) solver = "MxNE" if n_mxne_iter == 1 else "irMxNE" -plot_sparse_source_estimates(forward['src'], stc, bgcolor=(1, 1, 1), - fig_name="%s (cond %s)" % (solver, condition), - opacity=0.1) +plot_sparse_source_estimates( + forward["src"], + stc, + bgcolor=(1, 1, 1), + fig_name="%s (cond %s)" % (solver, condition), + opacity=0.1, +) # %% # Morph onto fsaverage brain and view -morph = mne.compute_source_morph(stc, subject_from='sample', - subject_to='fsaverage', spacing=None, - sparse=True, subjects_dir=subjects_dir) +morph = mne.compute_source_morph( + stc, + subject_from="sample", + subject_to="fsaverage", + spacing=None, + sparse=True, + subjects_dir=subjects_dir, +) stc_fsaverage = morph.apply(stc) -src_fsaverage_fname = ( - subjects_dir / 'fsaverage' / 'bem' / 'fsaverage-ico-5-src.fif') +src_fsaverage_fname = subjects_dir / "fsaverage" / "bem" / "fsaverage-ico-5-src.fif" src_fsaverage = mne.read_source_spaces(src_fsaverage_fname) -plot_sparse_source_estimates(src_fsaverage, stc_fsaverage, bgcolor=(1, 1, 1), - fig_name="Morphed %s (cond %s)" % (solver, - condition), opacity=0.1) +plot_sparse_source_estimates( + src_fsaverage, + stc_fsaverage, + bgcolor=(1, 1, 1), + fig_name="Morphed %s (cond %s)" % (solver, condition), + opacity=0.1, +) # %% # References diff --git a/examples/inverse/mixed_source_space_inverse.py b/examples/inverse/mixed_source_space_inverse.py index f732178ea9f..9baac7da379 100644 --- a/examples/inverse/mixed_source_space_inverse.py +++ b/examples/inverse/mixed_source_space_inverse.py @@ -23,22 +23,22 @@ # Set dir data_path = mne.datasets.sample.data_path() -subject = 'sample' -data_dir = data_path / 'MEG' / subject -subjects_dir = data_path / 'subjects' -bem_dir = subjects_dir / subject / 'bem' +subject = "sample" +data_dir = data_path / "MEG" / subject +subjects_dir = data_path / "subjects" +bem_dir = subjects_dir / subject / "bem" # Set file names -fname_mixed_src = bem_dir / f'{subject}-oct-6-mixed-src.fif' -fname_aseg = subjects_dir / subject / 'mri' / 'aseg.mgz' +fname_mixed_src = bem_dir / f"{subject}-oct-6-mixed-src.fif" +fname_aseg = subjects_dir / subject / "mri" / "aseg.mgz" -fname_model = bem_dir / f'{subject}-5120-bem.fif' -fname_bem = bem_dir / f'{subject}-5120-bem-sol.fif' +fname_model = bem_dir / f"{subject}-5120-bem.fif" +fname_bem = bem_dir / f"{subject}-5120-bem-sol.fif" -fname_evoked = data_dir / f'{subject}_audvis-ave.fif' -fname_trans = data_dir / f'{subject}_audvis_raw-trans.fif' -fname_fwd = data_dir / f'{subject}_audvis-meg-oct-6-mixed-fwd.fif' -fname_cov = data_dir / f'{subject}_audvis-shrunk-cov.fif' +fname_evoked = data_dir / f"{subject}_audvis-ave.fif" +fname_trans = data_dir / f"{subject}_audvis_raw-trans.fif" +fname_fwd = data_dir / f"{subject}_audvis-meg-oct-6-mixed-fwd.fif" +fname_cov = data_dir / f"{subject}_audvis-shrunk-cov.fif" # %% # Set up our source space @@ -46,19 +46,22 @@ # List substructures we are interested in. We select only the # sub structures we want to include in the source space: -labels_vol = ['Left-Amygdala', - 'Left-Thalamus-Proper', - 'Left-Cerebellum-Cortex', - 'Brain-Stem', - 'Right-Amygdala', - 'Right-Thalamus-Proper', - 'Right-Cerebellum-Cortex'] +labels_vol = [ + "Left-Amygdala", + "Left-Thalamus-Proper", + "Left-Cerebellum-Cortex", + "Brain-Stem", + "Right-Amygdala", + "Right-Thalamus-Proper", + "Right-Cerebellum-Cortex", +] # %% # Get a surface-based source space, here with few source points for speed # in this demonstration, in general you should use oct6 spacing! -src = mne.setup_source_space(subject, spacing='oct5', - add_dist=False, subjects_dir=subjects_dir) +src = mne.setup_source_space( + subject, spacing="oct5", add_dist=False, subjects_dir=subjects_dir +) # %% # Now we create a mixed src space by adding the volume regions specified in the @@ -67,15 +70,22 @@ # we recommend something smaller like 5.0 in actual analyses): vol_src = mne.setup_volume_source_space( - subject, mri=fname_aseg, pos=10.0, bem=fname_model, - volume_label=labels_vol, subjects_dir=subjects_dir, + subject, + mri=fname_aseg, + pos=10.0, + bem=fname_model, + volume_label=labels_vol, + subjects_dir=subjects_dir, add_interpolator=False, # just for speed, usually this should be True - verbose=True) + verbose=True, +) # Generate the mixed source space src += vol_src -print(f"The source space contains {len(src)} spaces and " - f"{sum(s['nuse'] for s in src)} vertices") +print( + f"The source space contains {len(src)} spaces and " + f"{sum(s['nuse'] for s in src)} vertices" +) # %% # View the source space @@ -90,47 +100,54 @@ # # We can also export source positions to NIfTI file and visualize it again: -nii_fname = bem_dir / f'{subject}-mixed-src.nii' +nii_fname = bem_dir / f"{subject}-mixed-src.nii" src.export_volume(nii_fname, mri_resolution=True, overwrite=True) -plotting.plot_img(str(nii_fname), cmap='nipy_spectral') +plotting.plot_img(str(nii_fname), cmap="nipy_spectral") # %% # Compute the fwd matrix # ---------------------- fwd = mne.make_forward_solution( - fname_evoked, fname_trans, src, fname_bem, + fname_evoked, + fname_trans, + src, + fname_bem, mindist=5.0, # ignore sources<=5mm from innerskull - meg=True, eeg=False, n_jobs=None) + meg=True, + eeg=False, + n_jobs=None, +) del src # save memory -leadfield = fwd['sol']['data'] +leadfield = fwd["sol"]["data"] print("Leadfield size : %d sensors x %d dipoles" % leadfield.shape) -print(f"The fwd source space contains {len(fwd['src'])} spaces and " - f"{sum(s['nuse'] for s in fwd['src'])} vertices") +print( + f"The fwd source space contains {len(fwd['src'])} spaces and " + f"{sum(s['nuse'] for s in fwd['src'])} vertices" +) # Load data -condition = 'Left Auditory' -evoked = mne.read_evokeds(fname_evoked, condition=condition, - baseline=(None, 0)) +condition = "Left Auditory" +evoked = mne.read_evokeds(fname_evoked, condition=condition, baseline=(None, 0)) noise_cov = mne.read_cov(fname_cov) # %% # Compute inverse solution # ------------------------ -snr = 3.0 # use smaller SNR for raw data -inv_method = 'dSPM' # sLORETA, MNE, dSPM -parc = 'aparc' # the parcellation to use, e.g., 'aparc' 'aparc.a2009s' -loose = dict(surface=0.2, volume=1.) +snr = 3.0 # use smaller SNR for raw data +inv_method = "dSPM" # sLORETA, MNE, dSPM +parc = "aparc" # the parcellation to use, e.g., 'aparc' 'aparc.a2009s' +loose = dict(surface=0.2, volume=1.0) -lambda2 = 1.0 / snr ** 2 +lambda2 = 1.0 / snr**2 inverse_operator = make_inverse_operator( - evoked.info, fwd, noise_cov, depth=None, loose=loose, verbose=True) + evoked.info, fwd, noise_cov, depth=None, loose=loose, verbose=True +) del fwd -stc = apply_inverse(evoked, inverse_operator, lambda2, inv_method, - pick_ori=None) -src = inverse_operator['src'] +stc = apply_inverse(evoked, inverse_operator, lambda2, inv_method, pick_ori=None) +src = inverse_operator["src"] # %% # Plot the mixed source estimate @@ -138,24 +155,30 @@ # sphinx_gallery_thumbnail_number = 3 initial_time = 0.1 -stc_vec = apply_inverse(evoked, inverse_operator, lambda2, inv_method, - pick_ori='vector') +stc_vec = apply_inverse( + evoked, inverse_operator, lambda2, inv_method, pick_ori="vector" +) brain = stc_vec.plot( - hemi='both', src=inverse_operator['src'], views='coronal', - initial_time=initial_time, subjects_dir=subjects_dir, - brain_kwargs=dict(silhouette=True), smoothing_steps=7) + hemi="both", + src=inverse_operator["src"], + views="coronal", + initial_time=initial_time, + subjects_dir=subjects_dir, + brain_kwargs=dict(silhouette=True), + smoothing_steps=7, +) # %% # Plot the surface # ---------------- -brain = stc.surface().plot(initial_time=initial_time, - subjects_dir=subjects_dir, smoothing_steps=7) +brain = stc.surface().plot( + initial_time=initial_time, subjects_dir=subjects_dir, smoothing_steps=7 +) # %% # Plot the volume # --------------- -fig = stc.volume().plot(initial_time=initial_time, src=src, - subjects_dir=subjects_dir) +fig = stc.volume().plot(initial_time=initial_time, src=src, subjects_dir=subjects_dir) # %% # Process labels @@ -164,16 +187,16 @@ # and each sub structure contained in the src space # Get labels for FreeSurfer 'aparc' cortical parcellation with 34 labels/hemi -labels_parc = mne.read_labels_from_annot( - subject, parc=parc, subjects_dir=subjects_dir) +labels_parc = mne.read_labels_from_annot(subject, parc=parc, subjects_dir=subjects_dir) label_ts = mne.extract_label_time_course( - [stc], labels_parc, src, mode='mean', allow_empty=True) + [stc], labels_parc, src, mode="mean", allow_empty=True +) # plot the times series of 2 labels fig, axes = plt.subplots(1) -axes.plot(1e3 * stc.times, label_ts[0][0, :], 'k', label='bankssts-lh') -axes.plot(1e3 * stc.times, label_ts[0][-1, :].T, 'r', label='Brain-stem') -axes.set(xlabel='Time (ms)', ylabel='MNE current (nAm)') +axes.plot(1e3 * stc.times, label_ts[0][0, :], "k", label="bankssts-lh") +axes.plot(1e3 * stc.times, label_ts[0][-1, :].T, "r", label="Brain-stem") +axes.set(xlabel="Time (ms)", ylabel="MNE current (nAm)") axes.legend() mne.viz.tight_layout() diff --git a/examples/inverse/mne_cov_power.py b/examples/inverse/mne_cov_power.py index 91fc47bc577..592664a72ef 100644 --- a/examples/inverse/mne_cov_power.py +++ b/examples/inverse/mne_cov_power.py @@ -30,9 +30,9 @@ from mne.minimum_norm import make_inverse_operator, apply_inverse_cov data_path = sample.data_path() -subjects_dir = data_path / 'subjects' -meg_path = data_path / 'MEG' / 'sample' -raw_fname = meg_path / 'sample_audvis_filt-0-40_raw.fif' +subjects_dir = data_path / "subjects" +meg_path = data_path / "MEG" / "sample" +raw_fname = meg_path / "sample_audvis_filt-0-40_raw.fif" raw = mne.io.read_raw_fif(raw_fname) # %% @@ -41,28 +41,39 @@ # First we compute an empty-room covariance, which captures noise from the # sensors and environment. -raw_empty_room_fname = data_path / 'MEG' / 'sample' / 'ernoise_raw.fif' +raw_empty_room_fname = data_path / "MEG" / "sample" / "ernoise_raw.fif" raw_empty_room = mne.io.read_raw_fif(raw_empty_room_fname) raw_empty_room.crop(0, 30) # cropped just for speed -raw_empty_room.info['bads'] = ['MEG 2443'] -raw_empty_room.add_proj(raw.info['projs']) -noise_cov = mne.compute_raw_covariance(raw_empty_room, method='shrunk') +raw_empty_room.info["bads"] = ["MEG 2443"] +raw_empty_room.add_proj(raw.info["projs"]) +noise_cov = mne.compute_raw_covariance(raw_empty_room, method="shrunk") del raw_empty_room # %% # Epoch the data # -------------- -raw.pick(['meg', 'stim', 'eog']).load_data().filter(4, 12) -raw.info['bads'] = ['MEG 2443'] -events = mne.find_events(raw, stim_channel='STI 014') +raw.pick(["meg", "stim", "eog"]).load_data().filter(4, 12) +raw.info["bads"] = ["MEG 2443"] +events = mne.find_events(raw, stim_channel="STI 014") event_id = dict(aud_l=1, aud_r=2, vis_l=3, vis_r=4) tmin, tmax = -0.2, 0.5 baseline = (None, 0) # means from the first instant to t = 0 reject = dict(grad=4000e-13, mag=4e-12, eog=150e-6) -epochs = mne.Epochs(raw, events, event_id, tmin, tmax, - proj=True, picks=('meg', 'eog'), baseline=None, - reject=reject, preload=True, decim=5, verbose='error') +epochs = mne.Epochs( + raw, + events, + event_id, + tmin, + tmax, + proj=True, + picks=("meg", "eog"), + baseline=None, + reject=reject, + preload=True, + decim=5, + verbose="error", +) del raw # %% @@ -78,9 +89,11 @@ # to noise sources). base_cov = mne.compute_covariance( - epochs, tmin=-0.2, tmax=0, method='shrunk', verbose=True) + epochs, tmin=-0.2, tmax=0, method="shrunk", verbose=True +) data_cov = mne.compute_covariance( - epochs, tmin=0., tmax=0.2, method='shrunk', verbose=True) + epochs, tmin=0.0, tmax=0.2, method="shrunk", verbose=True +) fig_noise_cov = mne.viz.plot_cov(noise_cov, epochs.info, show_svd=False) fig_base_cov = mne.viz.plot_cov(base_cov, epochs.info, show_svd=False) @@ -91,16 +104,18 @@ # baseline and data covariances, followed by the data covariance whitened # by the baseline covariance: -evoked = epochs.average().pick('meg') -evoked.drop_channels(evoked.info['bads']) -evoked.plot(time_unit='s') -evoked.plot_topomap(times=np.linspace(0.05, 0.15, 5), ch_type='mag') +evoked = epochs.average().pick("meg") +evoked.drop_channels(evoked.info["bads"]) +evoked.plot(time_unit="s") +evoked.plot_topomap(times=np.linspace(0.05, 0.15, 5), ch_type="mag") -loop = {'Noise': (noise_cov, dict()), - 'Data': (data_cov, dict()), - 'Whitened data': (data_cov, dict(noise_cov=noise_cov))} +loop = { + "Noise": (noise_cov, dict()), + "Data": (data_cov, dict()), + "Whitened data": (data_cov, dict(noise_cov=noise_cov)), +} for title, (_cov, _kw) in loop.items(): - fig = _cov.plot_topomap(evoked.info, 'grad', **_kw) + fig = _cov.plot_topomap(evoked.info, "grad", **_kw) fig.suptitle(title) # %% @@ -109,20 +124,31 @@ # Finally, we can construct an inverse using the empty-room noise covariance: # Read the forward solution and compute the inverse operator -fname_fwd = meg_path / 'sample_audvis-meg-oct-6-fwd.fif' +fname_fwd = meg_path / "sample_audvis-meg-oct-6-fwd.fif" fwd = mne.read_forward_solution(fname_fwd) # make an MEG inverse operator info = evoked.info -inverse_operator = make_inverse_operator(info, fwd, noise_cov, - loose=0.2, depth=0.8) +inverse_operator = make_inverse_operator(info, fwd, noise_cov, loose=0.2, depth=0.8) # %% # Project our data and baseline covariance to source space: -stc_data = apply_inverse_cov(data_cov, evoked.info, inverse_operator, - nave=len(epochs), method='dSPM', verbose=True) -stc_base = apply_inverse_cov(base_cov, evoked.info, inverse_operator, - nave=len(epochs), method='dSPM', verbose=True) +stc_data = apply_inverse_cov( + data_cov, + evoked.info, + inverse_operator, + nave=len(epochs), + method="dSPM", + verbose=True, +) +stc_base = apply_inverse_cov( + base_cov, + evoked.info, + inverse_operator, + nave=len(epochs), + method="dSPM", + verbose=True, +) # %% # And visualize power is relative to the baseline: @@ -130,6 +156,9 @@ # sphinx_gallery_thumbnail_number = 9 stc_data /= stc_base -brain = stc_data.plot(subject='sample', subjects_dir=subjects_dir, - clim=dict(kind='percent', lims=(50, 90, 98)), - smoothing_steps=7) +brain = stc_data.plot( + subject="sample", + subjects_dir=subjects_dir, + clim=dict(kind="percent", lims=(50, 90, 98)), + smoothing_steps=7, +) diff --git a/examples/inverse/morph_surface_stc.py b/examples/inverse/morph_surface_stc.py index 80a35c87ed8..0417a8d807a 100644 --- a/examples/inverse/morph_surface_stc.py +++ b/examples/inverse/morph_surface_stc.py @@ -37,19 +37,18 @@ # Setup paths data_path = sample.data_path() -sample_dir = data_path / 'MEG' / 'sample' -subjects_dir = data_path / 'subjects' -fname_src = subjects_dir / 'sample' / 'bem' / 'sample-oct-6-src.fif' -fname_fwd = sample_dir / 'sample_audvis-meg-oct-6-fwd.fif' -fname_fsaverage_src = (subjects_dir / 'fsaverage' / 'bem' / - 'fsaverage-ico-5-src.fif') -fname_stc = sample_dir / 'sample_audvis-meg' +sample_dir = data_path / "MEG" / "sample" +subjects_dir = data_path / "subjects" +fname_src = subjects_dir / "sample" / "bem" / "sample-oct-6-src.fif" +fname_fwd = sample_dir / "sample_audvis-meg-oct-6-fwd.fif" +fname_fsaverage_src = subjects_dir / "fsaverage" / "bem" / "fsaverage-ico-5-src.fif" +fname_stc = sample_dir / "sample_audvis-meg" # %% # Load example data # Read stc from file -stc = mne.read_source_estimate(fname_stc, subject='sample') +stc = mne.read_source_estimate(fname_stc, subject="sample") # %% # Setting up SourceMorph for SourceEstimate @@ -66,7 +65,7 @@ src_orig = mne.read_source_spaces(fname_src) print(src_orig) # n_used=4098, 4098 fwd = mne.read_forward_solution(fname_fwd) -print(fwd['src']) # n_used=3732, 3766 +print(fwd["src"]) # n_used=3732, 3766 print([len(v) for v in stc.vertices]) # %% @@ -86,10 +85,14 @@ # Initialize SourceMorph for SourceEstimate src_to = mne.read_source_spaces(fname_fsaverage_src) -print(src_to[0]['vertno']) # special, np.arange(10242) -morph = mne.compute_source_morph(stc, subject_from='sample', - subject_to='fsaverage', src_to=src_to, - subjects_dir=subjects_dir) +print(src_to[0]["vertno"]) # special, np.arange(10242) +morph = mne.compute_source_morph( + stc, + subject_from="sample", + subject_to="fsaverage", + src_to=src_to, + subjects_dir=subjects_dir, +) # %% # Apply morph to (Vector) SourceEstimate @@ -106,25 +109,28 @@ # Define plotting parameters surfer_kwargs = dict( - hemi='lh', subjects_dir=subjects_dir, - clim=dict(kind='value', lims=[8, 12, 15]), views='lateral', - initial_time=0.09, time_unit='s', size=(800, 800), - smoothing_steps=5) + hemi="lh", + subjects_dir=subjects_dir, + clim=dict(kind="value", lims=[8, 12, 15]), + views="lateral", + initial_time=0.09, + time_unit="s", + size=(800, 800), + smoothing_steps=5, +) # As spherical surface -brain = stc_fsaverage.plot(surface='sphere', **surfer_kwargs) +brain = stc_fsaverage.plot(surface="sphere", **surfer_kwargs) # Add title -brain.add_text(0.1, 0.9, 'Morphed to fsaverage (spherical)', 'title', - font_size=16) +brain.add_text(0.1, 0.9, "Morphed to fsaverage (spherical)", "title", font_size=16) # %% # As inflated surface -brain_inf = stc_fsaverage.plot(surface='inflated', **surfer_kwargs) +brain_inf = stc_fsaverage.plot(surface="inflated", **surfer_kwargs) # Add title -brain_inf.add_text(0.1, 0.9, 'Morphed to fsaverage (inflated)', 'title', - font_size=16) +brain_inf.add_text(0.1, 0.9, "Morphed to fsaverage (inflated)", "title", font_size=16) # %% # Reading and writing SourceMorph from and to disk @@ -153,8 +159,7 @@ # easily chained into a handy one-liner. Taking this together the shortest # possible way to morph data directly would be: -stc_fsaverage = mne.compute_source_morph(stc, - subjects_dir=subjects_dir).apply(stc) +stc_fsaverage = mne.compute_source_morph(stc, subjects_dir=subjects_dir).apply(stc) # %% # For more examples, check out :ref:`examples using SourceMorph.apply diff --git a/examples/inverse/morph_volume_stc.py b/examples/inverse/morph_volume_stc.py index 1494b7b30c8..adf20db7905 100644 --- a/examples/inverse/morph_volume_stc.py +++ b/examples/inverse/morph_volume_stc.py @@ -38,16 +38,15 @@ # %% # Setup paths sample_dir_raw = sample.data_path() -sample_dir = os.path.join(sample_dir_raw, 'MEG', 'sample') -subjects_dir = os.path.join(sample_dir_raw, 'subjects') +sample_dir = os.path.join(sample_dir_raw, "MEG", "sample") +subjects_dir = os.path.join(sample_dir_raw, "subjects") -fname_evoked = os.path.join(sample_dir, 'sample_audvis-ave.fif') -fname_inv = os.path.join(sample_dir, 'sample_audvis-meg-vol-7-meg-inv.fif') +fname_evoked = os.path.join(sample_dir, "sample_audvis-ave.fif") +fname_inv = os.path.join(sample_dir, "sample_audvis-meg-vol-7-meg-inv.fif") -fname_t1_fsaverage = os.path.join(subjects_dir, 'fsaverage', 'mri', - 'brain.mgz') +fname_t1_fsaverage = os.path.join(subjects_dir, "fsaverage", "mri", "brain.mgz") fetch_fsaverage(subjects_dir) # ensure fsaverage src exists -fname_src_fsaverage = subjects_dir + '/fsaverage/bem/fsaverage-vol-5-src.fif' +fname_src_fsaverage = subjects_dir + "/fsaverage/bem/fsaverage-vol-5-src.fif" # %% # Compute example data. For reference see :ref:`ex-inverse-volume`. @@ -57,7 +56,7 @@ inverse_operator = read_inverse_operator(fname_inv) # Apply inverse operator -stc = apply_inverse(evoked, inverse_operator, 1.0 / 3.0 ** 2, "dSPM") +stc = apply_inverse(evoked, inverse_operator, 1.0 / 3.0**2, "dSPM") # To save time stc.crop(0.09, 0.09) @@ -84,9 +83,14 @@ src_fs = mne.read_source_spaces(fname_src_fsaverage) morph = mne.compute_source_morph( - inverse_operator['src'], subject_from='sample', subjects_dir=subjects_dir, - niter_affine=[10, 10, 5], niter_sdr=[10, 10, 5], # just for speed - src_to=src_fs, verbose=True) + inverse_operator["src"], + subject_from="sample", + subjects_dir=subjects_dir, + niter_affine=[10, 10, 5], + niter_sdr=[10, 10, 5], # just for speed + src_to=src_fs, + verbose=True, +) # %% # Apply morph to VolSourceEstimate @@ -119,7 +123,7 @@ # :meth:`morph.apply(..., output='nifti1') `. # Create mri-resolution volume of results -img_fsaverage = morph.apply(stc, mri_resolution=2, output='nifti1') +img_fsaverage = morph.apply(stc, mri_resolution=2, output="nifti1") # %% # Plot results @@ -129,10 +133,9 @@ t1_fsaverage = nib.load(fname_t1_fsaverage) # Plot glass brain (change to plot_anat to display an overlaid anatomical T1) -display = plot_glass_brain(t1_fsaverage, - title='subject results to fsaverage', - draw_cross=False, - annotate=True) +display = plot_glass_brain( + t1_fsaverage, title="subject results to fsaverage", draw_cross=False, annotate=True +) # Add functional data as overlay display.add_overlay(img_fsaverage, alpha=0.75) diff --git a/examples/inverse/multi_dipole_model.py b/examples/inverse/multi_dipole_model.py index afed2d738df..40bbf60c919 100644 --- a/examples/inverse/multi_dipole_model.py +++ b/examples/inverse/multi_dipole_model.py @@ -33,17 +33,16 @@ import mne from mne.datasets import sample from mne.channels import read_vectorview_selection -from mne.minimum_norm import (make_inverse_operator, apply_inverse, - apply_inverse_epochs) +from mne.minimum_norm import make_inverse_operator, apply_inverse, apply_inverse_epochs import matplotlib.pyplot as plt import numpy as np data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -raw_fname = meg_path / 'sample_audvis_raw.fif' -cov_fname = meg_path / 'sample_audvis-shrunk-cov.fif' -bem_dir = data_path / 'subjects' / 'sample' / 'bem' -bem_fname = bem_dir / 'sample-5120-5120-5120-bem-sol.fif' +meg_path = data_path / "MEG" / "sample" +raw_fname = meg_path / "sample_audvis_raw.fif" +cov_fname = meg_path / "sample_audvis-shrunk-cov.fif" +bem_dir = data_path / "subjects" / "sample" / "bem" +bem_fname = bem_dir / "sample-5120-5120-5120-bem-sol.fif" ############################################################################### # Read the MEG data from the audvis experiment. Make epochs and evokeds for the @@ -55,13 +54,19 @@ # Create epochs for auditory events events = mne.find_events(raw) event_id = dict(right=1, left=2) -epochs = mne.Epochs(raw, events, event_id, - tmin=-0.1, tmax=0.3, baseline=(None, 0), - reject=dict(mag=4e-12, grad=4000e-13, eog=150e-6)) +epochs = mne.Epochs( + raw, + events, + event_id, + tmin=-0.1, + tmax=0.3, + baseline=(None, 0), + reject=dict(mag=4e-12, grad=4000e-13, eog=150e-6), +) # Create evokeds for left and right auditory stimulation -evoked_left = epochs['left'].average() -evoked_right = epochs['right'].average() +evoked_left = epochs["left"].average() +evoked_right = epochs["right"].average() ############################################################################### # Guided dipole modeling, meaning fitting dipoles to a manually selected subset @@ -76,12 +81,12 @@ # Fit two dipoles at t=80ms. The first dipole is fitted using only the sensors # on the left side of the helmet. The second dipole is fitted using only the # sensors on the right side of the helmet. -picks_left = read_vectorview_selection('Left', info=info) +picks_left = read_vectorview_selection("Left", info=info) evoked_fit_left = evoked_left.copy().crop(0.08, 0.08) evoked_fit_left.pick_channels(picks_left, ordered=False) cov_fit_left = cov.copy().pick_channels(picks_left, ordered=False) -picks_right = read_vectorview_selection('Right', info=info) +picks_right = read_vectorview_selection("Right", info=info) evoked_fit_right = evoked_right.copy().crop(0.08, 0.08) evoked_fit_right.pick_channels(picks_right, ordered=False) cov_fit_right = cov.copy().pick_channels(picks_right, ordered=False) @@ -90,8 +95,8 @@ # after picking channels. evoked_fit_left.info.normalize_proj() evoked_fit_right.info.normalize_proj() -cov_fit_left['projs'] = evoked_fit_left.info['projs'] -cov_fit_right['projs'] = evoked_fit_right.info['projs'] +cov_fit_left["projs"] = evoked_fit_left.info["projs"] +cov_fit_right["projs"] = evoked_fit_right.info["projs"] # Fit the dipoles with the subset of sensors. dip_left, _ = mne.fit_dipole(evoked_fit_left, cov_fit_left, bem) @@ -107,27 +112,25 @@ # Apply MNE inverse inv = make_inverse_operator(info, fwd, cov, fixed=True, depth=0) -stc_left = apply_inverse(evoked_left, inv, method='MNE', lambda2=1E-6) -stc_right = apply_inverse(evoked_right, inv, method='MNE', lambda2=1E-6) +stc_left = apply_inverse(evoked_left, inv, method="MNE", lambda2=1e-6) +stc_right = apply_inverse(evoked_right, inv, method="MNE", lambda2=1e-6) # Plot the timecourses of the resulting source estimate fig, axes = plt.subplots(nrows=2, sharex=True, sharey=True) axes[0].plot(stc_left.times, stc_left.data.T) -axes[0].set_title('Left auditory stimulation') -axes[0].legend(['Dipole 1', 'Dipole 2']) +axes[0].set_title("Left auditory stimulation") +axes[0].legend(["Dipole 1", "Dipole 2"]) axes[1].plot(stc_right.times, stc_right.data.T) -axes[1].set_title('Right auditory stimulation') -axes[1].set_xlabel('Time (s)') -fig.supylabel('Dipole amplitude') +axes[1].set_title("Right auditory stimulation") +axes[1].set_xlabel("Time (s)") +fig.supylabel("Dipole amplitude") ############################################################################### # We can also fit the timecourses to single epochs. Here, we do it for each # experimental condition separately. -stcs_left = apply_inverse_epochs(epochs['left'], inv, lambda2=1E-6, - method='MNE') -stcs_right = apply_inverse_epochs(epochs['right'], inv, lambda2=1E-6, - method='MNE') +stcs_left = apply_inverse_epochs(epochs["left"], inv, lambda2=1e-6, method="MNE") +stcs_right = apply_inverse_epochs(epochs["right"], inv, lambda2=1e-6, method="MNE") ############################################################################### # To summarize and visualize the single-epoch dipole amplitudes, we will create @@ -151,17 +154,17 @@ mean_right = np.mean(amplitudes_right, axis=0) fig, ax = plt.subplots(figsize=(8, 4)) -ax.scatter(np.arange(n), amplitudes[:, 0], label='Dipole 1') -ax.scatter(np.arange(n), amplitudes[:, 1], label='Dipole 2') +ax.scatter(np.arange(n), amplitudes[:, 0], label="Dipole 1") +ax.scatter(np.arange(n), amplitudes[:, 1], label="Dipole 2") transition_point = n_left - 0.5 -ax.plot([0, transition_point], [mean_left[0], mean_left[0]], color='C0') -ax.plot([0, transition_point], [mean_left[1], mean_left[1]], color='C1') -ax.plot([transition_point, n], [mean_right[0], mean_right[0]], color='C0') -ax.plot([transition_point, n], [mean_right[1], mean_right[1]], color='C1') -ax.axvline(transition_point, color='black') -ax.set_xlabel('Epochs') -ax.set_ylabel('Dipole amplitude') +ax.plot([0, transition_point], [mean_left[0], mean_left[0]], color="C0") +ax.plot([0, transition_point], [mean_left[1], mean_left[1]], color="C1") +ax.plot([transition_point, n], [mean_right[0], mean_right[0]], color="C0") +ax.plot([transition_point, n], [mean_right[1], mean_right[1]], color="C1") +ax.axvline(transition_point, color="black") +ax.set_xlabel("Epochs") +ax.set_ylabel("Dipole amplitude") ax.legend() -fig.suptitle('Single epoch dipole amplitudes') -fig.text(0.30, 0.9, 'Left auditory stimulation', ha='center') -fig.text(0.70, 0.9, 'Right auditory stimulation', ha='center') +fig.suptitle("Single epoch dipole amplitudes") +fig.text(0.30, 0.9, "Left auditory stimulation", ha="center") +fig.text(0.70, 0.9, "Right auditory stimulation", ha="center") diff --git a/examples/inverse/multidict_reweighted_tfmxne.py b/examples/inverse/multidict_reweighted_tfmxne.py index 58aa0fefb09..f12903d1754 100644 --- a/examples/inverse/multidict_reweighted_tfmxne.py +++ b/examples/inverse/multidict_reweighted_tfmxne.py @@ -38,28 +38,29 @@ # Load somatosensory MEG data data_path = somato.data_path() -subject = '01' -task = 'somato' -raw_fname = (data_path / f'sub-{subject}' / 'meg' / - f'sub-{subject}_task-{task}_meg.fif') -fwd_fname = (data_path / 'derivatives' / f'sub-{subject}' / - f'sub-{subject}_task-{task}-fwd.fif') +subject = "01" +task = "somato" +raw_fname = data_path / f"sub-{subject}" / "meg" / f"sub-{subject}_task-{task}_meg.fif" +fwd_fname = ( + data_path / "derivatives" / f"sub-{subject}" / f"sub-{subject}_task-{task}-fwd.fif" +) # Read evoked raw = mne.io.read_raw_fif(raw_fname) raw.pick_types(meg=True, eog=True, stim=True) -events = mne.find_events(raw, stim_channel='STI 014') +events = mne.find_events(raw, stim_channel="STI 014") reject = dict(grad=4000e-13, eog=350e-6) event_id, tmin, tmax = dict(unknown=1), -0.5, 0.5 -epochs = mne.Epochs(raw, events, event_id, tmin, tmax, reject=reject, - baseline=(None, 0)) +epochs = mne.Epochs( + raw, events, event_id, tmin, tmax, reject=reject, baseline=(None, 0) +) evoked = epochs.average() evoked.crop(tmin=0.0, tmax=0.2) # Compute noise covariance matrix -cov = mne.compute_covariance(epochs, rank='info', tmax=0.) +cov = mne.compute_covariance(epochs, rank="info", tmax=0.0) del epochs, raw # Handling forward solution @@ -69,7 +70,7 @@ # Run iterative reweighted multidict TF-MxNE solver alpha, l1_ratio = 20, 0.05 -loose, depth = 0.9, 1. +loose, depth = 0.9, 1.0 # Use a multiscale time-frequency dictionary wsize, tstep = [4, 16], [2, 4] @@ -77,27 +78,42 @@ n_tfmxne_iter = 10 # Compute TF-MxNE inverse solution with dipole output dipoles, residual = tf_mixed_norm( - evoked, forward, cov, alpha=alpha, l1_ratio=l1_ratio, - n_tfmxne_iter=n_tfmxne_iter, loose=loose, - depth=depth, tol=1e-3, - wsize=wsize, tstep=tstep, return_as_dipoles=True, - return_residual=True) + evoked, + forward, + cov, + alpha=alpha, + l1_ratio=l1_ratio, + n_tfmxne_iter=n_tfmxne_iter, + loose=loose, + depth=depth, + tol=1e-3, + wsize=wsize, + tstep=tstep, + return_as_dipoles=True, + return_residual=True, +) # %% # Generate stc from dipoles -stc = make_stc_from_dipoles(dipoles, forward['src']) +stc = make_stc_from_dipoles(dipoles, forward["src"]) plot_sparse_source_estimates( - forward['src'], stc, bgcolor=(1, 1, 1), opacity=0.1, - fig_name=f"irTF-MxNE (cond {evoked.comment})") + forward["src"], + stc, + bgcolor=(1, 1, 1), + opacity=0.1, + fig_name=f"irTF-MxNE (cond {evoked.comment})", +) # %% # Show the evoked response and the residual for gradiometers ylim = dict(grad=[-300, 300]) -evoked.copy().pick_types(meg='grad').plot( - titles=dict(grad='Evoked Response: Gradiometers'), ylim=ylim) -residual.copy().pick_types(meg='grad').plot( - titles=dict(grad='Residuals: Gradiometers'), ylim=ylim) +evoked.copy().pick_types(meg="grad").plot( + titles=dict(grad="Evoked Response: Gradiometers"), ylim=ylim +) +residual.copy().pick_types(meg="grad").plot( + titles=dict(grad="Residuals: Gradiometers"), ylim=ylim +) # %% # References diff --git a/examples/inverse/psf_ctf_label_leakage.py b/examples/inverse/psf_ctf_label_leakage.py index 5975584c391..d74663d369a 100644 --- a/examples/inverse/psf_ctf_label_leakage.py +++ b/examples/inverse/psf_ctf_label_leakage.py @@ -25,9 +25,11 @@ import mne from mne.datasets import sample -from mne.minimum_norm import (read_inverse_operator, - make_inverse_resolution_matrix, - get_point_spread) +from mne.minimum_norm import ( + read_inverse_operator, + make_inverse_resolution_matrix, + get_point_spread, +) from mne.viz import circular_layout from mne_connectivity.viz import plot_connectivity_circle @@ -43,20 +45,20 @@ # resolution matrices for different methods. data_path = sample.data_path() -subjects_dir = data_path / 'subjects' -meg_path = data_path / 'MEG' / 'sample' -fname_fwd = meg_path / 'sample_audvis-meg-eeg-oct-6-fwd.fif' -fname_inv = meg_path / 'sample_audvis-meg-oct-6-meg-fixed-inv.fif' +subjects_dir = data_path / "subjects" +meg_path = data_path / "MEG" / "sample" +fname_fwd = meg_path / "sample_audvis-meg-eeg-oct-6-fwd.fif" +fname_inv = meg_path / "sample_audvis-meg-oct-6-meg-fixed-inv.fif" forward = mne.read_forward_solution(fname_fwd) # Convert forward solution to fixed source orientations -mne.convert_forward_solution( - forward, surf_ori=True, force_fixed=True, copy=False) +mne.convert_forward_solution(forward, surf_ori=True, force_fixed=True, copy=False) inverse_operator = read_inverse_operator(fname_inv) # Compute resolution matrices for MNE -rm_mne = make_inverse_resolution_matrix(forward, inverse_operator, - method='MNE', lambda2=1. / 3.**2) -src = inverse_operator['src'] +rm_mne = make_inverse_resolution_matrix( + forward, inverse_operator, method="MNE", lambda2=1.0 / 3.0**2 +) +src = inverse_operator["src"] del forward, inverse_operator # save memory # %% @@ -64,13 +66,12 @@ # -------------------------------------------------- # # Get labels for FreeSurfer 'aparc' cortical parcellation with 34 labels/hemi -labels = mne.read_labels_from_annot('sample', parc='aparc', - subjects_dir=subjects_dir) +labels = mne.read_labels_from_annot("sample", parc="aparc", subjects_dir=subjects_dir) n_labels = len(labels) label_colors = [label.color for label in labels] # First, we reorder the labels based on their location in the left hemi label_names = [label.name for label in labels] -lh_labels = [name for name in label_names if name.endswith('lh')] +lh_labels = [name for name in label_names if name.endswith("lh")] # Get the y-location of the label label_ypos = list() @@ -83,7 +84,7 @@ lh_labels = [label for (yp, label) in sorted(zip(label_ypos, lh_labels))] # For the right hemi -rh_labels = [label[:-2] + 'rh' for label in lh_labels] +rh_labels = [label[:-2] + "rh" for label in lh_labels] # %% # Compute point-spread function summaries (PCA) for all labels @@ -97,8 +98,8 @@ # spatial extents of labels. n_comp = 5 stcs_psf_mne, pca_vars_mne = get_point_spread( - rm_mne, src, labels, mode='pca', n_comp=n_comp, norm=None, - return_pca_vars=True) + rm_mne, src, labels, mode="pca", n_comp=n_comp, norm=None, return_pca_vars=True +) n_verts = rm_mne.shape[0] del rm_mne @@ -109,7 +110,7 @@ with np.printoptions(precision=1): for [name, var] in zip(label_names, pca_vars_mne): - print(f'{name}: {var.sum():.1f}% {var}') + print(f"{name}: {var.sum():.1f}% {var}") # %% # The output shows the summed variance explained by the first five principal @@ -132,15 +133,23 @@ # Save the plot order and create a circular layout node_order = lh_labels[::-1] + rh_labels # mirror label order across hemis -node_angles = circular_layout(label_names, node_order, start_pos=90, - group_boundaries=[0, len(label_names) / 2]) +node_angles = circular_layout( + label_names, node_order, start_pos=90, group_boundaries=[0, len(label_names) / 2] +) # Plot the graph using node colors from the FreeSurfer parcellation. We only # show the 200 strongest connections. fig, ax = plt.subplots( - figsize=(8, 8), facecolor='black', subplot_kw=dict(projection='polar')) -plot_connectivity_circle(leakage_mne, label_names, n_lines=200, - node_angles=node_angles, node_colors=label_colors, - title='MNE Leakage', ax=ax) + figsize=(8, 8), facecolor="black", subplot_kw=dict(projection="polar") +) +plot_connectivity_circle( + leakage_mne, + label_names, + n_lines=200, + node_angles=node_angles, + node_colors=label_colors, + title="MNE Leakage", + ax=ax, +) # %% # Most leakage occurs for neighbouring regions, but also for deeper regions @@ -175,20 +184,26 @@ # %% # Point-spread function for the lateral occipital label in the left hemisphere -brain_lh = stc_lh.plot(subjects_dir=subjects_dir, subject='sample', - hemi='both', views='caudal', - clim=dict(kind='value', - pos_lims=(0, max_val / 2., max_val))) -brain_lh.add_text(0.1, 0.9, label_names[idx[0]], 'title', font_size=16) +brain_lh = stc_lh.plot( + subjects_dir=subjects_dir, + subject="sample", + hemi="both", + views="caudal", + clim=dict(kind="value", pos_lims=(0, max_val / 2.0, max_val)), +) +brain_lh.add_text(0.1, 0.9, label_names[idx[0]], "title", font_size=16) # %% # and in the right hemisphere. -brain_rh = stc_rh.plot(subjects_dir=subjects_dir, subject='sample', - hemi='both', views='caudal', - clim=dict(kind='value', - pos_lims=(0, max_val / 2., max_val))) -brain_rh.add_text(0.1, 0.9, label_names[idx[1]], 'title', font_size=16) +brain_rh = stc_rh.plot( + subjects_dir=subjects_dir, + subject="sample", + hemi="both", + views="caudal", + clim=dict(kind="value", pos_lims=(0, max_val / 2.0, max_val)), +) +brain_rh.add_text(0.1, 0.9, label_names[idx[1]], "title", font_size=16) # %% # Both summary PSFs are confined to their respective hemispheres, indicating diff --git a/examples/inverse/psf_ctf_vertices.py b/examples/inverse/psf_ctf_vertices.py index a365991ffa1..0ec01a865dc 100644 --- a/examples/inverse/psf_ctf_vertices.py +++ b/examples/inverse/psf_ctf_vertices.py @@ -16,23 +16,25 @@ import mne from mne.datasets import sample -from mne.minimum_norm import (make_inverse_resolution_matrix, get_cross_talk, - get_point_spread) +from mne.minimum_norm import ( + make_inverse_resolution_matrix, + get_cross_talk, + get_point_spread, +) print(__doc__) data_path = sample.data_path() -subjects_dir = data_path / 'subjects' -meg_path = data_path / 'MEG' / 'sample' -fname_fwd = meg_path / 'sample_audvis-meg-eeg-oct-6-fwd.fif' -fname_cov = meg_path / 'sample_audvis-cov.fif' -fname_evo = meg_path / 'sample_audvis-ave.fif' +subjects_dir = data_path / "subjects" +meg_path = data_path / "MEG" / "sample" +fname_fwd = meg_path / "sample_audvis-meg-eeg-oct-6-fwd.fif" +fname_cov = meg_path / "sample_audvis-cov.fif" +fname_evo = meg_path / "sample_audvis-ave.fif" # read forward solution forward = mne.read_forward_solution(fname_fwd) # forward operator with fixed source orientations -mne.convert_forward_solution(forward, surf_ori=True, - force_fixed=True, copy=False) +mne.convert_forward_solution(forward, surf_ori=True, force_fixed=True, copy=False) # noise covariance matrix noise_cov = mne.read_cov(fname_cov) @@ -43,23 +45,24 @@ # make inverse operator from forward solution # free source orientation inverse_operator = mne.minimum_norm.make_inverse_operator( - info=evoked.info, forward=forward, noise_cov=noise_cov, loose=0., - depth=None) + info=evoked.info, forward=forward, noise_cov=noise_cov, loose=0.0, depth=None +) # regularisation parameter snr = 3.0 -lambda2 = 1.0 / snr ** 2 +lambda2 = 1.0 / snr**2 # compute resolution matrix for sLORETA -rm_lor = make_inverse_resolution_matrix(forward, inverse_operator, - method='sLORETA', lambda2=lambda2) +rm_lor = make_inverse_resolution_matrix( + forward, inverse_operator, method="sLORETA", lambda2=lambda2 +) # get PSF and CTF for sLORETA at one vertex sources = [1000] -stc_psf = get_point_spread(rm_lor, forward['src'], sources, norm=True) +stc_psf = get_point_spread(rm_lor, forward["src"], sources, norm=True) -stc_ctf = get_cross_talk(rm_lor, forward['src'], sources, norm=True) +stc_ctf = get_cross_talk(rm_lor, forward["src"], sources, norm=True) del rm_lor ############################################################################## @@ -68,37 +71,41 @@ # PSF: # Which vertex corresponds to selected source -vertno_lh = forward['src'][0]['vertno'] +vertno_lh = forward["src"][0]["vertno"] verttrue = [vertno_lh[sources[0]]] # just one vertex # find vertices with maxima in PSF and CTF vert_max_psf = vertno_lh[stc_psf.data.argmax()] vert_max_ctf = vertno_lh[stc_ctf.data.argmax()] -brain_psf = stc_psf.plot('sample', 'inflated', 'lh', subjects_dir=subjects_dir) -brain_psf.show_view('ventral') -brain_psf.add_text(0.1, 0.9, 'sLORETA PSF', 'title', font_size=16) +brain_psf = stc_psf.plot("sample", "inflated", "lh", subjects_dir=subjects_dir) +brain_psf.show_view("ventral") +brain_psf.add_text(0.1, 0.9, "sLORETA PSF", "title", font_size=16) # True source location for PSF -brain_psf.add_foci(verttrue, coords_as_verts=True, scale_factor=1., hemi='lh', - color='green') +brain_psf.add_foci( + verttrue, coords_as_verts=True, scale_factor=1.0, hemi="lh", color="green" +) # Maximum of PSF -brain_psf.add_foci(vert_max_psf, coords_as_verts=True, scale_factor=1., - hemi='lh', color='black') +brain_psf.add_foci( + vert_max_psf, coords_as_verts=True, scale_factor=1.0, hemi="lh", color="black" +) # %% # CTF: -brain_ctf = stc_ctf.plot('sample', 'inflated', 'lh', subjects_dir=subjects_dir) -brain_ctf.add_text(0.1, 0.9, 'sLORETA CTF', 'title', font_size=16) -brain_ctf.show_view('ventral') -brain_ctf.add_foci(verttrue, coords_as_verts=True, scale_factor=1., hemi='lh', - color='green') +brain_ctf = stc_ctf.plot("sample", "inflated", "lh", subjects_dir=subjects_dir) +brain_ctf.add_text(0.1, 0.9, "sLORETA CTF", "title", font_size=16) +brain_ctf.show_view("ventral") +brain_ctf.add_foci( + verttrue, coords_as_verts=True, scale_factor=1.0, hemi="lh", color="green" +) # Maximum of CTF -brain_ctf.add_foci(vert_max_ctf, coords_as_verts=True, scale_factor=1., - hemi='lh', color='black') +brain_ctf.add_foci( + vert_max_ctf, coords_as_verts=True, scale_factor=1.0, hemi="lh", color="black" +) # %% diff --git a/examples/inverse/psf_ctf_vertices_lcmv.py b/examples/inverse/psf_ctf_vertices_lcmv.py index de774c2149e..7f3d2a4207e 100644 --- a/examples/inverse/psf_ctf_vertices_lcmv.py +++ b/examples/inverse/psf_ctf_vertices_lcmv.py @@ -23,75 +23,91 @@ print(__doc__) data_path = sample.data_path() -subjects_dir = data_path / 'subjects' -meg_path = data_path / 'MEG' / 'sample' -fname_fwd = meg_path / 'sample_audvis-meg-eeg-oct-6-fwd.fif' -fname_cov = meg_path / 'sample_audvis-cov.fif' -fname_evo = meg_path / 'sample_audvis-ave.fif' -raw_fname = meg_path / 'sample_audvis_filt-0-40_raw.fif' +subjects_dir = data_path / "subjects" +meg_path = data_path / "MEG" / "sample" +fname_fwd = meg_path / "sample_audvis-meg-eeg-oct-6-fwd.fif" +fname_cov = meg_path / "sample_audvis-cov.fif" +fname_evo = meg_path / "sample_audvis-ave.fif" +raw_fname = meg_path / "sample_audvis_filt-0-40_raw.fif" # Read raw data raw = mne.io.read_raw_fif(raw_fname) # only pick good EEG/MEG sensors -raw.info['bads'] += ['EEG 053'] # bads + 1 more -picks = mne.pick_types(raw.info, meg=True, eeg=True, exclude='bads') +raw.info["bads"] += ["EEG 053"] # bads + 1 more +picks = mne.pick_types(raw.info, meg=True, eeg=True, exclude="bads") # Find events events = mne.find_events(raw) # event_id = {'aud/l': 1, 'aud/r': 2, 'vis/l': 3, 'vis/r': 4} -event_id = {'vis/l': 3, 'vis/r': 4} - -tmin, tmax = -.2, .25 # epoch duration -epochs = mne.Epochs(raw, events, event_id=event_id, tmin=tmin, tmax=tmax, - picks=picks, baseline=(-.2, 0.), preload=True) +event_id = {"vis/l": 3, "vis/r": 4} + +tmin, tmax = -0.2, 0.25 # epoch duration +epochs = mne.Epochs( + raw, + events, + event_id=event_id, + tmin=tmin, + tmax=tmax, + picks=picks, + baseline=(-0.2, 0.0), + preload=True, +) del raw # covariance matrix for pre-stimulus interval -tmin, tmax = -.2, 0. -cov_pre = mne.compute_covariance(epochs, tmin=tmin, tmax=tmax, - method='empirical') +tmin, tmax = -0.2, 0.0 +cov_pre = mne.compute_covariance(epochs, tmin=tmin, tmax=tmax, method="empirical") # covariance matrix for post-stimulus interval (around main evoked responses) -tmin, tmax = 0.05, .25 -cov_post = mne.compute_covariance(epochs, tmin=tmin, tmax=tmax, - method='empirical') +tmin, tmax = 0.05, 0.25 +cov_post = mne.compute_covariance(epochs, tmin=tmin, tmax=tmax, method="empirical") info = epochs.info del epochs # read forward solution forward = mne.read_forward_solution(fname_fwd) # use forward operator with fixed source orientations -mne.convert_forward_solution(forward, surf_ori=True, - force_fixed=True, copy=False) +mne.convert_forward_solution(forward, surf_ori=True, force_fixed=True, copy=False) # read noise covariance matrix noise_cov = mne.read_cov(fname_cov) # regularize noise covariance (we used 'empirical' above) -noise_cov = mne.cov.regularize(noise_cov, info, mag=0.1, grad=0.1, - eeg=0.1, rank='info') +noise_cov = mne.cov.regularize(noise_cov, info, mag=0.1, grad=0.1, eeg=0.1, rank="info") ############################################################################## # Compute LCMV filters with different data covariance matrices # ------------------------------------------------------------ # compute LCMV beamformer filters for pre-stimulus interval -filters_pre = make_lcmv(info, forward, cov_pre, reg=0.05, - noise_cov=noise_cov, - pick_ori=None, rank=None, - weight_norm=None, - reduce_rank=False, - verbose=False) +filters_pre = make_lcmv( + info, + forward, + cov_pre, + reg=0.05, + noise_cov=noise_cov, + pick_ori=None, + rank=None, + weight_norm=None, + reduce_rank=False, + verbose=False, +) # compute LCMV beamformer filters for post-stimulus interval -filters_post = make_lcmv(info, forward, cov_post, reg=0.05, - noise_cov=noise_cov, - pick_ori=None, rank=None, - weight_norm=None, - reduce_rank=False, - verbose=False) +filters_post = make_lcmv( + info, + forward, + cov_post, + reg=0.05, + noise_cov=noise_cov, + pick_ori=None, + rank=None, + weight_norm=None, + reduce_rank=False, + verbose=False, +) ############################################################################## # Compute resolution matrices for the two LCMV beamformers @@ -99,14 +115,14 @@ # compute cross-talk functions (CTFs) for one target vertex sources = [3000] -verttrue = [forward['src'][0]['vertno'][sources[0]]] # pick one vertex +verttrue = [forward["src"][0]["vertno"][sources[0]]] # pick one vertex rm_pre = make_lcmv_resolution_matrix(filters_pre, forward, info) -stc_pre = get_cross_talk(rm_pre, forward['src'], sources, norm=True) +stc_pre = get_cross_talk(rm_pre, forward["src"], sources, norm=True) del rm_pre ############################################################################## rm_post = make_lcmv_resolution_matrix(filters_post, forward, info) -stc_post = get_cross_talk(rm_post, forward['src'], sources, norm=True) +stc_post = get_cross_talk(rm_post, forward["src"], sources, norm=True) del rm_post ############################################################################## @@ -114,28 +130,51 @@ # --------- # Pre: -brain_pre = stc_pre.plot('sample', 'inflated', 'lh', subjects_dir=subjects_dir, - figure=1, clim=dict(kind='value', lims=(0, .2, .4))) - -brain_pre.add_text(0.1, 0.9, 'LCMV beamformer with pre-stimulus\ndata ' - 'covariance matrix', 'title', font_size=16) +brain_pre = stc_pre.plot( + "sample", + "inflated", + "lh", + subjects_dir=subjects_dir, + figure=1, + clim=dict(kind="value", lims=(0, 0.2, 0.4)), +) + +brain_pre.add_text( + 0.1, + 0.9, + "LCMV beamformer with pre-stimulus\ndata " "covariance matrix", + "title", + font_size=16, +) # mark true source location for CTFs -brain_pre.add_foci(verttrue, coords_as_verts=True, scale_factor=1., hemi='lh', - color='green') +brain_pre.add_foci( + verttrue, coords_as_verts=True, scale_factor=1.0, hemi="lh", color="green" +) # %% # Post: -brain_post = stc_post.plot('sample', 'inflated', 'lh', - subjects_dir=subjects_dir, - figure=2, clim=dict(kind='value', lims=(0, .2, .4))) - -brain_post.add_text(0.1, 0.9, 'LCMV beamformer with post-stimulus\ndata ' - 'covariance matrix', 'title', font_size=16) - -brain_post.add_foci(verttrue, coords_as_verts=True, scale_factor=1., - hemi='lh', color='green') +brain_post = stc_post.plot( + "sample", + "inflated", + "lh", + subjects_dir=subjects_dir, + figure=2, + clim=dict(kind="value", lims=(0, 0.2, 0.4)), +) + +brain_post.add_text( + 0.1, + 0.9, + "LCMV beamformer with post-stimulus\ndata " "covariance matrix", + "title", + font_size=16, +) + +brain_post.add_foci( + verttrue, coords_as_verts=True, scale_factor=1.0, hemi="lh", color="green" +) # %% # The pre-stimulus beamformer's CTF has lower values in parietal regions diff --git a/examples/inverse/psf_volume.py b/examples/inverse/psf_volume.py index 7cfd0675cd8..f2e465c1b20 100644 --- a/examples/inverse/psf_volume.py +++ b/examples/inverse/psf_volume.py @@ -24,13 +24,12 @@ print(__doc__) data_path = sample.data_path() -subjects_dir = data_path / 'subjects' -meg_path = data_path / 'MEG' / 'sample' -fname_cov = meg_path / 'sample_audvis-cov.fif' -fname_evo = meg_path / 'sample_audvis-ave.fif' -fname_trans = meg_path / 'sample_audvis_raw-trans.fif' -fname_bem = ( - subjects_dir / 'sample' / 'bem' / 'sample-5120-bem-sol.fif') +subjects_dir = data_path / "subjects" +meg_path = data_path / "MEG" / "sample" +fname_cov = meg_path / "sample_audvis-cov.fif" +fname_evo = meg_path / "sample_audvis-ave.fif" +fname_trans = meg_path / "sample_audvis_raw-trans.fif" +fname_bem = subjects_dir / "sample" / "bem" / "sample-5120-bem-sol.fif" # %% # For the volume, create a coarse source space for speed (don't do this in @@ -42,27 +41,29 @@ # create a coarse source space src_vol = mne.setup_volume_source_space( # this is a very course resolution! - 'sample', pos=15., subjects_dir=subjects_dir, - add_interpolator=False) # usually you want True, this is just for speed + "sample", pos=15.0, subjects_dir=subjects_dir, add_interpolator=False +) # usually you want True, this is just for speed # compute the forward forward_vol = mne.make_forward_solution( # MEG-only for speed - evoked.info, fname_trans, src_vol, fname_bem, eeg=False) + evoked.info, fname_trans, src_vol, fname_bem, eeg=False +) del src_vol # %% # Now make an inverse operator and compute the PSF at a source. inverse_operator_vol = mne.minimum_norm.make_inverse_operator( - info=evoked.info, forward=forward_vol, noise_cov=noise_cov) + info=evoked.info, forward=forward_vol, noise_cov=noise_cov +) # compute resolution matrix for sLORETA rm_lor_vol = make_inverse_resolution_matrix( - forward_vol, inverse_operator_vol, method='sLORETA', lambda2=1. / 9.) + forward_vol, inverse_operator_vol, method="sLORETA", lambda2=1.0 / 9.0 +) # get PSF and CTF for sLORETA at one vertex sources_vol = [100] -stc_psf_vol = get_point_spread( - rm_lor_vol, forward_vol['src'], sources_vol, norm=True) +stc_psf_vol = get_point_spread(rm_lor_vol, forward_vol["src"], sources_vol, norm=True) del rm_lor_vol ############################################################################## @@ -71,23 +72,30 @@ # PSF: # Which vertex corresponds to selected source -src_vol = forward_vol['src'] -verttrue_vol = src_vol[0]['vertno'][sources_vol] +src_vol = forward_vol["src"] +verttrue_vol = src_vol[0]["vertno"][sources_vol] # find vertex with maximum in PSF -max_vert_idx, _ = np.unravel_index( - stc_psf_vol.data.argmax(), stc_psf_vol.data.shape) -vert_max_ctf_vol = src_vol[0]['vertno'][[max_vert_idx]] +max_vert_idx, _ = np.unravel_index(stc_psf_vol.data.argmax(), stc_psf_vol.data.shape) +vert_max_ctf_vol = src_vol[0]["vertno"][[max_vert_idx]] # plot them brain_psf_vol = stc_psf_vol.plot_3d( - 'sample', src=forward_vol['src'], views='ven', subjects_dir=subjects_dir, - volume_options=dict(alpha=0.5)) -brain_psf_vol.add_text( - 0.1, 0.9, 'Volumetric sLORETA PSF', 'title', font_size=16) + "sample", + src=forward_vol["src"], + views="ven", + subjects_dir=subjects_dir, + volume_options=dict(alpha=0.5), +) +brain_psf_vol.add_text(0.1, 0.9, "Volumetric sLORETA PSF", "title", font_size=16) brain_psf_vol.add_foci( - verttrue_vol, coords_as_verts=True, - scale_factor=1, hemi='vol', color='green') + verttrue_vol, coords_as_verts=True, scale_factor=1, hemi="vol", color="green" +) brain_psf_vol.add_foci( - vert_max_ctf_vol, coords_as_verts=True, - scale_factor=1.25, hemi='vol', color='black', alpha=0.3) + vert_max_ctf_vol, + coords_as_verts=True, + scale_factor=1.25, + hemi="vol", + color="black", + alpha=0.3, +) diff --git a/examples/inverse/rap_music.py b/examples/inverse/rap_music.py index 937351b96dd..787c6d3b8c7 100644 --- a/examples/inverse/rap_music.py +++ b/examples/inverse/rap_music.py @@ -24,16 +24,15 @@ print(__doc__) data_path = sample.data_path() -subjects_dir = data_path / 'subjects' -meg_path = data_path / 'MEG' / 'sample' -fwd_fname = meg_path / 'sample_audvis-meg-eeg-oct-6-fwd.fif' -evoked_fname = meg_path / 'sample_audvis-ave.fif' -cov_fname = meg_path / 'sample_audvis-cov.fif' +subjects_dir = data_path / "subjects" +meg_path = data_path / "MEG" / "sample" +fwd_fname = meg_path / "sample_audvis-meg-eeg-oct-6-fwd.fif" +evoked_fname = meg_path / "sample_audvis-ave.fif" +cov_fname = meg_path / "sample_audvis-cov.fif" # Read the evoked response and crop it -condition = 'Right Auditory' -evoked = mne.read_evokeds(evoked_fname, condition=condition, - baseline=(None, 0)) +condition = "Right Auditory" +evoked = mne.read_evokeds(evoked_fname, condition=condition, baseline=(None, 0)) # select N100 evoked.crop(tmin=0.05, tmax=0.15) @@ -45,17 +44,16 @@ # Read noise covariance matrix noise_cov = mne.read_cov(cov_fname) -dipoles, residual = rap_music(evoked, forward, noise_cov, n_dipoles=2, - return_residual=True, verbose=True) -trans = forward['mri_head_t'] -plot_dipole_locations(dipoles, trans, 'sample', subjects_dir=subjects_dir) +dipoles, residual = rap_music( + evoked, forward, noise_cov, n_dipoles=2, return_residual=True, verbose=True +) +trans = forward["mri_head_t"] +plot_dipole_locations(dipoles, trans, "sample", subjects_dir=subjects_dir) plot_dipole_amplitudes(dipoles) # Plot the evoked data and the residual. -evoked.plot(ylim=dict(grad=[-300, 300], mag=[-800, 800], eeg=[-6, 8]), - time_unit='s') -residual.plot(ylim=dict(grad=[-300, 300], mag=[-800, 800], eeg=[-6, 8]), - time_unit='s') +evoked.plot(ylim=dict(grad=[-300, 300], mag=[-800, 800], eeg=[-6, 8]), time_unit="s") +residual.plot(ylim=dict(grad=[-300, 300], mag=[-800, 800], eeg=[-6, 8]), time_unit="s") # %% # References diff --git a/examples/inverse/read_inverse.py b/examples/inverse/read_inverse.py index fd604b08f35..a0fe1774252 100644 --- a/examples/inverse/read_inverse.py +++ b/examples/inverse/read_inverse.py @@ -21,30 +21,35 @@ print(__doc__) data_path = sample.data_path() -subjects_dir = data_path / 'subjects' -meg_path = data_path / 'MEG' / 'sample' -fname_trans = meg_path / 'sample_audvis_raw-trans.fif' -inv_fname = meg_path / 'sample_audvis-meg-oct-6-meg-inv.fif' +subjects_dir = data_path / "subjects" +meg_path = data_path / "MEG" / "sample" +fname_trans = meg_path / "sample_audvis_raw-trans.fif" +inv_fname = meg_path / "sample_audvis-meg-oct-6-meg-inv.fif" inv = read_inverse_operator(inv_fname) -print("Method: %s" % inv['methods']) -print("fMRI prior: %s" % inv['fmri_prior']) -print("Number of sources: %s" % inv['nsource']) -print("Number of channels: %s" % inv['nchan']) +print("Method: %s" % inv["methods"]) +print("fMRI prior: %s" % inv["fmri_prior"]) +print("Number of sources: %s" % inv["nsource"]) +print("Number of channels: %s" % inv["nchan"]) -src = inv['src'] # get the source space +src = inv["src"] # get the source space # Get access to the triangulation of the cortex -print("Number of vertices on the left hemisphere: %d" % len(src[0]['rr'])) -print("Number of triangles on left hemisphere: %d" % len(src[0]['use_tris'])) -print("Number of vertices on the right hemisphere: %d" % len(src[1]['rr'])) -print("Number of triangles on right hemisphere: %d" % len(src[1]['use_tris'])) +print("Number of vertices on the left hemisphere: %d" % len(src[0]["rr"])) +print("Number of triangles on left hemisphere: %d" % len(src[0]["use_tris"])) +print("Number of vertices on the right hemisphere: %d" % len(src[1]["rr"])) +print("Number of triangles on right hemisphere: %d" % len(src[1]["use_tris"])) # %% # Show the 3D source space -fig = mne.viz.plot_alignment(subject='sample', subjects_dir=subjects_dir, - trans=fname_trans, surfaces='white', src=src) -set_3d_view(fig, focalpoint=(0., 0., 0.06)) +fig = mne.viz.plot_alignment( + subject="sample", + subjects_dir=subjects_dir, + trans=fname_trans, + surfaces="white", + src=src, +) +set_3d_view(fig, focalpoint=(0.0, 0.0, 0.06)) diff --git a/examples/inverse/read_stc.py b/examples/inverse/read_stc.py index 3ae91bfc799..d98ba170400 100644 --- a/examples/inverse/read_stc.py +++ b/examples/inverse/read_stc.py @@ -22,17 +22,18 @@ print(__doc__) data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -fname = meg_path / 'sample_audvis-meg' +meg_path = data_path / "MEG" / "sample" +fname = meg_path / "sample_audvis-meg" stc = mne.read_source_estimate(fname) n_vertices, n_samples = stc.data.shape -print("stc data size: %s (nb of vertices) x %s (nb of samples)" - % (n_vertices, n_samples)) +print( + "stc data size: %s (nb of vertices) x %s (nb of samples)" % (n_vertices, n_samples) +) # View source activations plt.plot(stc.times, stc.data[::100, :].T) -plt.xlabel('time (ms)') -plt.ylabel('Source amplitude') +plt.xlabel("time (ms)") +plt.ylabel("Source amplitude") plt.show() diff --git a/examples/inverse/resolution_metrics.py b/examples/inverse/resolution_metrics.py index 10d3e03944c..e3e98827bea 100644 --- a/examples/inverse/resolution_metrics.py +++ b/examples/inverse/resolution_metrics.py @@ -25,17 +25,16 @@ print(__doc__) data_path = sample.data_path() -subjects_dir = data_path / 'subjects' -meg_path = data_path / 'MEG' / 'sample' -fname_fwd = meg_path / 'sample_audvis-meg-eeg-oct-6-fwd.fif' -fname_cov = meg_path / 'sample_audvis-cov.fif' -fname_evo = meg_path / 'sample_audvis-ave.fif' +subjects_dir = data_path / "subjects" +meg_path = data_path / "MEG" / "sample" +fname_fwd = meg_path / "sample_audvis-meg-eeg-oct-6-fwd.fif" +fname_cov = meg_path / "sample_audvis-cov.fif" +fname_evo = meg_path / "sample_audvis-ave.fif" # read forward solution forward = mne.read_forward_solution(fname_fwd) # forward operator with fixed source orientations -mne.convert_forward_solution(forward, surf_ori=True, - force_fixed=True, copy=False) +mne.convert_forward_solution(forward, surf_ori=True, force_fixed=True, copy=False) # noise covariance matrix noise_cov = mne.read_cov(fname_cov) @@ -46,12 +45,12 @@ # make inverse operator from forward solution # free source orientation inverse_operator = mne.minimum_norm.make_inverse_operator( - info=evoked.info, forward=forward, noise_cov=noise_cov, loose=0., - depth=None) + info=evoked.info, forward=forward, noise_cov=noise_cov, loose=0.0, depth=None +) # regularisation parameter snr = 3.0 -lambda2 = 1.0 / snr ** 2 +lambda2 = 1.0 / snr**2 # %% # MNE @@ -59,12 +58,15 @@ # Compute resolution matrices, peak localisation error (PLE) for point spread # functions (PSFs), spatial deviation (SD) for PSFs: -rm_mne = make_inverse_resolution_matrix(forward, inverse_operator, - method='MNE', lambda2=lambda2) -ple_mne_psf = resolution_metrics(rm_mne, inverse_operator['src'], - function='psf', metric='peak_err') -sd_mne_psf = resolution_metrics(rm_mne, inverse_operator['src'], - function='psf', metric='sd_ext') +rm_mne = make_inverse_resolution_matrix( + forward, inverse_operator, method="MNE", lambda2=lambda2 +) +ple_mne_psf = resolution_metrics( + rm_mne, inverse_operator["src"], function="psf", metric="peak_err" +) +sd_mne_psf = resolution_metrics( + rm_mne, inverse_operator["src"], function="psf", metric="sd_ext" +) del rm_mne # %% @@ -72,39 +74,57 @@ # ---- # Do the same for dSPM: -rm_dspm = make_inverse_resolution_matrix(forward, inverse_operator, - method='dSPM', lambda2=lambda2) -ple_dspm_psf = resolution_metrics(rm_dspm, inverse_operator['src'], - function='psf', metric='peak_err') -sd_dspm_psf = resolution_metrics(rm_dspm, inverse_operator['src'], - function='psf', metric='sd_ext') +rm_dspm = make_inverse_resolution_matrix( + forward, inverse_operator, method="dSPM", lambda2=lambda2 +) +ple_dspm_psf = resolution_metrics( + rm_dspm, inverse_operator["src"], function="psf", metric="peak_err" +) +sd_dspm_psf = resolution_metrics( + rm_dspm, inverse_operator["src"], function="psf", metric="sd_ext" +) del rm_dspm, forward # %% # Visualize results # ----------------- # Visualise peak localisation error (PLE) across the whole cortex for MNE PSF: -brain_ple_mne = ple_mne_psf.plot('sample', 'inflated', 'lh', - subjects_dir=subjects_dir, figure=1, - clim=dict(kind='value', lims=(0, 2, 4))) -brain_ple_mne.add_text(0.1, 0.9, 'PLE MNE', 'title', font_size=16) +brain_ple_mne = ple_mne_psf.plot( + "sample", + "inflated", + "lh", + subjects_dir=subjects_dir, + figure=1, + clim=dict(kind="value", lims=(0, 2, 4)), +) +brain_ple_mne.add_text(0.1, 0.9, "PLE MNE", "title", font_size=16) # %% # And dSPM: -brain_ple_dspm = ple_dspm_psf.plot('sample', 'inflated', 'lh', - subjects_dir=subjects_dir, figure=2, - clim=dict(kind='value', lims=(0, 2, 4))) -brain_ple_dspm.add_text(0.1, 0.9, 'PLE dSPM', 'title', font_size=16) +brain_ple_dspm = ple_dspm_psf.plot( + "sample", + "inflated", + "lh", + subjects_dir=subjects_dir, + figure=2, + clim=dict(kind="value", lims=(0, 2, 4)), +) +brain_ple_dspm.add_text(0.1, 0.9, "PLE dSPM", "title", font_size=16) # %% # Subtract the two distributions and plot this difference diff_ple = ple_mne_psf - ple_dspm_psf -brain_ple_diff = diff_ple.plot('sample', 'inflated', 'lh', - subjects_dir=subjects_dir, figure=3, - clim=dict(kind='value', pos_lims=(0., 1., 2.))) -brain_ple_diff.add_text(0.1, 0.9, 'PLE MNE-dSPM', 'title', font_size=16) +brain_ple_diff = diff_ple.plot( + "sample", + "inflated", + "lh", + subjects_dir=subjects_dir, + figure=3, + clim=dict(kind="value", pos_lims=(0.0, 1.0, 2.0)), +) +brain_ple_diff.add_text(0.1, 0.9, "PLE MNE-dSPM", "title", font_size=16) # %% # These plots show that dSPM has generally lower peak localization error (red @@ -114,28 +134,43 @@ # Next we'll visualise spatial deviation (SD) across the whole cortex for MNE # PSF: -brain_sd_mne = sd_mne_psf.plot('sample', 'inflated', 'lh', - subjects_dir=subjects_dir, figure=4, - clim=dict(kind='value', lims=(0, 2, 4))) -brain_sd_mne.add_text(0.1, 0.9, 'SD MNE', 'title', font_size=16) +brain_sd_mne = sd_mne_psf.plot( + "sample", + "inflated", + "lh", + subjects_dir=subjects_dir, + figure=4, + clim=dict(kind="value", lims=(0, 2, 4)), +) +brain_sd_mne.add_text(0.1, 0.9, "SD MNE", "title", font_size=16) # %% # And dSPM: -brain_sd_dspm = sd_dspm_psf.plot('sample', 'inflated', 'lh', - subjects_dir=subjects_dir, figure=5, - clim=dict(kind='value', lims=(0, 2, 4))) -brain_sd_dspm.add_text(0.1, 0.9, 'SD dSPM', 'title', font_size=16) +brain_sd_dspm = sd_dspm_psf.plot( + "sample", + "inflated", + "lh", + subjects_dir=subjects_dir, + figure=5, + clim=dict(kind="value", lims=(0, 2, 4)), +) +brain_sd_dspm.add_text(0.1, 0.9, "SD dSPM", "title", font_size=16) # %% # Subtract the two distributions and plot this difference: diff_sd = sd_mne_psf - sd_dspm_psf -brain_sd_diff = diff_sd.plot('sample', 'inflated', 'lh', - subjects_dir=subjects_dir, figure=6, - clim=dict(kind='value', pos_lims=(0., 1., 2.))) -brain_sd_diff.add_text(0.1, 0.9, 'SD MNE-dSPM', 'title', font_size=16) +brain_sd_diff = diff_sd.plot( + "sample", + "inflated", + "lh", + subjects_dir=subjects_dir, + figure=6, + clim=dict(kind="value", pos_lims=(0.0, 1.0, 2.0)), +) +brain_sd_diff.add_text(0.1, 0.9, "SD MNE-dSPM", "title", font_size=16) # %% # These plots show that dSPM has generally higher spatial deviation than MNE diff --git a/examples/inverse/resolution_metrics_eegmeg.py b/examples/inverse/resolution_metrics_eegmeg.py index 06268178058..d570cb42baa 100644 --- a/examples/inverse/resolution_metrics_eegmeg.py +++ b/examples/inverse/resolution_metrics_eegmeg.py @@ -27,17 +27,18 @@ print(__doc__) data_path = sample.data_path() -subjects_dir = data_path / 'subjects/' -meg_path = data_path / 'MEG' / 'sample' -fname_fwd_emeg = meg_path / 'sample_audvis-meg-eeg-oct-6-fwd.fif' -fname_cov = meg_path / 'sample_audvis-cov.fif' -fname_evo = meg_path / 'sample_audvis-ave.fif' +subjects_dir = data_path / "subjects/" +meg_path = data_path / "MEG" / "sample" +fname_fwd_emeg = meg_path / "sample_audvis-meg-eeg-oct-6-fwd.fif" +fname_cov = meg_path / "sample_audvis-cov.fif" +fname_evo = meg_path / "sample_audvis-ave.fif" # read forward solution with EEG and MEG forward_emeg = mne.read_forward_solution(fname_fwd_emeg) # forward operator with fixed source orientations -forward_emeg = mne.convert_forward_solution(forward_emeg, surf_ori=True, - force_fixed=True) +forward_emeg = mne.convert_forward_solution( + forward_emeg, surf_ori=True, force_fixed=True +) # create a forward solution with MEG only forward_meg = mne.pick_types_forward(forward_emeg, meg=True, eeg=False) @@ -50,16 +51,16 @@ # make inverse operator from forward solution for MEG and EEGMEG inv_emeg = mne.minimum_norm.make_inverse_operator( - info=evoked.info, forward=forward_emeg, noise_cov=noise_cov, loose=0., - depth=None) + info=evoked.info, forward=forward_emeg, noise_cov=noise_cov, loose=0.0, depth=None +) inv_meg = mne.minimum_norm.make_inverse_operator( - info=evoked.info, forward=forward_meg, noise_cov=noise_cov, loose=0., - depth=None) + info=evoked.info, forward=forward_meg, noise_cov=noise_cov, loose=0.0, depth=None +) # regularisation parameter snr = 3.0 -lambda2 = 1.0 / snr ** 2 +lambda2 = 1.0 / snr**2 # %% # EEGMEG @@ -67,12 +68,15 @@ # Compute resolution matrices, localization error, and spatial deviations # for MNE: -rm_emeg = make_inverse_resolution_matrix(forward_emeg, inv_emeg, - method='MNE', lambda2=lambda2) -ple_psf_emeg = resolution_metrics(rm_emeg, inv_emeg['src'], - function='psf', metric='peak_err') -sd_psf_emeg = resolution_metrics(rm_emeg, inv_emeg['src'], - function='psf', metric='sd_ext') +rm_emeg = make_inverse_resolution_matrix( + forward_emeg, inv_emeg, method="MNE", lambda2=lambda2 +) +ple_psf_emeg = resolution_metrics( + rm_emeg, inv_emeg["src"], function="psf", metric="peak_err" +) +sd_psf_emeg = resolution_metrics( + rm_emeg, inv_emeg["src"], function="psf", metric="sd_ext" +) del rm_emeg # %% @@ -80,12 +84,13 @@ # --- # Do the same for MEG: -rm_meg = make_inverse_resolution_matrix(forward_meg, inv_meg, - method='MNE', lambda2=lambda2) -ple_psf_meg = resolution_metrics(rm_meg, inv_meg['src'], - function='psf', metric='peak_err') -sd_psf_meg = resolution_metrics(rm_meg, inv_meg['src'], - function='psf', metric='sd_ext') +rm_meg = make_inverse_resolution_matrix( + forward_meg, inv_meg, method="MNE", lambda2=lambda2 +) +ple_psf_meg = resolution_metrics( + rm_meg, inv_meg["src"], function="psf", metric="peak_err" +) +sd_psf_meg = resolution_metrics(rm_meg, inv_meg["src"], function="psf", metric="sd_ext") del rm_meg # %% @@ -93,64 +98,94 @@ # ------------- # Look at peak localisation error (PLE) across the whole cortex for PSF: -brain_ple_emeg = ple_psf_emeg.plot('sample', 'inflated', 'lh', - subjects_dir=subjects_dir, figure=1, - clim=dict(kind='value', lims=(0, 2, 4))) +brain_ple_emeg = ple_psf_emeg.plot( + "sample", + "inflated", + "lh", + subjects_dir=subjects_dir, + figure=1, + clim=dict(kind="value", lims=(0, 2, 4)), +) -brain_ple_emeg.add_text(0.1, 0.9, 'PLE PSF EMEG', 'title', font_size=16) +brain_ple_emeg.add_text(0.1, 0.9, "PLE PSF EMEG", "title", font_size=16) # %% # For MEG only: -brain_ple_meg = ple_psf_meg.plot('sample', 'inflated', 'lh', - subjects_dir=subjects_dir, figure=2, - clim=dict(kind='value', lims=(0, 2, 4))) +brain_ple_meg = ple_psf_meg.plot( + "sample", + "inflated", + "lh", + subjects_dir=subjects_dir, + figure=2, + clim=dict(kind="value", lims=(0, 2, 4)), +) -brain_ple_meg.add_text(0.1, 0.9, 'PLE PSF MEG', 'title', font_size=16) +brain_ple_meg.add_text(0.1, 0.9, "PLE PSF MEG", "title", font_size=16) # %% # Subtract the two distributions and plot this difference: diff_ple = ple_psf_emeg - ple_psf_meg -brain_ple_diff = diff_ple.plot('sample', 'inflated', 'lh', - subjects_dir=subjects_dir, figure=3, - clim=dict(kind='value', pos_lims=(0., .5, 1.)), - smoothing_steps=20) +brain_ple_diff = diff_ple.plot( + "sample", + "inflated", + "lh", + subjects_dir=subjects_dir, + figure=3, + clim=dict(kind="value", pos_lims=(0.0, 0.5, 1.0)), + smoothing_steps=20, +) -brain_ple_diff.add_text(0.1, 0.9, 'PLE EMEG-MEG', 'title', font_size=16) +brain_ple_diff.add_text(0.1, 0.9, "PLE EMEG-MEG", "title", font_size=16) # %% # These plots show that with respect to peak localization error, adding EEG to # MEG does not bring much benefit. Next let's visualise spatial deviation (SD) # across the whole cortex for PSF: -brain_sd_emeg = sd_psf_emeg.plot('sample', 'inflated', 'lh', - subjects_dir=subjects_dir, figure=4, - clim=dict(kind='value', lims=(0, 2, 4))) +brain_sd_emeg = sd_psf_emeg.plot( + "sample", + "inflated", + "lh", + subjects_dir=subjects_dir, + figure=4, + clim=dict(kind="value", lims=(0, 2, 4)), +) -brain_sd_emeg.add_text(0.1, 0.9, 'SD PSF EMEG', 'title', font_size=16) +brain_sd_emeg.add_text(0.1, 0.9, "SD PSF EMEG", "title", font_size=16) # %% # For MEG only: -brain_sd_meg = sd_psf_meg.plot('sample', 'inflated', 'lh', - subjects_dir=subjects_dir, figure=5, - clim=dict(kind='value', lims=(0, 2, 4))) +brain_sd_meg = sd_psf_meg.plot( + "sample", + "inflated", + "lh", + subjects_dir=subjects_dir, + figure=5, + clim=dict(kind="value", lims=(0, 2, 4)), +) -brain_sd_meg.add_text(0.1, 0.9, 'SD PSF MEG', 'title', font_size=16) +brain_sd_meg.add_text(0.1, 0.9, "SD PSF MEG", "title", font_size=16) # %% # Subtract the two distributions and plot this difference: diff_sd = sd_psf_emeg - sd_psf_meg -brain_sd_diff = diff_sd.plot('sample', 'inflated', 'lh', - subjects_dir=subjects_dir, figure=6, - clim=dict(kind='value', pos_lims=(0., .5, 1.)), - smoothing_steps=20) - -brain_sd_diff.add_text(0.1, 0.9, 'SD EMEG-MEG', 'title', font_size=16) +brain_sd_diff = diff_sd.plot( + "sample", + "inflated", + "lh", + subjects_dir=subjects_dir, + figure=6, + clim=dict(kind="value", pos_lims=(0.0, 0.5, 1.0)), + smoothing_steps=20, +) + +brain_sd_diff.add_text(0.1, 0.9, "SD EMEG-MEG", "title", font_size=16) # %% # Adding EEG to MEG decreases the spatial extent of point-spread diff --git a/examples/inverse/snr_estimate.py b/examples/inverse/snr_estimate.py index 956f3cbe643..4a88a9d13c4 100644 --- a/examples/inverse/snr_estimate.py +++ b/examples/inverse/snr_estimate.py @@ -21,9 +21,9 @@ print(__doc__) -data_dir = data_path() / 'MEG' / 'sample' -fname_inv = data_dir / 'sample_audvis-meg-oct-6-meg-inv.fif' -fname_evoked = data_dir / 'sample_audvis-ave.fif' +data_dir = data_path() / "MEG" / "sample" +fname_inv = data_dir / "sample_audvis-meg-oct-6-meg-inv.fif" +fname_evoked = data_dir / "sample_audvis-ave.fif" inv = read_inverse_operator(fname_inv) evoked = read_evokeds(fname_evoked, baseline=(None, 0))[0] diff --git a/examples/inverse/source_space_snr.py b/examples/inverse/source_space_snr.py index 0dd14e71722..c5599a5d331 100644 --- a/examples/inverse/source_space_snr.py +++ b/examples/inverse/source_space_snr.py @@ -26,15 +26,14 @@ print(__doc__) data_path = sample.data_path() -subjects_dir = data_path / 'subjects' +subjects_dir = data_path / "subjects" # Read data -meg_path = data_path / 'MEG' / 'sample' -fname_evoked = meg_path / 'sample_audvis-ave.fif' -evoked = mne.read_evokeds(fname_evoked, condition='Left Auditory', - baseline=(None, 0)) -fname_fwd = meg_path / 'sample_audvis-meg-eeg-oct-6-fwd.fif' -fname_cov = meg_path / 'sample_audvis-cov.fif' +meg_path = data_path / "MEG" / "sample" +fname_evoked = meg_path / "sample_audvis-ave.fif" +evoked = mne.read_evokeds(fname_evoked, condition="Left Auditory", baseline=(None, 0)) +fname_fwd = meg_path / "sample_audvis-meg-eeg-oct-6-fwd.fif" +fname_cov = meg_path / "sample_audvis-cov.fif" fwd = mne.read_forward_solution(fname_fwd) cov = mne.read_cov(fname_cov) @@ -43,8 +42,8 @@ # Calculate MNE: snr = 3.0 -lambda2 = 1.0 / snr ** 2 -stc = apply_inverse(evoked, inv_op, lambda2, 'MNE', verbose=True) +lambda2 = 1.0 / snr**2 +stc = apply_inverse(evoked, inv_op, lambda2, "MNE", verbose=True) # Calculate SNR in source space: snr_stc = stc.estimate_snr(evoked.info, fwd, cov) @@ -54,17 +53,23 @@ fig, ax = plt.subplots() ax.plot(evoked.times, ave) -ax.set(xlabel='Time (s)', ylabel='SNR MEG-EEG') +ax.set(xlabel="Time (s)", ylabel="SNR MEG-EEG") fig.tight_layout() # Find time point of maximum SNR maxidx = np.argmax(ave) # Plot SNR on source space at the time point of maximum SNR: -kwargs = dict(initial_time=evoked.times[maxidx], hemi='split', - views=['lat', 'med'], subjects_dir=subjects_dir, size=(600, 600), - clim=dict(kind='value', lims=(-100, -70, -40)), - transparent=True, colormap='viridis') +kwargs = dict( + initial_time=evoked.times[maxidx], + hemi="split", + views=["lat", "med"], + subjects_dir=subjects_dir, + size=(600, 600), + clim=dict(kind="value", lims=(-100, -70, -40)), + transparent=True, + colormap="viridis", +) brain = snr_stc.plot(**kwargs) # %% @@ -73,9 +78,8 @@ # Next we do the same for EEG and plot the result on the cortex: evoked_eeg = evoked.copy().pick_types(eeg=True, meg=False) -inv_op_eeg = make_inverse_operator(evoked_eeg.info, fwd, cov, fixed=True, - verbose=True) -stc_eeg = apply_inverse(evoked_eeg, inv_op_eeg, lambda2, 'MNE', verbose=True) +inv_op_eeg = make_inverse_operator(evoked_eeg.info, fwd, cov, fixed=True, verbose=True) +stc_eeg = apply_inverse(evoked_eeg, inv_op_eeg, lambda2, "MNE", verbose=True) snr_stc_eeg = stc_eeg.estimate_snr(evoked_eeg.info, fwd, cov) brain = snr_stc_eeg.plot(**kwargs) diff --git a/examples/inverse/time_frequency_mixed_norm_inverse.py b/examples/inverse/time_frequency_mixed_norm_inverse.py index 2271c58f24c..d69d4769058 100644 --- a/examples/inverse/time_frequency_mixed_norm_inverse.py +++ b/examples/inverse/time_frequency_mixed_norm_inverse.py @@ -32,23 +32,26 @@ from mne.datasets import sample from mne.minimum_norm import make_inverse_operator, apply_inverse from mne.inverse_sparse import tf_mixed_norm, make_stc_from_dipoles -from mne.viz import (plot_sparse_source_estimates, - plot_dipole_locations, plot_dipole_amplitudes) +from mne.viz import ( + plot_sparse_source_estimates, + plot_dipole_locations, + plot_dipole_amplitudes, +) print(__doc__) data_path = sample.data_path() -subjects_dir = data_path / 'subjects' -meg_path = data_path / 'MEG' / 'sample' -fwd_fname = meg_path / 'sample_audvis-meg-eeg-oct-6-fwd.fif' -ave_fname = meg_path / 'sample_audvis-no-filter-ave.fif' -cov_fname = meg_path / 'sample_audvis-shrunk-cov.fif' +subjects_dir = data_path / "subjects" +meg_path = data_path / "MEG" / "sample" +fwd_fname = meg_path / "sample_audvis-meg-eeg-oct-6-fwd.fif" +ave_fname = meg_path / "sample_audvis-no-filter-ave.fif" +cov_fname = meg_path / "sample_audvis-shrunk-cov.fif" # Read noise covariance matrix cov = mne.read_cov(cov_fname) # Handling average file -condition = 'Left visual' +condition = "Left visual" evoked = mne.read_evokeds(ave_fname, condition=condition, baseline=(None, 0)) # We make the window slightly larger than what you'll eventually be interested # in ([-0.05, 0.3]) to avoid edge effects. @@ -61,7 +64,7 @@ # Run solver # alpha parameter is between 0 and 100 (100 gives 0 active source) -alpha = 40. # general regularization parameter +alpha = 40.0 # general regularization parameter # l1_ratio parameter between 0 and 1 promotes temporal smoothness # (0 means no temporal regularization) l1_ratio = 0.03 # temporal regularization parameter @@ -69,17 +72,31 @@ loose, depth = 0.2, 0.9 # loose orientation & depth weighting # Compute dSPM solution to be used as weights in MxNE -inverse_operator = make_inverse_operator(evoked.info, forward, cov, - loose=loose, depth=depth) -stc_dspm = apply_inverse(evoked, inverse_operator, lambda2=1. / 9., - method='dSPM') +inverse_operator = make_inverse_operator( + evoked.info, forward, cov, loose=loose, depth=depth +) +stc_dspm = apply_inverse(evoked, inverse_operator, lambda2=1.0 / 9.0, method="dSPM") # Compute TF-MxNE inverse solution with dipole output dipoles, residual = tf_mixed_norm( - evoked, forward, cov, alpha=alpha, l1_ratio=l1_ratio, loose=loose, - depth=depth, maxit=200, tol=1e-6, weights=stc_dspm, weights_min=8., - debias=True, wsize=16, tstep=4, window=0.05, return_as_dipoles=True, - return_residual=True) + evoked, + forward, + cov, + alpha=alpha, + l1_ratio=l1_ratio, + loose=loose, + depth=depth, + maxit=200, + tol=1e-6, + weights=stc_dspm, + weights_min=8.0, + debias=True, + wsize=16, + tstep=4, + window=0.05, + return_as_dipoles=True, + return_residual=True, +) # Crop to remove edges for dip in dipoles: @@ -94,9 +111,14 @@ # %% # Plot location of the strongest dipole with MRI slices idx = np.argmax([np.max(np.abs(dip.amplitude)) for dip in dipoles]) -plot_dipole_locations(dipoles[idx], forward['mri_head_t'], 'sample', - subjects_dir=subjects_dir, mode='orthoview', - idx='amplitude') +plot_dipole_locations( + dipoles[idx], + forward["mri_head_t"], + "sample", + subjects_dir=subjects_dir, + mode="orthoview", + idx="amplitude", +) # # Plot dipole locations of all dipoles with MRI slices: # for dip in dipoles: @@ -107,31 +129,51 @@ # %% # Show the evoked response and the residual for gradiometers ylim = dict(grad=[-120, 120]) -evoked.pick_types(meg='grad', exclude='bads') -evoked.plot(titles=dict(grad='Evoked Response: Gradiometers'), ylim=ylim, - proj=True, time_unit='s') - -residual.pick_types(meg='grad', exclude='bads') -residual.plot(titles=dict(grad='Residuals: Gradiometers'), ylim=ylim, - proj=True, time_unit='s') +evoked.pick_types(meg="grad", exclude="bads") +evoked.plot( + titles=dict(grad="Evoked Response: Gradiometers"), + ylim=ylim, + proj=True, + time_unit="s", +) + +residual.pick_types(meg="grad", exclude="bads") +residual.plot( + titles=dict(grad="Residuals: Gradiometers"), ylim=ylim, proj=True, time_unit="s" +) # %% # Generate stc from dipoles -stc = make_stc_from_dipoles(dipoles, forward['src']) +stc = make_stc_from_dipoles(dipoles, forward["src"]) # %% # View in 2D and 3D ("glass" brain like 3D plot) -plot_sparse_source_estimates(forward['src'], stc, bgcolor=(1, 1, 1), - opacity=0.1, fig_name="TF-MxNE (cond %s)" - % condition, modes=['sphere'], scale_factors=[1.]) - -time_label = 'TF-MxNE time=%0.2f ms' -clim = dict(kind='value', lims=[10e-9, 15e-9, 20e-9]) -brain = stc.plot('sample', 'inflated', 'rh', views='medial', - clim=clim, time_label=time_label, smoothing_steps=5, - subjects_dir=subjects_dir, initial_time=150, time_unit='ms') -brain.add_label("V1", color="yellow", scalar_thresh=.5, borders=True) -brain.add_label("V2", color="red", scalar_thresh=.5, borders=True) +plot_sparse_source_estimates( + forward["src"], + stc, + bgcolor=(1, 1, 1), + opacity=0.1, + fig_name="TF-MxNE (cond %s)" % condition, + modes=["sphere"], + scale_factors=[1.0], +) + +time_label = "TF-MxNE time=%0.2f ms" +clim = dict(kind="value", lims=[10e-9, 15e-9, 20e-9]) +brain = stc.plot( + "sample", + "inflated", + "rh", + views="medial", + clim=clim, + time_label=time_label, + smoothing_steps=5, + subjects_dir=subjects_dir, + initial_time=150, + time_unit="ms", +) +brain.add_label("V1", color="yellow", scalar_thresh=0.5, borders=True) +brain.add_label("V2", color="red", scalar_thresh=0.5, borders=True) # %% # References diff --git a/examples/inverse/vector_mne_solution.py b/examples/inverse/vector_mne_solution.py index caba3a46201..2733f40acd1 100644 --- a/examples/inverse/vector_mne_solution.py +++ b/examples/inverse/vector_mne_solution.py @@ -33,34 +33,37 @@ print(__doc__) data_path = sample.data_path() -subjects_dir = data_path / 'subjects' +subjects_dir = data_path / "subjects" smoothing_steps = 7 # Read evoked data -meg_path = data_path / 'MEG' / 'sample' -fname_evoked = meg_path / 'sample_audvis-ave.fif' +meg_path = data_path / "MEG" / "sample" +fname_evoked = meg_path / "sample_audvis-ave.fif" evoked = mne.read_evokeds(fname_evoked, condition=0, baseline=(None, 0)) # Read inverse solution -fname_inv = meg_path / 'sample_audvis-meg-oct-6-meg-inv.fif' +fname_inv = meg_path / "sample_audvis-meg-oct-6-meg-inv.fif" inv = read_inverse_operator(fname_inv) # Apply inverse solution, set pick_ori='vector' to obtain a # :class:`mne.VectorSourceEstimate` object snr = 3.0 -lambda2 = 1.0 / snr ** 2 -stc = apply_inverse(evoked, inv, lambda2, 'dSPM', pick_ori='vector') +lambda2 = 1.0 / snr**2 +stc = apply_inverse(evoked, inv, lambda2, "dSPM", pick_ori="vector") # Use peak getter to move visualization to the time point of the peak magnitude -_, peak_time = stc.magnitude().get_peak(hemi='lh') +_, peak_time = stc.magnitude().get_peak(hemi="lh") # %% # Plot the source estimate: # sphinx_gallery_thumbnail_number = 2 brain = stc.plot( - initial_time=peak_time, hemi='lh', subjects_dir=subjects_dir, - smoothing_steps=smoothing_steps) + initial_time=peak_time, + hemi="lh", + subjects_dir=subjects_dir, + smoothing_steps=smoothing_steps, +) # You can save a brain movie with: # brain.save_movie(time_dilation=20, tmin=0.05, tmax=0.16, framerate=10, @@ -69,32 +72,43 @@ # %% # Plot the activation in the direction of maximal power for this data: -stc_max, directions = stc.project('pca', src=inv['src']) +stc_max, directions = stc.project("pca", src=inv["src"]) # These directions must by design be close to the normals because this # inverse was computed with loose=0.2 -print('Absolute cosine similarity between source normals and directions: ' - f'{np.abs(np.sum(directions * inv["source_nn"][2::3], axis=-1)).mean()}') +print( + "Absolute cosine similarity between source normals and directions: " + f'{np.abs(np.sum(directions * inv["source_nn"][2::3], axis=-1)).mean()}' +) brain_max = stc_max.plot( - initial_time=peak_time, hemi='lh', subjects_dir=subjects_dir, - time_label='Max power', smoothing_steps=smoothing_steps) + initial_time=peak_time, + hemi="lh", + subjects_dir=subjects_dir, + time_label="Max power", + smoothing_steps=smoothing_steps, +) # %% # The normal is very similar: -brain_normal = stc.project('normal', inv['src'])[0].plot( - initial_time=peak_time, hemi='lh', subjects_dir=subjects_dir, - time_label='Normal', smoothing_steps=smoothing_steps) +brain_normal = stc.project("normal", inv["src"])[0].plot( + initial_time=peak_time, + hemi="lh", + subjects_dir=subjects_dir, + time_label="Normal", + smoothing_steps=smoothing_steps, +) # %% # You can also do this with a fixed-orientation inverse. It looks a lot like # the result above because the ``loose=0.2`` orientation constraint keeps # sources close to fixed orientation: -fname_inv_fixed = ( - meg_path / 'sample_audvis-meg-oct-6-meg-fixed-inv.fif') +fname_inv_fixed = meg_path / "sample_audvis-meg-oct-6-meg-fixed-inv.fif" inv_fixed = read_inverse_operator(fname_inv_fixed) -stc_fixed = apply_inverse( - evoked, inv_fixed, lambda2, 'dSPM', pick_ori='vector') +stc_fixed = apply_inverse(evoked, inv_fixed, lambda2, "dSPM", pick_ori="vector") brain_fixed = stc_fixed.plot( - initial_time=peak_time, hemi='lh', subjects_dir=subjects_dir, - smoothing_steps=smoothing_steps) + initial_time=peak_time, + hemi="lh", + subjects_dir=subjects_dir, + smoothing_steps=smoothing_steps, +) diff --git a/examples/io/elekta_epochs.py b/examples/io/elekta_epochs.py index 8c24902d209..125a1e2c028 100644 --- a/examples/io/elekta_epochs.py +++ b/examples/io/elekta_epochs.py @@ -20,7 +20,7 @@ import os from mne.datasets import multimodal -fname_raw = os.path.join(multimodal.data_path(), 'multimodal_raw.fif') +fname_raw = os.path.join(multimodal.data_path(), "multimodal_raw.fif") print(__doc__) @@ -35,9 +35,9 @@ # %% # Extract epochs corresponding to a category -cond = raw.acqparser.get_condition(raw, 'Auditory right') +cond = raw.acqparser.get_condition(raw, "Auditory right") epochs = mne.Epochs(raw, **cond) -epochs.average().plot_topo(background_color='w') +epochs.average().plot_topo(background_color="w") # %% # Get epochs from all conditions, average @@ -45,10 +45,11 @@ for cat in raw.acqparser.categories: cond = raw.acqparser.get_condition(raw, cat) # copy (supported) rejection parameters from DACQ settings - epochs = mne.Epochs(raw, reject=raw.acqparser.reject, - flat=raw.acqparser.flat, **cond) + epochs = mne.Epochs( + raw, reject=raw.acqparser.reject, flat=raw.acqparser.flat, **cond + ) evoked = epochs.average() - evoked.comment = cat['comment'] + evoked.comment = cat["comment"] evokeds.append(evoked) # save all averages to an evoked fiff file # fname_out = 'multimodal-ave.fif' @@ -57,16 +58,15 @@ # %% # Make a new averaging category newcat = dict() -newcat['comment'] = 'Visual lower left, longer epochs' -newcat['event'] = 3 # reference event -newcat['start'] = -.2 # epoch start rel. to ref. event (in seconds) -newcat['end'] = .7 # epoch end -newcat['reqevent'] = 0 # additional required event; 0 if none -newcat['reqwithin'] = .5 # ...required within .5 s (before or after) -newcat['reqwhen'] = 2 # ...required before (1) or after (2) ref. event -newcat['index'] = 9 # can be set freely +newcat["comment"] = "Visual lower left, longer epochs" +newcat["event"] = 3 # reference event +newcat["start"] = -0.2 # epoch start rel. to ref. event (in seconds) +newcat["end"] = 0.7 # epoch end +newcat["reqevent"] = 0 # additional required event; 0 if none +newcat["reqwithin"] = 0.5 # ...required within .5 s (before or after) +newcat["reqwhen"] = 2 # ...required before (1) or after (2) ref. event +newcat["index"] = 9 # can be set freely cond = raw.acqparser.get_condition(raw, newcat) -epochs = mne.Epochs(raw, reject=raw.acqparser.reject, - flat=raw.acqparser.flat, **cond) -epochs.average().plot(time_unit='s') +epochs = mne.Epochs(raw, reject=raw.acqparser.reject, flat=raw.acqparser.flat, **cond) +epochs.average().plot(time_unit="s") diff --git a/examples/io/read_neo_format.py b/examples/io/read_neo_format.py index 43b8a98f876..7847e23dcfa 100644 --- a/examples/io/read_neo_format.py +++ b/examples/io/read_neo_format.py @@ -22,15 +22,15 @@ # demonstrate the steps to using NEO data. For actual data and different file # formats, consult the NEO documentation. -reader = neo.io.ExampleIO('fakedata.nof') +reader = neo.io.ExampleIO("fakedata.nof") block = reader.read(lazy=False)[0] # get the first block -segment = block.segments[0] # get data from first (and only) segment +segment = block.segments[0] # get data from first (and only) segment signals = segment.analogsignals[0] # get first (multichannel) signal -data = signals.rescale('V').magnitude.T +data = signals.rescale("V").magnitude.T sfreq = signals.sampling_rate.magnitude -ch_names = [f'Neo {(idx + 1):02}' for idx in range(signals.shape[1])] -ch_types = ['eeg'] * len(ch_names) # if not specified, type 'misc' is assumed +ch_names = [f"Neo {(idx + 1):02}" for idx in range(signals.shape[1])] +ch_types = ["eeg"] * len(ch_names) # if not specified, type 'misc' is assumed info = mne.create_info(ch_names=ch_names, ch_types=ch_types, sfreq=sfreq) raw = mne.io.RawArray(data, info) diff --git a/examples/io/read_noise_covariance_matrix.py b/examples/io/read_noise_covariance_matrix.py index 57b0d314e25..ba9e126a4ea 100644 --- a/examples/io/read_noise_covariance_matrix.py +++ b/examples/io/read_noise_covariance_matrix.py @@ -17,8 +17,8 @@ from mne.datasets import sample data_path = sample.data_path() -fname_cov = data_path / 'MEG' / 'sample' / 'sample_audvis-cov.fif' -fname_evo = data_path / 'MEG' / 'sample' / 'sample_audvis-ave.fif' +fname_cov = data_path / "MEG" / "sample" / "sample_audvis-cov.fif" +fname_evo = data_path / "MEG" / "sample" / "sample_audvis-ave.fif" cov = mne.read_cov(fname_cov) print(cov) @@ -27,4 +27,4 @@ # %% # Plot covariance -cov.plot(ev_info, exclude='bads', show_svd=False) +cov.plot(ev_info, exclude="bads", show_svd=False) diff --git a/examples/io/read_xdf.py b/examples/io/read_xdf.py index d65784d85ad..1edc8faf2e6 100644 --- a/examples/io/read_xdf.py +++ b/examples/io/read_xdf.py @@ -22,15 +22,13 @@ import mne from mne.datasets import misc -fname = ( - misc.data_path() / 'xdf' / - 'sub-P001_ses-S004_task-Default_run-001_eeg_a2.xdf') +fname = misc.data_path() / "xdf" / "sub-P001_ses-S004_task-Default_run-001_eeg_a2.xdf" streams, header = pyxdf.load_xdf(fname) data = streams[0]["time_series"].T assert data.shape[0] == 5 # four raw EEG plus one stim channel data[:4:2] -= data[1:4:2] # subtract (rereference) to get two bipolar EEG data = data[::2] # subselect -data[:2] *= (1e-6 / 50 / 2) # uV -> V and preamp gain +data[:2] *= 1e-6 / 50 / 2 # uV -> V and preamp gain sfreq = float(streams[0]["info"]["nominal_srate"][0]) info = mne.create_info(3, sfreq, ["eeg", "eeg", "stim"]) raw = mne.io.RawArray(data, info) diff --git a/examples/preprocessing/contralateral_referencing.py b/examples/preprocessing/contralateral_referencing.py index 2c04ccc7c8f..c3aff2afe16 100644 --- a/examples/preprocessing/contralateral_referencing.py +++ b/examples/preprocessing/contralateral_referencing.py @@ -15,28 +15,24 @@ import mne ssvep_folder = mne.datasets.ssvep.data_path() -ssvep_data_raw_path = (ssvep_folder / 'sub-02' / 'ses-01' / 'eeg' / - 'sub-02_ses-01_task-ssvep_eeg.vhdr') +ssvep_data_raw_path = ( + ssvep_folder / "sub-02" / "ses-01" / "eeg" / "sub-02_ses-01_task-ssvep_eeg.vhdr" +) raw = mne.io.read_raw(ssvep_data_raw_path, preload=True) -_ = raw.set_montage('easycap-M1') +_ = raw.set_montage("easycap-M1") # %% # The electrodes TP9 and TP10 are near the mastoids so we'll use them as our # contralateral reference channels. Then we'll create our hemisphere groups. -raw.rename_channels({ - 'TP9': 'M1', - 'TP10': 'M2' -}) +raw.rename_channels({"TP9": "M1", "TP10": "M2"}) # this splits electrodes into 3 groups; left, midline, and right -ch_names = mne.channels.make_1020_channel_selections( - raw.info, return_ch_names=True -) +ch_names = mne.channels.make_1020_channel_selections(raw.info, return_ch_names=True) # remove the ref channels from the lists of to-be-rereferenced channels -ch_names['Left'].remove('M1') -ch_names['Right'].remove('M2') +ch_names["Left"].remove("M1") +ch_names["Right"].remove("M2") # %% # Finally we do the referencing. For the midline channels we'll reference them @@ -44,25 +40,23 @@ # reference to the single contralateral mastoid channel. # midline referencing to mean of mastoids: -mastoids = ['M1', 'M2'] -rereferenced_midline_chs = (raw.copy() - .pick(mastoids + ch_names['Midline']) - .set_eeg_reference(mastoids) - .drop_channels(mastoids) - ) +mastoids = ["M1", "M2"] +rereferenced_midline_chs = ( + raw.copy() + .pick(mastoids + ch_names["Midline"]) + .set_eeg_reference(mastoids) + .drop_channels(mastoids) +) # contralateral referencing (alters channels in `raw` in-place): -for ref, hemi in dict(M2=ch_names['Left'], M1=ch_names['Right']).items(): - mne.set_bipolar_reference( - raw, anode=hemi, cathode=[ref] * len(hemi), copy=False - ) +for ref, hemi in dict(M2=ch_names["Left"], M1=ch_names["Right"]).items(): + mne.set_bipolar_reference(raw, anode=hemi, cathode=[ref] * len(hemi), copy=False) # strip off '-M1' and '-M2' suffixes added to each bipolar-referenced channel -raw.rename_channels(lambda ch_name: ch_name.split('-')[0]) +raw.rename_channels(lambda ch_name: ch_name.split("-")[0]) # replace unreferenced midline with rereferenced midline -_ = (raw.drop_channels(ch_names['Midline']) - .add_channels([rereferenced_midline_chs])) +_ = raw.drop_channels(ch_names["Midline"]).add_channels([rereferenced_midline_chs]) # %% # Make sure the channel locations still look right: -fig = raw.plot_sensors(show_names=True, sphere='eeglab') +fig = raw.plot_sensors(show_names=True, sphere="eeglab") diff --git a/examples/preprocessing/css.py b/examples/preprocessing/css.py index 2631dc54d23..73e86c1b389 100644 --- a/examples/preprocessing/css.py +++ b/examples/preprocessing/css.py @@ -27,26 +27,25 @@ ############################################################################### # Load sample subject data data_path = sample.data_path() -subjects_dir = data_path / 'subjects' -meg_path = data_path / 'MEG' / 'sample' -fwd_fname = meg_path / 'sample_audvis-meg-eeg-oct-6-fwd.fif' -ave_fname = meg_path / 'sample_audvis-no-filter-ave.fif' -cov_fname = meg_path / 'sample_audvis-cov.fif' -trans_fname = meg_path / 'sample_audvis_raw-trans.fif' -bem_fname = subjects_dir / 'sample' / 'bem' / '/sample-5120-bem-sol.fif' - -raw = mne.io.read_raw_fif(meg_path / 'sample_audvis_raw.fif') +subjects_dir = data_path / "subjects" +meg_path = data_path / "MEG" / "sample" +fwd_fname = meg_path / "sample_audvis-meg-eeg-oct-6-fwd.fif" +ave_fname = meg_path / "sample_audvis-no-filter-ave.fif" +cov_fname = meg_path / "sample_audvis-cov.fif" +trans_fname = meg_path / "sample_audvis_raw-trans.fif" +bem_fname = subjects_dir / "sample" / "bem" / "/sample-5120-bem-sol.fif" + +raw = mne.io.read_raw_fif(meg_path / "sample_audvis_raw.fif") fwd = mne.read_forward_solution(fwd_fname) fwd = mne.convert_forward_solution(fwd, force_fixed=True, surf_ori=True) -fwd = mne.pick_types_forward(fwd, meg=True, eeg=True, exclude=raw.info['bads']) +fwd = mne.pick_types_forward(fwd, meg=True, eeg=True, exclude=raw.info["bads"]) cov = mne.read_cov(cov_fname) ############################################################################### # Find patches (labels) to activate -all_labels = mne.read_labels_from_annot(subject='sample', - subjects_dir=subjects_dir) +all_labels = mne.read_labels_from_annot(subject="sample", subjects_dir=subjects_dir) labels = [] -for select_label in ['parahippocampal-lh', 'postcentral-rh']: +for select_label in ["parahippocampal-lh", "postcentral-rh"]: labels.append([lab for lab in all_labels if lab.name in select_label][0]) hiplab, postcenlab = labels @@ -64,32 +63,38 @@ def subcortical_waveform(times): return 10e-9 * np.cos(times * 2 * np.pi * 239) -times = np.linspace(0, 0.5, int(0.5 * raw.info['sfreq'])) -stc = simulate_sparse_stc(fwd['src'], n_dipoles=2, times=times, - location='center', subjects_dir=subjects_dir, - labels=[postcenlab, hiplab], - data_fun=cortical_waveform) -stc.data[np.where(np.isin(stc.vertices[0], hiplab.vertices))[0], :] = \ - subcortical_waveform(times) +times = np.linspace(0, 0.5, int(0.5 * raw.info["sfreq"])) +stc = simulate_sparse_stc( + fwd["src"], + n_dipoles=2, + times=times, + location="center", + subjects_dir=subjects_dir, + labels=[postcenlab, hiplab], + data_fun=cortical_waveform, +) +stc.data[ + np.where(np.isin(stc.vertices[0], hiplab.vertices))[0], : +] = subcortical_waveform(times) evoked = simulate_evoked(fwd, stc, raw.info, cov, nave=15) ############################################################################### # Process with CSS and plot PSD of EEG data before and after processing -evoked_subcortical = mne.preprocessing.cortical_signal_suppression(evoked, - n_proj=6) +evoked_subcortical = mne.preprocessing.cortical_signal_suppression(evoked, n_proj=6) chs = mne.pick_types(evoked.info, meg=False, eeg=True) -psd = np.mean(np.abs(np.fft.rfft(evoked.data))**2, axis=0) -psd_proc = np.mean(np.abs(np.fft.rfft(evoked_subcortical.data))**2, axis=0) -freq = np.arange(0, stop=int(evoked.info['sfreq'] / 2), - step=evoked.info['sfreq'] / (2 * len(psd))) +psd = np.mean(np.abs(np.fft.rfft(evoked.data)) ** 2, axis=0) +psd_proc = np.mean(np.abs(np.fft.rfft(evoked_subcortical.data)) ** 2, axis=0) +freq = np.arange( + 0, stop=int(evoked.info["sfreq"] / 2), step=evoked.info["sfreq"] / (2 * len(psd)) +) fig, ax = plt.subplots() -ax.plot(freq, psd, label='raw') -ax.plot(freq, psd_proc, label='processed') -ax.text(.2, .7, 'cortical', transform=ax.transAxes) -ax.text(.8, .25, 'subcortical', transform=ax.transAxes) -ax.set(ylabel='EEG Power spectral density', xlabel='Frequency (Hz)') +ax.plot(freq, psd, label="raw") +ax.plot(freq, psd_proc, label="processed") +ax.text(0.2, 0.7, "cortical", transform=ax.transAxes) +ax.text(0.8, 0.25, "subcortical", transform=ax.transAxes) +ax.set(ylabel="EEG Power spectral density", xlabel="Frequency (Hz)") ax.legend() # References diff --git a/examples/preprocessing/define_target_events.py b/examples/preprocessing/define_target_events.py index f35b16743d9..51e0fdbb960 100644 --- a/examples/preprocessing/define_target_events.py +++ b/examples/preprocessing/define_target_events.py @@ -33,9 +33,9 @@ # %% # Set parameters -meg_path = data_path / 'MEG' / 'sample' -raw_fname = meg_path / 'sample_audvis_filt-0-40_raw.fif' -event_fname = meg_path / 'sample_audvis_filt-0-40_raw-eve.fif' +meg_path = data_path / "MEG" / "sample" +raw_fname = meg_path / "sample_audvis_filt-0-40_raw.fif" +event_fname = meg_path / "sample_audvis_filt-0-40_raw-eve.fif" # Setup for reading the raw data raw = io.read_raw_fif(raw_fname) @@ -43,25 +43,33 @@ # Set up pick list: EEG + STI 014 - bad channels (modify to your needs) include = [] # or stim channels ['STI 014'] -raw.info['bads'] += ['EEG 053'] # bads +raw.info["bads"] += ["EEG 053"] # bads # pick MEG channels -picks = mne.pick_types(raw.info, meg='mag', eeg=False, stim=False, eog=True, - include=include, exclude='bads') +picks = mne.pick_types( + raw.info, + meg="mag", + eeg=False, + stim=False, + eog=True, + include=include, + exclude="bads", +) # %% # Find stimulus event followed by quick button presses reference_id = 5 # presentation of a smiley face target_id = 32 # button press -sfreq = raw.info['sfreq'] # sampling rate +sfreq = raw.info["sfreq"] # sampling rate tmin = 0.1 # trials leading to very early responses will be rejected tmax = 0.59 # ignore face stimuli followed by button press later than 590 ms new_id = 42 # the new event id for a hit. If None, reference_id is used. fill_na = 99 # the fill value for misses -events_, lag = define_target_events(events, reference_id, target_id, - sfreq, tmin, tmax, new_id, fill_na) +events_, lag = define_target_events( + events, reference_id, target_id, sfreq, tmin, tmax, new_id, fill_na +) print(events_) # The 99 indicates missing or too late button presses @@ -77,9 +85,16 @@ tmax_ = 0.4 event_id = dict(early=new_id, late=fill_na) -epochs = mne.Epochs(raw, events_, event_id, tmin_, - tmax_, picks=picks, baseline=(None, 0), - reject=dict(mag=4e-12)) +epochs = mne.Epochs( + raw, + events_, + event_id, + tmin_, + tmax_, + picks=picks, + baseline=(None, 0), + reject=dict(mag=4e-12), +) # average epochs and get an Evoked dataset. @@ -89,11 +104,11 @@ # View evoked response times = 1e3 * epochs.times # time in milliseconds -title = 'Evoked response followed by %s button press' +title = "Evoked response followed by %s button press" fig, axes = plt.subplots(2, 1) -early.plot(axes=axes[0], time_unit='s') -axes[0].set(title=title % 'late', ylabel='Evoked field (fT)') -late.plot(axes=axes[1], time_unit='s') -axes[1].set(title=title % 'early', ylabel='Evoked field (fT)') +early.plot(axes=axes[0], time_unit="s") +axes[0].set(title=title % "late", ylabel="Evoked field (fT)") +late.plot(axes=axes[1], time_unit="s") +axes[1].set(title=title % "early", ylabel="Evoked field (fT)") plt.show() diff --git a/examples/preprocessing/eeg_bridging.py b/examples/preprocessing/eeg_bridging.py index fa94e752c71..09319a2cdea 100644 --- a/examples/preprocessing/eeg_bridging.py +++ b/examples/preprocessing/eeg_bridging.py @@ -60,11 +60,11 @@ # bridging so using the last segment of the data will # give the most conservative estimate. -montage = mne.channels.make_standard_montage('standard_1005') +montage = mne.channels.make_standard_montage("standard_1005") ed_data = dict() # electrical distance/bridging data raw_data = dict() # store infos for electrode positions for sub in range(1, 11): - print(f'Computing electrode bridges for subject {sub}') + print(f"Computing electrode bridges for subject {sub}") raw_fname = mne.datasets.eegbci.load_data(subject=sub, runs=(1,))[0] raw = mne.io.read_raw(raw_fname, preload=True, verbose=False) mne.datasets.eegbci.standardize(raw) # set channel names @@ -89,7 +89,7 @@ bridged_idx, ed_matrix = ed_data[6] fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4)) -fig.suptitle('Subject 6 Electrical Distance Matrix') +fig.suptitle("Subject 6 Electrical Distance Matrix") # take median across epochs, only use upper triangular, lower is NaNs ed_plot = np.zeros(ed_matrix.shape[1:]) * np.nan @@ -98,17 +98,17 @@ ed_plot[idx0, idx1] = np.nanmedian(ed_matrix[:, idx0, idx1]) # plot full distribution color range -im1 = ax1.imshow(ed_plot, aspect='auto') +im1 = ax1.imshow(ed_plot, aspect="auto") cax1 = fig.colorbar(im1, ax=ax1) -cax1.set_label(r'Electrical Distance ($\mu$$V^2$)') +cax1.set_label(r"Electrical Distance ($\mu$$V^2$)") # plot zoomed in colors -im2 = ax2.imshow(ed_plot, aspect='auto', vmax=5) +im2 = ax2.imshow(ed_plot, aspect="auto", vmax=5) cax2 = fig.colorbar(im2, ax=ax2) -cax2.set_label(r'Electrical Distance ($\mu$$V^2$)') +cax2.set_label(r"Electrical Distance ($\mu$$V^2$)") for ax in (ax1, ax2): - ax.set_xlabel('Channel Index') - ax.set_ylabel('Channel Index') + ax.set_xlabel("Channel Index") + ax.set_ylabel("Channel Index") fig.tight_layout() @@ -125,10 +125,10 @@ # without bridged electrodes do not have a peak near zero. fig, ax = plt.subplots(figsize=(5, 5)) -fig.suptitle('Subject 6 Electrical Distance Matrix Distribution') +fig.suptitle("Subject 6 Electrical Distance Matrix Distribution") ax.hist(ed_matrix[~np.isnan(ed_matrix)], bins=np.linspace(0, 500, 51)) -ax.set_xlabel(r'Electrical Distance ($\mu$$V^2$)') -ax.set_ylabel('Count (channel pairs for all epochs)') +ax.set_xlabel(r"Electrical Distance ($\mu$$V^2$)") +ax.set_ylabel("Count (channel pairs for all epochs)") # %% # Plot Electrical Distances on a Topomap @@ -145,8 +145,12 @@ # may have inserted the gel syringe tip in too far). mne.viz.plot_bridged_electrodes( - raw_data[6].info, bridged_idx, ed_matrix, - title='Subject 6 Bridged Electrodes', topomap_args=dict(vlim=(None, 5))) + raw_data[6].info, + bridged_idx, + ed_matrix, + title="Subject 6 Bridged Electrodes", + topomap_args=dict(vlim=(None, 5)), +) # %% # Plot the Raw Voltage Time Series for Bridged Electrodes @@ -160,18 +164,30 @@ # pairs, meaning that it is unlikely that all four of these electrodes are # bridged. -raw = raw_data[6].copy().pick_channels(['FC2', 'FC4', 'F2', 'F4']) -raw.add_channels([mne.io.RawArray( - raw.get_data(ch1) - raw.get_data(ch2), - mne.create_info([f'{ch1}-{ch2}'], raw.info['sfreq'], 'eeg'), - raw.first_samp) for ch1, ch2 in [('F2', 'F4'), ('FC2', 'FC4')]]) +raw = raw_data[6].copy().pick_channels(["FC2", "FC4", "F2", "F4"]) +raw.add_channels( + [ + mne.io.RawArray( + raw.get_data(ch1) - raw.get_data(ch2), + mne.create_info([f"{ch1}-{ch2}"], raw.info["sfreq"], "eeg"), + raw.first_samp, + ) + for ch1, ch2 in [("F2", "F4"), ("FC2", "FC4")] + ] +) raw.plot(duration=20, scalings=dict(eeg=2e-4)) -raw = raw_data[1].copy().pick_channels(['FC2', 'FC4', 'F2', 'F4']) -raw.add_channels([mne.io.RawArray( - raw.get_data(ch1) - raw.get_data(ch2), - mne.create_info([f'{ch1}-{ch2}'], raw.info['sfreq'], 'eeg'), - raw.first_samp) for ch1, ch2 in [('F2', 'F4'), ('FC2', 'FC4')]]) +raw = raw_data[1].copy().pick_channels(["FC2", "FC4", "F2", "F4"]) +raw.add_channels( + [ + mne.io.RawArray( + raw.get_data(ch1) - raw.get_data(ch2), + mne.create_info([f"{ch1}-{ch2}"], raw.info["sfreq"], "eeg"), + raw.first_samp, + ) + for ch1, ch2 in [("F2", "F4"), ("FC2", "FC4")] + ] +) raw.plot(duration=20, scalings=dict(eeg=2e-4)) # %% @@ -193,23 +209,25 @@ # distance from the sensors to the brain). fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4)) -fig.suptitle('Electrical Distance Distribution for EEGBCI Subjects') +fig.suptitle("Electrical Distance Distribution for EEGBCI Subjects") for ax in (ax1, ax2): - ax.set_ylabel('Count') - ax.set_xlabel(r'Electrical Distance ($\mu$$V^2$)') + ax.set_ylabel("Count") + ax.set_xlabel(r"Electrical Distance ($\mu$$V^2$)") for sub, (bridged_idx, ed_matrix) in ed_data.items(): # ed_matrix is upper triangular so exclude bottom half of NaNs - hist, edges = np.histogram(ed_matrix[~np.isnan(ed_matrix)].flatten(), - bins=np.linspace(0, 1000, 101)) + hist, edges = np.histogram( + ed_matrix[~np.isnan(ed_matrix)].flatten(), bins=np.linspace(0, 1000, 101) + ) centers = (edges[1:] + edges[:-1]) / 2 ax1.plot(centers, hist) - hist, edges = np.histogram(ed_matrix[~np.isnan(ed_matrix)].flatten(), - bins=np.linspace(0, 30, 21)) + hist, edges = np.histogram( + ed_matrix[~np.isnan(ed_matrix)].flatten(), bins=np.linspace(0, 30, 21) + ) centers = (edges[1:] + edges[:-1]) / 2 - ax2.plot(centers, hist, label=f'Sub {sub} #={len(bridged_idx)}') + ax2.plot(centers, hist, label=f"Sub {sub} #={len(bridged_idx)}") -ax1.axvspan(0, 30, color='r', alpha=0.5) +ax1.axvspan(0, 30, color="r", alpha=0.5) ax2.legend(loc=(1.04, 0)) fig.subplots_adjust(right=0.725, bottom=0.15, wspace=0.4) @@ -223,9 +241,12 @@ for sub, (bridged_idx, ed_matrix) in ed_data.items(): mne.viz.plot_bridged_electrodes( - raw_data[sub].info, bridged_idx, ed_matrix, - title=f'Subject {sub} Bridged Electrodes', - topomap_args=dict(vlim=(None, 5))) + raw_data[sub].info, + bridged_idx, + ed_matrix, + title=f"Subject {sub} Bridged Electrodes", + topomap_args=dict(vlim=(None, 5)), + ) # %% # For subjects with many bridged channels like Subject 6 shown in the example @@ -242,7 +263,8 @@ # use subject 2, only one bridged electrode pair bridged_idx = ed_data[2][0] raw = mne.preprocessing.interpolate_bridged_electrodes( - raw_data[2].copy(), bridged_idx=bridged_idx) + raw_data[2].copy(), bridged_idx=bridged_idx +) # %% # Let's make sure that our virtual channel aided the interpolation. We can do @@ -274,41 +296,73 @@ bridged_data[0] += 1e-7 * rng.normal(size=raw.times.size) bridged_data[1] += 1e-7 * rng.normal(size=raw.times.size) # add back simulated data -raw_sim = raw_sim.add_channels([mne.io.RawArray( - bridged_data, mne.create_info([ch0, ch1], raw.info['sfreq'], 'eeg'), - raw.first_samp)]) +raw_sim = raw_sim.add_channels( + [ + mne.io.RawArray( + bridged_data, + mne.create_info([ch0, ch1], raw.info["sfreq"], "eeg"), + raw.first_samp, + ) + ] +) raw_sim.set_montage(montage) # add back channel positions # use virtual channel method raw_virtual = mne.preprocessing.interpolate_bridged_electrodes( - raw_sim.copy(), bridged_idx=bridged_idx_simulated) + raw_sim.copy(), bridged_idx=bridged_idx_simulated +) data_virtual = raw_virtual.get_data(picks=(idx0, idx1)) # set bads to be bridged electrodes to interpolate without a virtual channel raw_comp = raw_sim.copy() -raw_comp.info['bads'] = [raw_sim.ch_names[idx0], raw_sim.ch_names[idx1]] +raw_comp.info["bads"] = [raw_sim.ch_names[idx0], raw_sim.ch_names[idx1]] raw_comp.interpolate_bads() data_comp = raw_comp.get_data(picks=(idx0, idx1)) # compute variance of residuals -print('Variance of residual (interpolated data - original data)\n\n' - 'With adding virtual channel: {}\n' - 'Compared to interpolation only using other channels: {}' - ''.format(np.mean(np.var(data_virtual - data_orig, axis=1)), - np.mean(np.var(data_comp - data_orig, axis=1)))) +print( + "Variance of residual (interpolated data - original data)\n\n" + "With adding virtual channel: {}\n" + "Compared to interpolation only using other channels: {}" + "".format( + np.mean(np.var(data_virtual - data_orig, axis=1)), + np.mean(np.var(data_comp - data_orig, axis=1)), + ) +) # plot results raw = raw.pick_channels([ch0, ch1]) -raw = raw.add_channels([mne.io.RawArray( - np.concatenate([data_virtual, data_virtual - data_orig]), - mne.create_info([f'{ch0} virtual', f'{ch1} virtual', - f'{ch0} virtual diff', f'{ch1} virtual diff'], - raw.info['sfreq'], 'eeg'), raw.first_samp)]) -raw = raw.add_channels([mne.io.RawArray( - np.concatenate([data_comp, data_comp - data_orig]), - mne.create_info([f'{ch0} comp', f'{ch1} comp', - f'{ch0} comp diff', f'{ch1} comp diff'], - raw.info['sfreq'], 'eeg'), raw.first_samp)]) +raw = raw.add_channels( + [ + mne.io.RawArray( + np.concatenate([data_virtual, data_virtual - data_orig]), + mne.create_info( + [ + f"{ch0} virtual", + f"{ch1} virtual", + f"{ch0} virtual diff", + f"{ch1} virtual diff", + ], + raw.info["sfreq"], + "eeg", + ), + raw.first_samp, + ) + ] +) +raw = raw.add_channels( + [ + mne.io.RawArray( + np.concatenate([data_comp, data_comp - data_orig]), + mne.create_info( + [f"{ch0} comp", f"{ch1} comp", f"{ch0} comp diff", f"{ch1} comp diff"], + raw.info["sfreq"], + "eeg", + ), + raw.first_samp, + ) + ] +) raw.plot(scalings=dict(eeg=7e-5)) # %% @@ -332,17 +386,26 @@ raw = raw_data[1] # typically impedances < 25 kOhm are acceptable for active systems and # impedances < 5 kOhm are desirable for a passive system -impedances = rng.random((len(raw.ch_names,))) * 30 +impedances = ( + rng.random( + ( + len( + raw.ch_names, + ) + ) + ) + * 30 +) impedances[10] = 80 # set a few bad impendances impedances[25] = 99 -cmap = LinearSegmentedColormap.from_list(name='impedance_cmap', - colors=['g', 'y', 'r'], N=256) +cmap = LinearSegmentedColormap.from_list( + name="impedance_cmap", colors=["g", "y", "r"], N=256 +) fig, ax = plt.subplots(figsize=(5, 5)) -im, cn = mne.viz.plot_topomap(impedances, raw.info, axes=ax, - cmap=cmap, vlim=(25, 75)) -ax.set_title('Electrode Impendances') +im, cn = mne.viz.plot_topomap(impedances, raw.info, axes=ax, cmap=cmap, vlim=(25, 75)) +ax.set_title("Electrode Impendances") cax = fig.colorbar(im, ax=ax) -cax.set_label(r'Impedance (k$\Omega$)') +cax.set_label(r"Impedance (k$\Omega$)") # %% # Summary diff --git a/examples/preprocessing/eeg_csd.py b/examples/preprocessing/eeg_csd.py index 24f33b91e53..7bb19415eaa 100644 --- a/examples/preprocessing/eeg_csd.py +++ b/examples/preprocessing/eeg_csd.py @@ -32,10 +32,11 @@ # %% # Load sample subject data -meg_path = data_path / 'MEG' / 'sample' -raw = mne.io.read_raw_fif(meg_path / 'sample_audvis_raw.fif') -raw = raw.pick_types(meg=False, eeg=True, eog=True, ecg=True, stim=True, - exclude=raw.info['bads']).load_data() +meg_path = data_path / "MEG" / "sample" +raw = mne.io.read_raw_fif(meg_path / "sample_audvis_raw.fif") +raw = raw.pick_types( + meg=False, eeg=True, eog=True, ecg=True, stim=True, exclude=raw.info["bads"] +).load_data() events = mne.find_events(raw) raw.set_eeg_reference(projection=True).apply_proj() @@ -56,19 +57,24 @@ # CSD can also be computed on Evoked (averaged) data. # Here we epoch and average the data so we can demonstrate that. -event_id = {'auditory/left': 1, 'auditory/right': 2, 'visual/left': 3, - 'visual/right': 4, 'smiley': 5, 'button': 32} -epochs = mne.Epochs(raw, events, event_id=event_id, tmin=-0.2, tmax=.5, - preload=True) -evoked = epochs['auditory'].average() +event_id = { + "auditory/left": 1, + "auditory/right": 2, + "visual/left": 3, + "visual/right": 4, + "smiley": 5, + "button": 32, +} +epochs = mne.Epochs(raw, events, event_id=event_id, tmin=-0.2, tmax=0.5, preload=True) +evoked = epochs["auditory"].average() # %% # First let's look at how CSD affects scalp topography: -times = np.array([-0.1, 0., 0.05, 0.1, 0.15]) +times = np.array([-0.1, 0.0, 0.05, 0.1, 0.15]) evoked_csd = mne.preprocessing.compute_current_source_density(evoked) -evoked.plot_joint(title='Average Reference', show=False) -evoked_csd.plot_joint(title='Current Source Density') +evoked.plot_joint(title="Average Reference", show=False) +evoked_csd.plot_joint(title="Current Source Density") # %% # CSD has parameters ``stiffness`` and ``lambda2`` affecting smoothing and @@ -80,11 +86,12 @@ for i, lambda2 in enumerate([0, 1e-7, 1e-5, 1e-3]): for j, m in enumerate([5, 4, 3, 2]): this_evoked_csd = mne.preprocessing.compute_current_source_density( - evoked, stiffness=m, lambda2=lambda2) + evoked, stiffness=m, lambda2=lambda2 + ) this_evoked_csd.plot_topomap( - 0.1, axes=ax[i, j], contours=4, time_unit='s', - colorbar=False, show=False) - ax[i, j].set_title('stiffness=%i\nλ²=%s' % (m, lambda2)) + 0.1, axes=ax[i, j], contours=4, time_unit="s", colorbar=False, show=False + ) + ax[i, j].set_title("stiffness=%i\nλ²=%s" % (m, lambda2)) # %% # References diff --git a/examples/preprocessing/eog_artifact_histogram.py b/examples/preprocessing/eog_artifact_histogram.py index a6a3e895b3c..6953e8a8ed3 100644 --- a/examples/preprocessing/eog_artifact_histogram.py +++ b/examples/preprocessing/eog_artifact_histogram.py @@ -28,24 +28,24 @@ # %% # Set parameters -meg_path = data_path / 'MEG' / 'sample' -raw_fname = meg_path / 'sample_audvis_filt-0-40_raw.fif' +meg_path = data_path / "MEG" / "sample" +raw_fname = meg_path / "sample_audvis_filt-0-40_raw.fif" # Setup for reading the raw data raw = io.read_raw_fif(raw_fname, preload=True) -events = mne.find_events(raw, 'STI 014') +events = mne.find_events(raw, "STI 014") eog_event_id = 512 eog_events = mne.preprocessing.find_eog_events(raw, eog_event_id) -raw.add_events(eog_events, 'STI 014') +raw.add_events(eog_events, "STI 014") # Read epochs picks = mne.pick_types(raw.info, meg=False, eeg=False, stim=True, eog=False) tmin, tmax = -0.2, 0.5 -event_ids = {'AudL': 1, 'AudR': 2, 'VisL': 3, 'VisR': 4} +event_ids = {"AudL": 1, "AudR": 2, "VisL": 3, "VisR": 4} epochs = mne.Epochs(raw, events, event_ids, tmin, tmax, picks=picks) # Get the stim channel data -pick_ch = mne.pick_channels(epochs.ch_names, ['STI 014'])[0] +pick_ch = mne.pick_channels(epochs.ch_names, ["STI 014"])[0] data = epochs.get_data()[:, pick_ch, :] data = np.sum((data.astype(int) & eog_event_id) == eog_event_id, axis=0) @@ -53,6 +53,5 @@ # Plot EOG artifact distribution fig, ax = plt.subplots() ax.stem(1e3 * epochs.times, data) -ax.set(xlabel='Times (ms)', - ylabel='Blink counts (from %s trials)' % len(epochs)) +ax.set(xlabel="Times (ms)", ylabel="Blink counts (from %s trials)" % len(epochs)) fig.tight_layout() diff --git a/examples/preprocessing/eog_regression.py b/examples/preprocessing/eog_regression.py index 1d7f6879b9a..6c88cb01d9a 100644 --- a/examples/preprocessing/eog_regression.py +++ b/examples/preprocessing/eog_regression.py @@ -30,14 +30,14 @@ print(__doc__) data_path = sample.data_path() -raw_fname = data_path / 'MEG' / 'sample' / 'sample_audvis_filt-0-40_raw.fif' +raw_fname = data_path / "MEG" / "sample" / "sample_audvis_filt-0-40_raw.fif" # Read raw data raw = mne.io.read_raw_fif(raw_fname, preload=True) -events = mne.find_events(raw, 'STI 014') +events = mne.find_events(raw, "STI 014") # Highpass filter to eliminate slow drifts -raw.filter(0.3, None, picks='all') +raw.filter(0.3, None, picks="all") # %% # Perform regression and remove EOG @@ -57,21 +57,22 @@ # is best visualized by extracting epochs and plotting the evoked potential. tmin, tmax = -0.1, 0.5 -event_id = {'visual/left': 3, 'visual/right': 4} -evoked_before = mne.Epochs(raw, events, event_id, tmin, tmax, - baseline=(tmin, 0)).average() -evoked_after = mne.Epochs(raw_clean, events, event_id, tmin, tmax, - baseline=(tmin, 0)).average() +event_id = {"visual/left": 3, "visual/right": 4} +evoked_before = mne.Epochs( + raw, events, event_id, tmin, tmax, baseline=(tmin, 0) +).average() +evoked_after = mne.Epochs( + raw_clean, events, event_id, tmin, tmax, baseline=(tmin, 0) +).average() # Create epochs after EOG correction -epochs_after = mne.Epochs(raw_clean, events, event_id, tmin, tmax, - baseline=(tmin, 0)) +epochs_after = mne.Epochs(raw_clean, events, event_id, tmin, tmax, baseline=(tmin, 0)) evoked_after = epochs_after.average() -fig, ax = plt.subplots(nrows=3, ncols=2, figsize=(10, 7), - sharex=True, sharey='row') +fig, ax = plt.subplots(nrows=3, ncols=2, figsize=(10, 7), sharex=True, sharey="row") evoked_before.plot(axes=ax[:, 0], spatial_colors=True) evoked_after.plot(axes=ax[:, 1], spatial_colors=True) -fig.subplots_adjust(top=0.905, bottom=0.09, left=0.08, right=0.975, - hspace=0.325, wspace=0.145) -fig.suptitle('Before --> After') +fig.subplots_adjust( + top=0.905, bottom=0.09, left=0.08, right=0.975, hspace=0.325, wspace=0.145 +) +fig.suptitle("Before --> After") diff --git a/examples/preprocessing/find_ref_artifacts.py b/examples/preprocessing/find_ref_artifacts.py index f3781a0c1cc..8a08658a174 100644 --- a/examples/preprocessing/find_ref_artifacts.py +++ b/examples/preprocessing/find_ref_artifacts.py @@ -45,7 +45,7 @@ # %% # Read raw data, cropping to 5 minutes to save memory -raw_fname = data_path / 'sample_reference_MEG_noise-raw.fif' +raw_fname = data_path / "sample_reference_MEG_noise-raw.fif" raw = io.read_raw_fif(raw_fname).crop(300, 600).load_data() # %% @@ -53,11 +53,17 @@ # been applied to these data, much of the noise in the reference channels # (bottom of the plot) can still be seen in the standard channels. select_picks = np.concatenate( - (mne.pick_types(raw.info, meg=True)[-32:], - mne.pick_types(raw.info, meg=False, ref_meg=True))) + ( + mne.pick_types(raw.info, meg=True)[-32:], + mne.pick_types(raw.info, meg=False, ref_meg=True), + ) +) plot_kwargs = dict( - duration=100, order=select_picks, n_channels=len(select_picks), - scalings={"mag": 8e-13, "ref_meg": 2e-11}) + duration=100, + order=select_picks, + n_channels=len(select_picks), + scalings={"mag": 8e-13, "ref_meg": 2e-11}, +) raw.plot(**plot_kwargs) # %% @@ -68,12 +74,11 @@ # Run the "together" algorithm. raw_tog = raw.copy() ica_kwargs = dict( - method='picard', + method="picard", fit_params=dict(tol=1e-4), # use a high tol here for speed ) all_picks = mne.pick_types(raw_tog.info, meg=True, ref_meg=True) -ica_tog = ICA(n_components=60, max_iter='auto', allow_ref_meg=True, - **ica_kwargs) +ica_tog = ICA(n_components=60, max_iter="auto", allow_ref_meg=True, **ica_kwargs) ica_tog.fit(raw_tog, picks=all_picks) # low threshold (2.0) here because of cropped data, entire recording can use # a higher threshold (2.5) @@ -100,8 +105,7 @@ # Do ICA only on the reference channels. ref_picks = mne.pick_types(raw_sep.info, meg=False, ref_meg=True) -ica_ref = ICA(n_components=2, max_iter='auto', allow_ref_meg=True, - **ica_kwargs) +ica_ref = ICA(n_components=2, max_iter="auto", allow_ref_meg=True, **ica_kwargs) ica_ref.fit(raw_sep, picks=ref_picks) # Do ICA on both reference and standard channels. Here, we can just reuse diff --git a/examples/preprocessing/fnirs_artifact_removal.py b/examples/preprocessing/fnirs_artifact_removal.py index b7236b76636..d669d6ce09c 100644 --- a/examples/preprocessing/fnirs_artifact_removal.py +++ b/examples/preprocessing/fnirs_artifact_removal.py @@ -18,8 +18,10 @@ import os import mne -from mne.preprocessing.nirs import (optical_density, - temporal_derivative_distribution_repair) +from mne.preprocessing.nirs import ( + optical_density, + temporal_derivative_distribution_repair, +) # %% # Import data @@ -31,12 +33,13 @@ # and plot these signals. fnirs_data_folder = mne.datasets.fnirs_motor.data_path() -fnirs_cw_amplitude_dir = os.path.join(fnirs_data_folder, 'Participant-1') +fnirs_cw_amplitude_dir = os.path.join(fnirs_data_folder, "Participant-1") raw_intensity = mne.io.read_raw_nirx(fnirs_cw_amplitude_dir, verbose=True) raw_intensity.load_data().resample(3, npad="auto") raw_od = optical_density(raw_intensity) -new_annotations = mne.Annotations([31, 187, 317], [8, 8, 8], - ["Movement", "Movement", "Movement"]) +new_annotations = mne.Annotations( + [31, 187, 317], [8, 8, 8], ["Movement", "Movement", "Movement"] +) raw_od.set_annotations(new_annotations) raw_od.plot(n_channels=15, duration=400, show_scrollbars=False) @@ -61,10 +64,10 @@ corrupted_data = raw_od.get_data() corrupted_data[:, 298:302] = corrupted_data[:, 298:302] - 0.06 corrupted_data[:, 450:750] = corrupted_data[:, 450:750] + 0.03 -corrupted_od = mne.io.RawArray(corrupted_data, raw_od.info, - first_samp=raw_od.first_samp) -new_annotations.append([95, 145, 245], [10, 10, 10], - ["Spike", "Baseline", "Baseline"]) +corrupted_od = mne.io.RawArray( + corrupted_data, raw_od.info, first_samp=raw_od.first_samp +) +new_annotations.append([95, 145, 245], [10, 10, 10], ["Spike", "Baseline", "Baseline"]) corrupted_od.set_annotations(new_annotations) corrupted_od.plot(n_channels=15, duration=400, show_scrollbars=False) diff --git a/examples/preprocessing/ica_comparison.py b/examples/preprocessing/ica_comparison.py index 7c4a8aa733c..6aa601dd5fa 100644 --- a/examples/preprocessing/ica_comparison.py +++ b/examples/preprocessing/ica_comparison.py @@ -31,13 +31,13 @@ # - 1-30 Hz band-pass filter data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -raw_fname = meg_path / 'sample_audvis_filt-0-40_raw.fif' +meg_path = data_path / "MEG" / "sample" +raw_fname = meg_path / "sample_audvis_filt-0-40_raw.fif" -raw = mne.io.read_raw_fif(raw_fname).crop(0, 60).pick('meg').load_data() +raw = mne.io.read_raw_fif(raw_fname).crop(0, 60).pick("meg").load_data() reject = dict(mag=5e-12, grad=4000e-13) -raw.filter(1, 30, fir_design='firwin') +raw.filter(1, 30, fir_design="firwin") # %% @@ -45,27 +45,32 @@ def run_ica(method, fit_params=None): - ica = ICA(n_components=20, method=method, fit_params=fit_params, - max_iter='auto', random_state=0) + ica = ICA( + n_components=20, + method=method, + fit_params=fit_params, + max_iter="auto", + random_state=0, + ) t0 = time() ica.fit(raw, reject=reject) fit_time = time() - t0 - title = ('ICA decomposition using %s (took %.1fs)' % (method, fit_time)) + title = "ICA decomposition using %s (took %.1fs)" % (method, fit_time) ica.plot_components(title=title) # %% # FastICA -run_ica('fastica') +run_ica("fastica") # %% # Picard -run_ica('picard') +run_ica("picard") # %% # Infomax -run_ica('infomax') +run_ica("infomax") # %% # Extended Infomax -run_ica('infomax', fit_params=dict(extended=True)) +run_ica("infomax", fit_params=dict(extended=True)) diff --git a/examples/preprocessing/interpolate_bad_channels.py b/examples/preprocessing/interpolate_bad_channels.py index 635dffcbfba..7040e24299e 100644 --- a/examples/preprocessing/interpolate_bad_channels.py +++ b/examples/preprocessing/interpolate_bad_channels.py @@ -28,24 +28,24 @@ print(__doc__) data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -fname = meg_path / 'sample_audvis-ave.fif' -evoked = mne.read_evokeds(fname, condition='Left Auditory', - baseline=(None, 0)) +meg_path = data_path / "MEG" / "sample" +fname = meg_path / "sample_audvis-ave.fif" +evoked = mne.read_evokeds(fname, condition="Left Auditory", baseline=(None, 0)) # plot with bads -evoked.plot(exclude=[], picks=('grad', 'eeg')) +evoked.plot(exclude=[], picks=("grad", "eeg")) # %% # Compute interpolation (also works with Raw and Epochs objects) evoked_interp = evoked.copy().interpolate_bads(reset_bads=False) -evoked_interp.plot(exclude=[], picks=('grad', 'eeg')) +evoked_interp.plot(exclude=[], picks=("grad", "eeg")) # %% # You can also use minimum-norm for EEG as well as MEG evoked_interp_mne = evoked.copy().interpolate_bads( - reset_bads=False, method=dict(eeg='MNE'), verbose=True) -evoked_interp_mne.plot(exclude=[], picks=('grad', 'eeg')) + reset_bads=False, method=dict(eeg="MNE"), verbose=True +) +evoked_interp_mne.plot(exclude=[], picks=("grad", "eeg")) # %% # References diff --git a/examples/preprocessing/movement_compensation.py b/examples/preprocessing/movement_compensation.py index 3a31648c4a5..97d183533a8 100644 --- a/examples/preprocessing/movement_compensation.py +++ b/examples/preprocessing/movement_compensation.py @@ -24,11 +24,11 @@ print(__doc__) -data_path = mne.datasets.misc.data_path(verbose=True) / 'movement' +data_path = mne.datasets.misc.data_path(verbose=True) / "movement" -head_pos = mne.chpi.read_head_pos(data_path / 'simulated_quats.pos') -raw = mne.io.read_raw_fif(data_path / 'simulated_movement_raw.fif') -raw_stat = mne.io.read_raw_fif(data_path / 'simulated_stationary_raw.fif') +head_pos = mne.chpi.read_head_pos(data_path / "simulated_quats.pos") +raw = mne.io.read_raw_fif(data_path / "simulated_movement_raw.fif") +raw_stat = mne.io.read_raw_fif(data_path / "simulated_stationary_raw.fif") # %% # Visualize the "subject" head movements. By providing the measurement @@ -37,29 +37,31 @@ # be shown in blue, and the destination (if given) shown in red. mne.viz.plot_head_positions( - head_pos, mode='traces', destination=raw.info['dev_head_t'], info=raw.info) + head_pos, mode="traces", destination=raw.info["dev_head_t"], info=raw.info +) # %% # This can also be visualized using a quiver. mne.viz.plot_head_positions( - head_pos, mode='field', destination=raw.info['dev_head_t'], info=raw.info) + head_pos, mode="field", destination=raw.info["dev_head_t"], info=raw.info +) # %% # Process our simulated raw data (taking into account head movements). # extract our resulting events -events = mne.find_events(raw, stim_channel='STI 014') +events = mne.find_events(raw, stim_channel="STI 014") events[:, 2] = 1 raw.plot(events=events) -topo_kwargs = dict(times=[0, 0.1, 0.2], ch_type='mag', vlim=(-500, 500)) +topo_kwargs = dict(times=[0, 0.1, 0.2], ch_type="mag", vlim=(-500, 500)) # %% # First, take the average of stationary data (bilateral auditory patterns). evoked_stat = mne.Epochs(raw_stat, events, 1, -0.2, 0.8).average() fig = evoked_stat.plot_topomap(**topo_kwargs) -fig.suptitle('Stationary') +fig.suptitle("Stationary") # %% # Second, take a naive average, which averages across epochs that have been @@ -68,18 +70,18 @@ epochs = mne.Epochs(raw, events, 1, -0.2, 0.8) evoked = epochs.average() fig = evoked.plot_topomap(**topo_kwargs) -fig.suptitle('Moving: naive average') +fig.suptitle("Moving: naive average") # %% # Third, use raw movement compensation (restores pattern). raw_sss = maxwell_filter(raw, head_pos=head_pos) evoked_raw_mc = mne.Epochs(raw_sss, events, 1, -0.2, 0.8).average() fig = evoked_raw_mc.plot_topomap(**topo_kwargs) -fig.suptitle('Moving: movement compensated (raw)') +fig.suptitle("Moving: movement compensated (raw)") # %% # Fourth, use evoked movement compensation. For these data, which contain # very large rotations, it does not as cleanly restore the pattern. evoked_evo_mc = mne.epochs.average_movements(epochs, head_pos=head_pos) fig = evoked_evo_mc.plot_topomap(**topo_kwargs) -fig.suptitle('Moving: movement compensated (evoked)') +fig.suptitle("Moving: movement compensated (evoked)") diff --git a/examples/preprocessing/movement_detection.py b/examples/preprocessing/movement_detection.py index ac90f45f587..2984c53fea2 100644 --- a/examples/preprocessing/movement_detection.py +++ b/examples/preprocessing/movement_detection.py @@ -29,16 +29,17 @@ # Load data data_path = bst_auditory.data_path() -data_path_MEG = data_path / 'MEG' -subject = 'bst_auditory' -subjects_dir = data_path / 'subjects' -trans_fname = data_path / 'MEG' / 'bst_auditory' / 'bst_auditory-trans.fif' -raw_fname1 = data_path_MEG / 'bst_auditory' / 'S01_AEF_20131218_01.ds' -raw_fname2 = data_path_MEG / 'bst_auditory' / 'S01_AEF_20131218_02.ds' +data_path_MEG = data_path / "MEG" +subject = "bst_auditory" +subjects_dir = data_path / "subjects" +trans_fname = data_path / "MEG" / "bst_auditory" / "bst_auditory-trans.fif" +raw_fname1 = data_path_MEG / "bst_auditory" / "S01_AEF_20131218_01.ds" +raw_fname2 = data_path_MEG / "bst_auditory" / "S01_AEF_20131218_02.ds" # read and concatenate two files, ignoring device<->head mismatch raw = read_raw_ctf(raw_fname1, preload=False) mne.io.concatenate_raws( - [raw, read_raw_ctf(raw_fname2, preload=False)], on_mismatch='ignore') + [raw, read_raw_ctf(raw_fname2, preload=False)], on_mismatch="ignore" +) raw.crop(350, 410).load_data() raw.resample(100, npad="auto") @@ -49,15 +50,18 @@ # Get cHPI time series and compute average chpi_locs = mne.chpi.extract_chpi_locs_ctf(raw) head_pos = mne.chpi.compute_head_pos(raw.info, chpi_locs) -original_head_dev_t = mne.transforms.invert_transform( - raw.info['dev_head_t']) +original_head_dev_t = mne.transforms.invert_transform(raw.info["dev_head_t"]) average_head_dev_t = mne.transforms.invert_transform( - compute_average_dev_head_t(raw, head_pos)) + compute_average_dev_head_t(raw, head_pos) +) fig = mne.viz.plot_head_positions(head_pos) -for ax, val, val_ori in zip(fig.axes[::2], average_head_dev_t['trans'][:3, 3], - original_head_dev_t['trans'][:3, 3]): - ax.axhline(1000 * val, color='r') - ax.axhline(1000 * val_ori, color='g') +for ax, val, val_ori in zip( + fig.axes[::2], + average_head_dev_t["trans"][:3, 3], + original_head_dev_t["trans"][:3, 3], +): + ax.axhline(1000 * val, color="r") + ax.axhline(1000 * val_ori, color="g") # The green horizontal lines represent the original head position, whereas the # red lines are the new head position averaged over all the time points. @@ -66,9 +70,10 @@ # Plot raw data with annotated movement # ------------------------------------------------------------------ -mean_distance_limit = .0015 # in meters +mean_distance_limit = 0.0015 # in meters annotation_movement, hpi_disp = annotate_movement( - raw, head_pos, mean_distance_limit=mean_distance_limit) + raw, head_pos, mean_distance_limit=mean_distance_limit +) raw.set_annotations(annotation_movement) raw.plot(n_channels=100, duration=20) @@ -76,7 +81,12 @@ # After checking the annotated movement artifacts, calculate the new transform # and plot it: new_dev_head_t = compute_average_dev_head_t(raw, head_pos) -raw.info['dev_head_t'] = new_dev_head_t -fig = mne.viz.plot_alignment(raw.info, show_axes=True, subject=subject, - trans=trans_fname, subjects_dir=subjects_dir) +raw.info["dev_head_t"] = new_dev_head_t +fig = mne.viz.plot_alignment( + raw.info, + show_axes=True, + subject=subject, + trans=trans_fname, + subjects_dir=subjects_dir, +) mne.viz.set_3d_view(fig, azimuth=90, elevation=60) diff --git a/examples/preprocessing/muscle_detection.py b/examples/preprocessing/muscle_detection.py index 223f93743d9..37bd021d853 100644 --- a/examples/preprocessing/muscle_detection.py +++ b/examples/preprocessing/muscle_detection.py @@ -39,7 +39,7 @@ # Load data data_path = bst_auditory.data_path() -raw_fname = data_path / 'MEG' / 'bst_auditory' / 'S01_AEF_20131218_01.ds' +raw_fname = data_path / "MEG" / "bst_auditory" / "S01_AEF_20131218_01.ds" raw = read_raw_ctf(raw_fname, preload=False) @@ -64,8 +64,12 @@ # Choose one channel type, if there are axial gradiometers and magnetometers, # select magnetometers as they are more sensitive to muscle activity. annot_muscle, scores_muscle = annotate_muscle_zscore( - raw, ch_type="mag", threshold=threshold_muscle, min_length_good=0.2, - filter_freq=[110, 140]) + raw, + ch_type="mag", + threshold=threshold_muscle, + min_length_good=0.2, + filter_freq=[110, 140], +) # %% # Plot muscle z-scores across recording @@ -73,8 +77,8 @@ fig, ax = plt.subplots() ax.plot(raw.times, scores_muscle) -ax.axhline(y=threshold_muscle, color='r') -ax.set(xlabel='time, (s)', ylabel='zscore', title='Muscle activity') +ax.axhline(y=threshold_muscle, color="r") +ax.set(xlabel="time, (s)", ylabel="zscore", title="Muscle activity") # %% # View the annotations # -------------------------------------------------------------------------- diff --git a/examples/preprocessing/muscle_ica.py b/examples/preprocessing/muscle_ica.py index 8abc96f5d6a..8f50615c66a 100644 --- a/examples/preprocessing/muscle_ica.py +++ b/examples/preprocessing/muscle_ica.py @@ -22,7 +22,7 @@ import mne data_path = mne.datasets.sample.data_path() -raw_fname = data_path / 'MEG' / 'sample' / 'sample_audvis_raw.fif' +raw_fname = data_path / "MEG" / "sample" / "sample_audvis_raw.fif" raw = mne.io.read_raw_fif(raw_fname) raw.crop(tmin=100, tmax=130) # take 30 seconds for speed @@ -33,12 +33,13 @@ # ICA works best with a highpass filter applied raw.load_data() -raw.filter(l_freq=1., h_freq=None) +raw.filter(l_freq=1.0, h_freq=None) # %% # Run ICA ica = mne.preprocessing.ICA( - n_components=15, method='picard', max_iter='auto', random_state=97) + n_components=15, method="picard", max_iter="auto", random_state=97 +) ica.fit(raw) # %% @@ -85,8 +86,10 @@ # and ensure that it gets the same components we did manually. muscle_idx_auto, scores = ica.find_bads_muscle(raw) ica.plot_scores(scores, exclude=muscle_idx_auto) -print(f'Manually found muscle artifact ICA components: {muscle_idx}\n' - f'Automatically found muscle artifact ICA components: {muscle_idx_auto}') +print( + f"Manually found muscle artifact ICA components: {muscle_idx}\n" + f"Automatically found muscle artifact ICA components: {muscle_idx_auto}" +) # %% # Let's now replicate this on the EEGBCI dataset @@ -94,24 +97,28 @@ for sub in (1, 2): raw = mne.io.read_raw_edf( - mne.datasets.eegbci.load_data(subject=sub, runs=(1,))[0], preload=True) + mne.datasets.eegbci.load_data(subject=sub, runs=(1,))[0], preload=True + ) mne.datasets.eegbci.standardize(raw) # set channel names - montage = mne.channels.make_standard_montage('standard_1005') + montage = mne.channels.make_standard_montage("standard_1005") raw.set_montage(montage) - raw.filter(l_freq=1., h_freq=None) + raw.filter(l_freq=1.0, h_freq=None) # Run ICA ica = mne.preprocessing.ICA( - n_components=15, method='picard', max_iter='auto', random_state=97) + n_components=15, method="picard", max_iter="auto", random_state=97 + ) ica.fit(raw) ica.plot_sources(raw) muscle_idx_auto, scores = ica.find_bads_muscle(raw) ica.plot_properties(raw, picks=muscle_idx_auto, log_scale=True) ica.plot_scores(scores, exclude=muscle_idx_auto) - print(f'Manually found muscle artifact ICA components: {muscle_idx}\n' - 'Automatically found muscle artifact ICA components: ' - f'{muscle_idx_auto}') + print( + f"Manually found muscle artifact ICA components: {muscle_idx}\n" + "Automatically found muscle artifact ICA components: " + f"{muscle_idx_auto}" + ) # %% # References diff --git a/examples/preprocessing/otp.py b/examples/preprocessing/otp.py index 520d66166ac..a05eaf5c6ce 100644 --- a/examples/preprocessing/otp.py +++ b/examples/preprocessing/otp.py @@ -32,17 +32,17 @@ dipole_number = 1 data_path = bst_phantom_elekta.data_path() -raw = read_raw_fif(data_path / 'kojak_all_200nAm_pp_no_chpi_no_ms_raw.fif') -raw.crop(40., 50.).load_data() +raw = read_raw_fif(data_path / "kojak_all_200nAm_pp_no_chpi_no_ms_raw.fif") +raw.crop(40.0, 50.0).load_data() order = list(range(160, 170)) -raw.copy().filter(0., 40.).plot(order=order, n_channels=10) +raw.copy().filter(0.0, 40.0).plot(order=order, n_channels=10) # %% # Now we can clean the data with OTP, lowpass, and plot. The flux jumps have # been suppressed alongside the random sensor noise. raw_clean = mne.preprocessing.oversampled_temporal_projection(raw) -raw_clean.filter(0., 40.) +raw_clean.filter(0.0, 40.0) raw_clean.plot(order=order, n_channels=10) @@ -52,19 +52,26 @@ # for more information. Here we use a version that does single-trial # localization across the 17 trials are in our 10-second window: + def compute_bias(raw): - events = find_events(raw, 'STI201', verbose=False) + events = find_events(raw, "STI201", verbose=False) events = events[1:] # first one has an artifact tmin, tmax = -0.2, 0.1 - epochs = mne.Epochs(raw, events, dipole_number, tmin, tmax, - baseline=(None, -0.01), preload=True, verbose=False) - sphere = mne.make_sphere_model(r0=(0., 0., 0.), head_radius=None, - verbose=False) - cov = mne.compute_covariance(epochs, tmax=0, method='oas', - rank=None, verbose=False) + epochs = mne.Epochs( + raw, + events, + dipole_number, + tmin, + tmax, + baseline=(None, -0.01), + preload=True, + verbose=False, + ) + sphere = mne.make_sphere_model(r0=(0.0, 0.0, 0.0), head_radius=None, verbose=False) + cov = mne.compute_covariance(epochs, tmax=0, method="oas", rank=None, verbose=False) idx = epochs.time_as_index(0.036)[0] data = epochs.get_data()[:, :, idx].T - evoked = mne.EvokedArray(data, epochs.info, tmin=0.) + evoked = mne.EvokedArray(data, epochs.info, tmin=0.0) dip = fit_dipole(evoked, cov, sphere, n_jobs=None, verbose=False)[0] actual_pos = mne.dipole.get_phantom_dipoles()[0][dipole_number - 1] misses = 1000 * np.linalg.norm(dip.pos - actual_pos, axis=-1) @@ -72,11 +79,15 @@ def compute_bias(raw): bias = compute_bias(raw) -print('Raw bias: %0.1fmm (worst: %0.1fmm)' - % (np.mean(bias), np.max(bias))) +print("Raw bias: %0.1fmm (worst: %0.1fmm)" % (np.mean(bias), np.max(bias))) bias_clean = compute_bias(raw_clean) -print('OTP bias: %0.1fmm (worst: %0.1fmm)' - % (np.mean(bias_clean), np.max(bias_clean),)) +print( + "OTP bias: %0.1fmm (worst: %0.1fmm)" + % ( + np.mean(bias_clean), + np.max(bias_clean), + ) +) # %% # References diff --git a/examples/preprocessing/shift_evoked.py b/examples/preprocessing/shift_evoked.py index 3bbe0386416..7b05a1b4714 100644 --- a/examples/preprocessing/shift_evoked.py +++ b/examples/preprocessing/shift_evoked.py @@ -20,32 +20,46 @@ print(__doc__) data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -fname = meg_path / 'sample_audvis-ave.fif' +meg_path = data_path / "MEG" / "sample" +fname = meg_path / "sample_audvis-ave.fif" # Reading evoked data -condition = 'Left Auditory' -evoked = mne.read_evokeds(fname, condition=condition, baseline=(None, 0), - proj=True) +condition = "Left Auditory" +evoked = mne.read_evokeds(fname, condition=condition, baseline=(None, 0), proj=True) -ch_names = evoked.info['ch_names'] +ch_names = evoked.info["ch_names"] picks = mne.pick_channels(ch_names=ch_names, include=["MEG 2332"]) # Create subplots f, (ax1, ax2, ax3) = plt.subplots(3) -evoked.plot(exclude=[], picks=picks, axes=ax1, - titles=dict(grad='Before time shifting'), time_unit='s') +evoked.plot( + exclude=[], + picks=picks, + axes=ax1, + titles=dict(grad="Before time shifting"), + time_unit="s", +) # Apply relative time-shift of 500 ms evoked.shift_time(0.5, relative=True) -evoked.plot(exclude=[], picks=picks, axes=ax2, - titles=dict(grad='Relative shift: 500 ms'), time_unit='s') +evoked.plot( + exclude=[], + picks=picks, + axes=ax2, + titles=dict(grad="Relative shift: 500 ms"), + time_unit="s", +) # Apply absolute time-shift of 500 ms evoked.shift_time(0.5, relative=False) -evoked.plot(exclude=[], picks=picks, axes=ax3, - titles=dict(grad='Absolute shift: 500 ms'), time_unit='s') +evoked.plot( + exclude=[], + picks=picks, + axes=ax3, + titles=dict(grad="Absolute shift: 500 ms"), + time_unit="s", +) tight_layout() diff --git a/examples/preprocessing/virtual_evoked.py b/examples/preprocessing/virtual_evoked.py index b947226b40b..096165910da 100644 --- a/examples/preprocessing/virtual_evoked.py +++ b/examples/preprocessing/virtual_evoked.py @@ -27,35 +27,35 @@ # read the evoked data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -fname = meg_path / 'sample_audvis-ave.fif' -evoked = mne.read_evokeds(fname, condition='Left Auditory', baseline=(None, 0)) +meg_path = data_path / "MEG" / "sample" +fname = meg_path / "sample_audvis-ave.fif" +evoked = mne.read_evokeds(fname, condition="Left Auditory", baseline=(None, 0)) # %% # First, let's call remap gradiometers to magnometers, and plot # the original and remapped topomaps of the magnetometers. # go from grad + mag to mag and plot original mag -virt_evoked = evoked.as_type('mag') -fig = evoked.plot_topomap(ch_type='mag') -fig.suptitle('mag (original)') +virt_evoked = evoked.as_type("mag") +fig = evoked.plot_topomap(ch_type="mag") +fig.suptitle("mag (original)") # %% # plot interpolated grad + mag -fig = virt_evoked.plot_topomap(ch_type='mag') -fig.suptitle('mag (interpolated from mag + grad)') +fig = virt_evoked.plot_topomap(ch_type="mag") +fig.suptitle("mag (interpolated from mag + grad)") # %% # Now, we remap magnometers to gradiometers, and plot # the original and remapped topomaps of the gradiometers # go from grad + mag to grad and plot original grad -virt_evoked = evoked.as_type('grad') -fig = evoked.plot_topomap(ch_type='grad') -fig.suptitle('grad (original)') +virt_evoked = evoked.as_type("grad") +fig = evoked.plot_topomap(ch_type="grad") +fig.suptitle("grad (original)") # %% # plot interpolated grad + mag -fig = virt_evoked.plot_topomap(ch_type='grad') -fig.suptitle('grad (interpolated from mag + grad)') +fig = virt_evoked.plot_topomap(ch_type="grad") +fig.suptitle("grad (interpolated from mag + grad)") diff --git a/examples/preprocessing/xdawn_denoising.py b/examples/preprocessing/xdawn_denoising.py index aa7c0f48e08..b6eed43d142 100644 --- a/examples/preprocessing/xdawn_denoising.py +++ b/examples/preprocessing/xdawn_denoising.py @@ -25,7 +25,7 @@ # %% -from mne import (io, compute_raw_covariance, read_events, pick_types, Epochs) +from mne import io, compute_raw_covariance, read_events, pick_types, Epochs from mne.datasets import sample from mne.preprocessing import Xdawn from mne.viz import plot_epochs_image @@ -36,27 +36,35 @@ # %% # Set parameters and read data -meg_path = data_path / 'MEG' / 'sample' -raw_fname = meg_path / 'sample_audvis_filt-0-40_raw.fif' -event_fname = meg_path / 'sample_audvis_filt-0-40_raw-eve.fif' +meg_path = data_path / "MEG" / "sample" +raw_fname = meg_path / "sample_audvis_filt-0-40_raw.fif" +event_fname = meg_path / "sample_audvis_filt-0-40_raw-eve.fif" tmin, tmax = -0.1, 0.3 event_id = dict(vis_r=4) # Setup for reading the raw data raw = io.read_raw_fif(raw_fname, preload=True) -raw.filter(1, 20, fir_design='firwin') # replace baselining with high-pass +raw.filter(1, 20, fir_design="firwin") # replace baselining with high-pass events = read_events(event_fname) -raw.info['bads'] = ['MEG 2443'] # set bad channels -picks = pick_types(raw.info, meg=True, eeg=False, stim=False, eog=False, - exclude='bads') +raw.info["bads"] = ["MEG 2443"] # set bad channels +picks = pick_types(raw.info, meg=True, eeg=False, stim=False, eog=False, exclude="bads") # Epoching -epochs = Epochs(raw, events, event_id, tmin, tmax, proj=False, - picks=picks, baseline=None, preload=True, - verbose=False) +epochs = Epochs( + raw, + events, + event_id, + tmin, + tmax, + proj=False, + picks=picks, + baseline=None, + preload=True, + verbose=False, +) # Plot image epoch before xdawn -plot_epochs_image(epochs['vis_r'], picks=[230], vmin=-500, vmax=500) +plot_epochs_image(epochs["vis_r"], picks=[230], vmin=-500, vmax=500) # %% # Now, we estimate a set of xDAWN filters for the epochs (which contain only @@ -78,7 +86,7 @@ epochs_denoised = xd.apply(epochs) # Plot image epoch after Xdawn -plot_epochs_image(epochs_denoised['vis_r'], picks=[230], vmin=-500, vmax=500) +plot_epochs_image(epochs_denoised["vis_r"], picks=[230], vmin=-500, vmax=500) # %% # References diff --git a/examples/simulation/plot_stc_metrics.py b/examples/simulation/plot_stc_metrics.py index 2e53c6bcd02..20912c12cc1 100644 --- a/examples/simulation/plot_stc_metrics.py +++ b/examples/simulation/plot_stc_metrics.py @@ -20,31 +20,36 @@ import mne from mne.datasets import sample from mne.minimum_norm import make_inverse_operator, apply_inverse -from mne.simulation.metrics import (region_localization_error, - f1_score, precision_score, - recall_score, cosine_score, - peak_position_error, - spatial_deviation_error) +from mne.simulation.metrics import ( + region_localization_error, + f1_score, + precision_score, + recall_score, + cosine_score, + peak_position_error, + spatial_deviation_error, +) random_state = 42 # set random state to make this example deterministic # Import sample data data_path = sample.data_path() -subjects_dir = data_path / 'subjects' -subject = 'sample' -evoked_fname = data_path / 'MEG' / subject / 'sample_audvis-ave.fif' +subjects_dir = data_path / "subjects" +subject = "sample" +evoked_fname = data_path / "MEG" / subject / "sample_audvis-ave.fif" info = mne.io.read_info(evoked_fname) -tstep = 1. / info['sfreq'] +tstep = 1.0 / info["sfreq"] # Import forward operator and source space -fwd_fname = data_path / 'MEG' / subject / 'sample_audvis-meg-eeg-oct-6-fwd.fif' +fwd_fname = data_path / "MEG" / subject / "sample_audvis-meg-eeg-oct-6-fwd.fif" fwd = mne.read_forward_solution(fwd_fname) -src = fwd['src'] +src = fwd["src"] # To select source, we use the caudal middle frontal to grow # a region of interest. selected_label = mne.read_labels_from_annot( - subject, regexp='caudalmiddlefrontal-lh', subjects_dir=subjects_dir)[0] + subject, regexp="caudalmiddlefrontal-lh", subjects_dir=subjects_dir +)[0] ############################################################################### @@ -61,22 +66,32 @@ # WHERE? # Region -location = 'center' # Use the center of the label as a seed. -extent = 20. # Extent in mm of the region. +location = "center" # Use the center of the label as a seed. +extent = 20.0 # Extent in mm of the region. label_region = mne.label.select_sources( - subject, selected_label, location=location, extent=extent, - subjects_dir=subjects_dir, random_state=random_state) + subject, + selected_label, + location=location, + extent=extent, + subjects_dir=subjects_dir, + random_state=random_state, +) # Dipole location = 1915 # Use the index of the vertex as a seed -extent = 0. # One dipole source +extent = 0.0 # One dipole source label_dipole = mne.label.select_sources( - subject, selected_label, location=location, extent=extent, - subjects_dir=subjects_dir, random_state=random_state) + subject, + selected_label, + location=location, + extent=extent, + subjects_dir=subjects_dir, + random_state=random_state, +) # WHAT? # Define the time course of the activity -source_time_series = np.sin(2. * np.pi * 18. * np.arange(100) * tstep) * 10e-9 +source_time_series = np.sin(2.0 * np.pi * 18.0 * np.arange(100) * tstep) * 10e-9 # WHEN? # Define when the activity occurs using events. @@ -107,20 +122,20 @@ # noise obtained from the noise covariance from the sample data. # Region -raw_region = mne.simulation.simulate_raw(info, source_simulator_region, - forward=fwd) +raw_region = mne.simulation.simulate_raw(info, source_simulator_region, forward=fwd) raw_region = raw_region.pick_types(meg=False, eeg=True, stim=True) cov = mne.make_ad_hoc_cov(raw_region.info) -mne.simulation.add_noise(raw_region, cov, iir_filter=[0.2, -0.2, 0.04], - random_state=random_state) +mne.simulation.add_noise( + raw_region, cov, iir_filter=[0.2, -0.2, 0.04], random_state=random_state +) # Dipole -raw_dipole = mne.simulation.simulate_raw(info, source_simulator_dipole, - forward=fwd) +raw_dipole = mne.simulation.simulate_raw(info, source_simulator_dipole, forward=fwd) raw_dipole = raw_dipole.pick_types(meg=False, eeg=True, stim=True) cov = mne.make_ad_hoc_cov(raw_dipole.info) -mne.simulation.add_noise(raw_dipole, cov, iir_filter=[0.2, -0.2, 0.04], - random_state=random_state) +mne.simulation.add_noise( + raw_dipole, cov, iir_filter=[0.2, -0.2, 0.04], random_state=random_state +) ############################################################################### # Compute evoked from raw data @@ -149,14 +164,14 @@ # same number of time samples. # Region -stc_true_region = \ - source_simulator_region.get_stc(start_sample=0, - stop_sample=len(source_time_series)) +stc_true_region = source_simulator_region.get_stc( + start_sample=0, stop_sample=len(source_time_series) +) # Dipole -stc_true_dipole = \ - source_simulator_dipole.get_stc(start_sample=0, - stop_sample=len(source_time_series)) +stc_true_dipole = source_simulator_dipole.get_stc( + start_sample=0, stop_sample=len(source_time_series) +) ############################################################################### # Reconstruct simulated sources @@ -166,27 +181,29 @@ # Region snr = 30.0 -inv_method = 'sLORETA' -lambda2 = 1.0 / snr ** 2 +inv_method = "sLORETA" +lambda2 = 1.0 / snr**2 -inverse_operator = make_inverse_operator(evoked_region.info, fwd, cov, - loose='auto', depth=0.8, - fixed=True) +inverse_operator = make_inverse_operator( + evoked_region.info, fwd, cov, loose="auto", depth=0.8, fixed=True +) -stc_est_region = apply_inverse(evoked_region, inverse_operator, lambda2, - inv_method, pick_ori=None) +stc_est_region = apply_inverse( + evoked_region, inverse_operator, lambda2, inv_method, pick_ori=None +) # Dipole snr = 3.0 -inv_method = 'sLORETA' -lambda2 = 1.0 / snr ** 2 +inv_method = "sLORETA" +lambda2 = 1.0 / snr**2 -inverse_operator = make_inverse_operator(evoked_dipole.info, fwd, cov, - loose='auto', depth=0.8, - fixed=True) +inverse_operator = make_inverse_operator( + evoked_dipole.info, fwd, cov, loose="auto", depth=0.8, fixed=True +) -stc_est_dipole = apply_inverse(evoked_dipole, inverse_operator, lambda2, - inv_method, pick_ori=None) +stc_est_dipole = apply_inverse( + evoked_dipole, inverse_operator, lambda2, inv_method, pick_ori=None +) ############################################################################### # Compute performance scores for different source amplitude thresholds @@ -201,32 +218,34 @@ # # create a set of scorers -scorers = {'RLE': partial(region_localization_error, src=src), - 'Precision': precision_score, 'Recall': recall_score, - 'F1 score': f1_score} +scorers = { + "RLE": partial(region_localization_error, src=src), + "Precision": precision_score, + "Recall": recall_score, + "F1 score": f1_score, +} # compute results region_results = {} for name, scorer in scorers.items(): - region_results[name] = [scorer(stc_true_region, stc_est_region, - threshold=f'{thx}%', per_sample=False) - for thx in thresholds] + region_results[name] = [ + scorer(stc_true_region, stc_est_region, threshold=f"{thx}%", per_sample=False) + for thx in thresholds + ] # Plot the results -f, ((ax1, ax2), (ax3, ax4)) = plt.subplots( - 2, 2, sharex='col', constrained_layout=True) +f, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, sharex="col", constrained_layout=True) for ax, (title, results) in zip([ax1, ax2, ax3, ax4], region_results.items()): - ax.plot(thresholds, results, '.-') - ax.set(title=title, ylabel='score', xlabel='Threshold', - xticks=thresholds) + ax.plot(thresholds, results, ".-") + ax.set(title=title, ylabel="score", xlabel="Threshold", xticks=thresholds) -f.suptitle('Performance scores per threshold') # Add Super title -ax1.ticklabel_format(axis='y', style='sci', scilimits=(0, 1)) # tweak RLE +f.suptitle("Performance scores per threshold") # Add Super title +ax1.ticklabel_format(axis="y", style="sci", scilimits=(0, 1)) # tweak RLE # Cosine score with respect to time f, ax1 = plt.subplots(constrained_layout=True) ax1.plot(stc_true_region.times, cosine_score(stc_true_region, stc_est_region)) -ax1.set(title='Cosine score', xlabel='Time', ylabel='Score') +ax1.set(title="Cosine score", xlabel="Time", ylabel="Score") ############################################################################### @@ -236,22 +255,28 @@ # create a set of scorers scorers = { - 'Peak Position Error': peak_position_error, - 'Spatial Deviation Error': spatial_deviation_error, + "Peak Position Error": peak_position_error, + "Spatial Deviation Error": spatial_deviation_error, } # compute results dipole_results = {} for name, scorer in scorers.items(): - dipole_results[name] = [scorer(stc_true_dipole, stc_est_dipole, src=src, - threshold=f'{thx}%', per_sample=False) - for thx in thresholds] + dipole_results[name] = [ + scorer( + stc_true_dipole, + stc_est_dipole, + src=src, + threshold=f"{thx}%", + per_sample=False, + ) + for thx in thresholds + ] # Plot the results for name, results in dipole_results.items(): f, ax1 = plt.subplots(constrained_layout=True) - ax1.plot(thresholds, 100 * np.array(results), '.-') - ax1.set(title=name, ylabel='Error (cm)', xlabel='Threshold', - xticks=thresholds) + ax1.plot(thresholds, 100 * np.array(results), ".-") + ax1.set(title=name, ylabel="Error (cm)", xlabel="Threshold", xticks=thresholds) diff --git a/examples/simulation/simulate_evoked_data.py b/examples/simulation/simulate_evoked_data.py index b906d2df265..0a8d69a66ed 100644 --- a/examples/simulation/simulate_evoked_data.py +++ b/examples/simulation/simulate_evoked_data.py @@ -28,55 +28,65 @@ # %% # Load real data as templates data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -raw = mne.io.read_raw_fif(meg_path / 'sample_audvis_raw.fif') -proj = mne.read_proj(meg_path / 'sample_audvis_ecg-proj.fif') +meg_path = data_path / "MEG" / "sample" +raw = mne.io.read_raw_fif(meg_path / "sample_audvis_raw.fif") +proj = mne.read_proj(meg_path / "sample_audvis_ecg-proj.fif") raw.add_proj(proj) -raw.info['bads'] = ['MEG 2443', 'EEG 053'] # mark bad channels +raw.info["bads"] = ["MEG 2443", "EEG 053"] # mark bad channels -fwd_fname = meg_path / 'sample_audvis-meg-eeg-oct-6-fwd.fif' -ave_fname = meg_path / 'sample_audvis-no-filter-ave.fif' -cov_fname = meg_path / 'sample_audvis-cov.fif' +fwd_fname = meg_path / "sample_audvis-meg-eeg-oct-6-fwd.fif" +ave_fname = meg_path / "sample_audvis-no-filter-ave.fif" +cov_fname = meg_path / "sample_audvis-cov.fif" fwd = mne.read_forward_solution(fwd_fname) -fwd = mne.pick_types_forward(fwd, meg=True, eeg=True, exclude=raw.info['bads']) +fwd = mne.pick_types_forward(fwd, meg=True, eeg=True, exclude=raw.info["bads"]) cov = mne.read_cov(cov_fname) info = mne.io.read_info(ave_fname) -label_names = ['Aud-lh', 'Aud-rh'] -labels = [mne.read_label(meg_path / 'labels' / f'{ln}.label') - for ln in label_names] +label_names = ["Aud-lh", "Aud-rh"] +labels = [mne.read_label(meg_path / "labels" / f"{ln}.label") for ln in label_names] # %% # Generate source time courses from 2 dipoles and the corresponding evoked data -times = np.arange(300, dtype=np.float64) / raw.info['sfreq'] - 0.1 +times = np.arange(300, dtype=np.float64) / raw.info["sfreq"] - 0.1 rng = np.random.RandomState(42) def data_fun(times): """Generate random source time courses.""" - return (50e-9 * np.sin(30. * times) * - np.exp(- (times - 0.15 + 0.05 * rng.randn(1)) ** 2 / 0.01)) - - -stc = simulate_sparse_stc(fwd['src'], n_dipoles=2, times=times, - random_state=42, labels=labels, data_fun=data_fun) + return ( + 50e-9 + * np.sin(30.0 * times) + * np.exp(-((times - 0.15 + 0.05 * rng.randn(1)) ** 2) / 0.01) + ) + + +stc = simulate_sparse_stc( + fwd["src"], + n_dipoles=2, + times=times, + random_state=42, + labels=labels, + data_fun=data_fun, +) # %% # Generate noisy evoked data -picks = mne.pick_types(raw.info, meg=True, exclude='bads') +picks = mne.pick_types(raw.info, meg=True, exclude="bads") iir_filter = fit_iir_model_raw(raw, order=5, picks=picks, tmin=60, tmax=180)[1] nave = 100 # simulate average of 100 epochs -evoked = simulate_evoked(fwd, stc, info, cov, nave=nave, use_cps=True, - iir_filter=iir_filter) +evoked = simulate_evoked( + fwd, stc, info, cov, nave=nave, use_cps=True, iir_filter=iir_filter +) # %% # Plot -plot_sparse_source_estimates(fwd['src'], stc, bgcolor=(1, 1, 1), - opacity=0.5, high_resolution=True) +plot_sparse_source_estimates( + fwd["src"], stc, bgcolor=(1, 1, 1), opacity=0.5, high_resolution=True +) plt.figure() plt.psd(evoked.data[0]) -evoked.plot(time_unit='s') +evoked.plot(time_unit="s") diff --git a/examples/simulation/simulate_raw_data.py b/examples/simulation/simulate_raw_data.py index 6c308792c97..902429717c2 100644 --- a/examples/simulation/simulate_raw_data.py +++ b/examples/simulation/simulate_raw_data.py @@ -22,15 +22,20 @@ import mne from mne import find_events, Epochs, compute_covariance, make_ad_hoc_cov from mne.datasets import sample -from mne.simulation import (simulate_sparse_stc, simulate_raw, - add_noise, add_ecg, add_eog) +from mne.simulation import ( + simulate_sparse_stc, + simulate_raw, + add_noise, + add_ecg, + add_eog, +) print(__doc__) data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -raw_fname = meg_path / 'sample_audvis_raw.fif' -fwd_fname = meg_path / 'sample_audvis-meg-eeg-oct-6-fwd.fif' +meg_path = data_path / "MEG" / "sample" +raw_fname = meg_path / "sample_audvis_raw.fif" +fwd_fname = meg_path / "sample_audvis-meg-eeg-oct-6-fwd.fif" # Load real data as the template raw = mne.io.read_raw_fif(raw_fname) @@ -39,7 +44,7 @@ ############################################################################## # Generate dipole time series n_dipoles = 4 # number of dipoles to create -epoch_duration = 2. # duration of each epoch/event +epoch_duration = 2.0 # duration of each epoch/event n = 0 # harmonic number rng = np.random.RandomState(0) # random state (make reproducible) @@ -49,24 +54,26 @@ def data_fun(times): global n n_samp = len(times) window = np.zeros(n_samp) - start, stop = [int(ii * float(n_samp) / (2 * n_dipoles)) - for ii in (2 * n, 2 * n + 1)] - window[start:stop] = 1. + start, stop = [ + int(ii * float(n_samp) / (2 * n_dipoles)) for ii in (2 * n, 2 * n + 1) + ] + window[start:stop] = 1.0 n += 1 - data = 25e-9 * np.sin(2. * np.pi * 10. * n * times) + data = 25e-9 * np.sin(2.0 * np.pi * 10.0 * n * times) data *= window return data -times = raw.times[:int(raw.info['sfreq'] * epoch_duration)] +times = raw.times[: int(raw.info["sfreq"] * epoch_duration)] fwd = mne.read_forward_solution(fwd_fname) -src = fwd['src'] -stc = simulate_sparse_stc(src, n_dipoles=n_dipoles, times=times, - data_fun=data_fun, random_state=rng) +src = fwd["src"] +stc = simulate_sparse_stc( + src, n_dipoles=n_dipoles, times=times, data_fun=data_fun, random_state=rng +) # look at our source data fig, ax = plt.subplots(1) ax.plot(times, 1e9 * stc.data.T) -ax.set(ylabel='Amplitude (nAm)', xlabel='Time (s)') +ax.set(ylabel="Amplitude (nAm)", xlabel="Time (s)") mne.viz.utils.plt_show() ############################################################################## @@ -82,7 +89,8 @@ def data_fun(times): # Plot evoked data events = find_events(raw_sim) # only 1 pos, so event number == 1 epochs = Epochs(raw_sim, events, 1, tmin=-0.2, tmax=epoch_duration) -cov = compute_covariance(epochs, tmax=0., method='empirical', - verbose='error') # quick calc +cov = compute_covariance( + epochs, tmax=0.0, method="empirical", verbose="error" +) # quick calc evoked = epochs.average() -evoked.plot_white(cov, time_unit='s') +evoked.plot_white(cov, time_unit="s") diff --git a/examples/simulation/simulated_raw_data_using_subject_anatomy.py b/examples/simulation/simulated_raw_data_using_subject_anatomy.py index 0edb33e7d0f..af13d124383 100644 --- a/examples/simulation/simulated_raw_data_using_subject_anatomy.py +++ b/examples/simulation/simulated_raw_data_using_subject_anatomy.py @@ -34,24 +34,24 @@ # to be given to functions. data_path = sample.data_path() -subjects_dir = data_path / 'subjects' -subject = 'sample' -meg_path = data_path / 'MEG' / subject +subjects_dir = data_path / "subjects" +subject = "sample" +meg_path = data_path / "MEG" / subject # %% # First, we get an info structure from the sample subject. -fname_info = meg_path / 'sample_audvis_raw.fif' +fname_info = meg_path / "sample_audvis_raw.fif" info = mne.io.read_info(fname_info) -tstep = 1 / info['sfreq'] +tstep = 1 / info["sfreq"] # %% # To simulate sources, we also need a source space. It can be obtained from the # forward solution of the sample subject. -fwd_fname = meg_path / 'sample_audvis-meg-eeg-oct-6-fwd.fif' +fwd_fname = meg_path / "sample_audvis-meg-eeg-oct-6-fwd.fif" fwd = mne.read_forward_solution(fwd_fname) -src = fwd['src'] +src = fwd["src"] # %% # To simulate raw data, we need to define when the activity occurs using events @@ -60,16 +60,22 @@ # Here, both are loaded from the sample dataset, but they can also be specified # by the user. -fname_event = meg_path / 'sample_audvis_raw-eve.fif' -fname_cov = meg_path / 'sample_audvis-cov.fif' +fname_event = meg_path / "sample_audvis_raw-eve.fif" +fname_cov = meg_path / "sample_audvis-cov.fif" events = mne.read_events(fname_event) noise_cov = mne.read_cov(fname_cov) # Standard sample event IDs. These values will correspond to the third column # in the events matrix. -event_id = {'auditory/left': 1, 'auditory/right': 2, 'visual/left': 3, - 'visual/right': 4, 'smiley': 5, 'button': 32} +event_id = { + "auditory/left": 1, + "auditory/right": 2, + "visual/left": 3, + "visual/right": 4, + "smiley": 5, + "button": 32, +} # Take only a few events for speed @@ -92,26 +98,28 @@ # times more) than the ipsilateral. activations = { - 'auditory/left': - [('G_temp_sup-G_T_transv-lh', 30), # label, activation (nAm) - ('G_temp_sup-G_T_transv-rh', 60)], - 'auditory/right': - [('G_temp_sup-G_T_transv-lh', 60), - ('G_temp_sup-G_T_transv-rh', 30)], - 'visual/left': - [('S_calcarine-lh', 30), - ('S_calcarine-rh', 60)], - 'visual/right': - [('S_calcarine-lh', 60), - ('S_calcarine-rh', 30)], + "auditory/left": [ + ("G_temp_sup-G_T_transv-lh", 30), # label, activation (nAm) + ("G_temp_sup-G_T_transv-rh", 60), + ], + "auditory/right": [ + ("G_temp_sup-G_T_transv-lh", 60), + ("G_temp_sup-G_T_transv-rh", 30), + ], + "visual/left": [("S_calcarine-lh", 30), ("S_calcarine-rh", 60)], + "visual/right": [("S_calcarine-lh", 60), ("S_calcarine-rh", 30)], } -annot = 'aparc.a2009s' +annot = "aparc.a2009s" # Load the 4 necessary label names. -label_names = sorted(set(activation[0] - for activation_list in activations.values() - for activation in activation_list)) +label_names = sorted( + set( + activation[0] + for activation_list in activations.values() + for activation in activation_list + ) +) region_names = list(activations.keys()) # %% @@ -128,8 +136,9 @@ def data_fun(times, latency, duration): f = 15 # oscillating frequency, beta band [Hz] sigma = 0.375 * duration sinusoid = np.sin(2 * np.pi * f * (times - latency)) - gf = np.exp(- (times - latency - (sigma / 4.) * rng.rand(1)) ** 2 / - (2 * (sigma ** 2))) + gf = np.exp( + -((times - latency - (sigma / 4.0) * rng.rand(1)) ** 2) / (2 * (sigma**2)) + ) return 1e-9 * sinusoid * gf @@ -152,7 +161,7 @@ def data_fun(times, latency, duration): # event, the second is not used. The third one is the event id, which is # different for each of the 4 areas. -times = np.arange(150, dtype=np.float64) / info['sfreq'] +times = np.arange(150, dtype=np.float64) / info["sfreq"] duration = 0.03 rng = np.random.RandomState(7) source_simulator = mne.simulation.SourceSimulator(src, tstep=tstep) @@ -161,20 +170,17 @@ def data_fun(times, latency, duration): events_tmp = events[np.where(events[:, 2] == region_id)[0], :] for i in range(2): label_name = activations[region_name][i][0] - label_tmp = mne.read_labels_from_annot(subject, annot, - subjects_dir=subjects_dir, - regexp=label_name, - verbose=False) + label_tmp = mne.read_labels_from_annot( + subject, annot, subjects_dir=subjects_dir, regexp=label_name, verbose=False + ) label_tmp = label_tmp[0] amplitude_tmp = activations[region_name][i][1] - if region_name.split('/')[1][0] == label_tmp.hemi[0]: + if region_name.split("/")[1][0] == label_tmp.hemi[0]: latency_tmp = 0.115 else: latency_tmp = 0.1 wf_tmp = data_fun(times, latency_tmp, duration) - source_simulator.add_data(label_tmp, - amplitude_tmp * wf_tmp, - events_tmp) + source_simulator.add_data(label_tmp, amplitude_tmp * wf_tmp, events_tmp) # To obtain a SourceEstimate object, we need to use `get_stc()` method of # SourceSimulator class. @@ -203,17 +209,16 @@ def data_fun(times, latency, duration): mne.simulation.add_ecg(raw_sim, random_state=0) # Plot original and simulated raw data. -raw_sim.plot(title='Simulated raw data') +raw_sim.plot(title="Simulated raw data") # %% # Extract epochs and compute evoked responsses # -------------------------------------------- # -epochs = mne.Epochs(raw_sim, events, event_id, tmin=-0.2, tmax=0.3, - baseline=(None, 0)) -evoked_aud_left = epochs['auditory/left'].average() -evoked_vis_right = epochs['visual/right'].average() +epochs = mne.Epochs(raw_sim, events, event_id, tmin=-0.2, tmax=0.3, baseline=(None, 0)) +evoked_aud_left = epochs["auditory/left"].average() +evoked_vis_right = epochs["visual/right"].average() # Visualize the evoked data evoked_aud_left.plot(spatial_colors=True) @@ -229,16 +234,15 @@ def data_fun(times, latency, duration): # As expected, when high activations appear in primary auditory areas, primary # visual areas will have low activations and vice versa. -method, lambda2 = 'dSPM', 1. / 9. +method, lambda2 = "dSPM", 1.0 / 9.0 inv = mne.minimum_norm.make_inverse_operator(epochs.info, fwd, noise_cov) -stc_aud = mne.minimum_norm.apply_inverse( - evoked_aud_left, inv, lambda2, method) -stc_vis = mne.minimum_norm.apply_inverse( - evoked_vis_right, inv, lambda2, method) +stc_aud = mne.minimum_norm.apply_inverse(evoked_aud_left, inv, lambda2, method) +stc_vis = mne.minimum_norm.apply_inverse(evoked_vis_right, inv, lambda2, method) stc_diff = stc_aud - stc_vis -brain = stc_diff.plot(subjects_dir=subjects_dir, initial_time=0.1, - hemi='split', views=['lat', 'med']) +brain = stc_diff.plot( + subjects_dir=subjects_dir, initial_time=0.1, hemi="split", views=["lat", "med"] +) # %% # References diff --git a/examples/simulation/source_simulator.py b/examples/simulation/source_simulator.py index 93a348e46ca..69cb803c134 100644 --- a/examples/simulation/source_simulator.py +++ b/examples/simulation/source_simulator.py @@ -29,38 +29,39 @@ class to generate source estimates and raw data. It is meant to be a brief # This will download the data if it not already on your machine. We also set # the subjects directory so we don't need to give it to functions. data_path = sample.data_path() -subjects_dir = data_path / 'subjects' -subject = 'sample' +subjects_dir = data_path / "subjects" +subject = "sample" # %% # First, we get an info structure from the test subject. -evoked_fname = data_path / 'MEG' / subject / 'sample_audvis-ave.fif' +evoked_fname = data_path / "MEG" / subject / "sample_audvis-ave.fif" info = mne.io.read_info(evoked_fname) -tstep = 1. / info['sfreq'] +tstep = 1.0 / info["sfreq"] # %% # To simulate sources, we also need a source space. It can be obtained from the # forward solution of the sample subject. -fwd_fname = data_path / 'MEG' / subject / 'sample_audvis-meg-eeg-oct-6-fwd.fif' +fwd_fname = data_path / "MEG" / subject / "sample_audvis-meg-eeg-oct-6-fwd.fif" fwd = mne.read_forward_solution(fwd_fname) -src = fwd['src'] +src = fwd["src"] # %% # To select a region to activate, we use the caudal middle frontal to grow # a region of interest. selected_label = mne.read_labels_from_annot( - subject, regexp='caudalmiddlefrontal-lh', subjects_dir=subjects_dir)[0] -location = 'center' # Use the center of the region as a seed. -extent = 10. # Extent in mm of the region. + subject, regexp="caudalmiddlefrontal-lh", subjects_dir=subjects_dir +)[0] +location = "center" # Use the center of the region as a seed. +extent = 10.0 # Extent in mm of the region. label = mne.label.select_sources( - subject, selected_label, location=location, extent=extent, - subjects_dir=subjects_dir) + subject, selected_label, location=location, extent=extent, subjects_dir=subjects_dir +) # %% # Define the time course of the activity for each source of the region to # activate. Here we use a sine wave at 18 Hz with a peak amplitude # of 10 nAm. -source_time_series = np.sin(2. * np.pi * 18. * np.arange(100) * tstep) * 10e-9 +source_time_series = np.sin(2.0 * np.pi * 18.0 * np.arange(100) * tstep) * 10e-9 # %% # Define when the activity occurs using events. The first column is the sample diff --git a/examples/stats/cluster_stats_evoked.py b/examples/stats/cluster_stats_evoked.py index cf2f9d59c18..1e21cdb7617 100644 --- a/examples/stats/cluster_stats_evoked.py +++ b/examples/stats/cluster_stats_evoked.py @@ -28,9 +28,9 @@ # %% # Set parameters data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -raw_fname = meg_path / 'sample_audvis_filt-0-40_raw.fif' -event_fname = meg_path / 'sample_audvis_filt-0-40_raw-eve.fif' +meg_path = data_path / "MEG" / "sample" +raw_fname = meg_path / "sample_audvis_filt-0-40_raw.fif" +event_fname = meg_path / "sample_audvis_filt-0-40_raw-eve.fif" tmin = -0.2 tmax = 0.5 @@ -38,22 +38,23 @@ raw = io.read_raw_fif(raw_fname) events = mne.read_events(event_fname) -channel = 'MEG 1332' # include only this channel in analysis +channel = "MEG 1332" # include only this channel in analysis include = [channel] # %% # Read epochs for the channel of interest -picks = mne.pick_types(raw.info, meg=False, eog=True, include=include, - exclude='bads') +picks = mne.pick_types(raw.info, meg=False, eog=True, include=include, exclude="bads") event_id = 1 reject = dict(grad=4000e-13, eog=150e-6) -epochs1 = mne.Epochs(raw, events, event_id, tmin, tmax, picks=picks, - baseline=(None, 0), reject=reject) +epochs1 = mne.Epochs( + raw, events, event_id, tmin, tmax, picks=picks, baseline=(None, 0), reject=reject +) condition1 = epochs1.get_data() # as 3D matrix event_id = 2 -epochs2 = mne.Epochs(raw, events, event_id, tmin, tmax, picks=picks, - baseline=(None, 0), reject=reject) +epochs2 = mne.Epochs( + raw, events, event_id, tmin, tmax, picks=picks, baseline=(None, 0), reject=reject +) condition2 = epochs2.get_data() # as 3D matrix condition1 = condition1[:, 0, :] # take only one channel to get a 2D array @@ -62,31 +63,36 @@ # %% # Compute statistic threshold = 6.0 -T_obs, clusters, cluster_p_values, H0 = \ - permutation_cluster_test([condition1, condition2], n_permutations=1000, - threshold=threshold, tail=1, n_jobs=None, - out_type='mask') +T_obs, clusters, cluster_p_values, H0 = permutation_cluster_test( + [condition1, condition2], + n_permutations=1000, + threshold=threshold, + tail=1, + n_jobs=None, + out_type="mask", +) # %% # Plot times = epochs1.times fig, (ax, ax2) = plt.subplots(2, 1, figsize=(8, 4)) -ax.set_title('Channel : ' + channel) -ax.plot(times, condition1.mean(axis=0) - condition2.mean(axis=0), - label="ERF Contrast (Event 1 - Event 2)") +ax.set_title("Channel : " + channel) +ax.plot( + times, + condition1.mean(axis=0) - condition2.mean(axis=0), + label="ERF Contrast (Event 1 - Event 2)", +) ax.set_ylabel("MEG (T / m)") ax.legend() for i_c, c in enumerate(clusters): c = c[0] if cluster_p_values[i_c] <= 0.05: - h = ax2.axvspan(times[c.start], times[c.stop - 1], - color='r', alpha=0.3) + h = ax2.axvspan(times[c.start], times[c.stop - 1], color="r", alpha=0.3) else: - ax2.axvspan(times[c.start], times[c.stop - 1], color=(0.3, 0.3, 0.3), - alpha=0.3) + ax2.axvspan(times[c.start], times[c.stop - 1], color=(0.3, 0.3, 0.3), alpha=0.3) -hf = plt.plot(times, T_obs, 'g') -ax2.legend((h, ), ('cluster p-value < 0.05', )) +hf = plt.plot(times, T_obs, "g") +ax2.legend((h,), ("cluster p-value < 0.05",)) ax2.set_xlabel("time (ms)") ax2.set_ylabel("f-values") diff --git a/examples/stats/fdr_stats_evoked.py b/examples/stats/fdr_stats_evoked.py index b90ab6f9ccd..94239f887df 100644 --- a/examples/stats/fdr_stats_evoked.py +++ b/examples/stats/fdr_stats_evoked.py @@ -30,26 +30,26 @@ # %% # Set parameters data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -raw_fname = meg_path / 'sample_audvis_filt-0-40_raw.fif' -event_fname = meg_path / 'sample_audvis_filt-0-40_raw-eve.fif' +meg_path = data_path / "MEG" / "sample" +raw_fname = meg_path / "sample_audvis_filt-0-40_raw.fif" +event_fname = meg_path / "sample_audvis_filt-0-40_raw-eve.fif" event_id, tmin, tmax = 1, -0.2, 0.5 # Setup for reading the raw data raw = io.read_raw_fif(raw_fname) events = mne.read_events(event_fname)[:30] -channel = 'MEG 1332' # include only this channel in analysis +channel = "MEG 1332" # include only this channel in analysis include = [channel] # %% # Read epochs for the channel of interest -picks = mne.pick_types(raw.info, meg=False, eog=True, include=include, - exclude='bads') +picks = mne.pick_types(raw.info, meg=False, eog=True, include=include, exclude="bads") event_id = 1 reject = dict(grad=4000e-13, eog=150e-6) -epochs = mne.Epochs(raw, events, event_id, tmin, tmax, picks=picks, - baseline=(None, 0), reject=reject) +epochs = mne.Epochs( + raw, events, event_id, tmin, tmax, picks=picks, baseline=(None, 0), reject=reject +) X = epochs.get_data() # as 3D matrix X = X[:, 0, :] # take only one channel to get a 2D array @@ -64,22 +64,43 @@ reject_bonferroni, pval_bonferroni = bonferroni_correction(pval, alpha=alpha) threshold_bonferroni = stats.t.ppf(1.0 - alpha / n_tests, n_samples - 1) -reject_fdr, pval_fdr = fdr_correction(pval, alpha=alpha, method='indep') +reject_fdr, pval_fdr = fdr_correction(pval, alpha=alpha, method="indep") threshold_fdr = np.min(np.abs(T)[reject_fdr]) # %% # Plot times = 1e3 * epochs.times -plt.close('all') -plt.plot(times, T, 'k', label='T-stat') +plt.close("all") +plt.plot(times, T, "k", label="T-stat") xmin, xmax = plt.xlim() -plt.hlines(threshold_uncorrected, xmin, xmax, linestyle='--', colors='k', - label='p=0.05 (uncorrected)', linewidth=2) -plt.hlines(threshold_bonferroni, xmin, xmax, linestyle='--', colors='r', - label='p=0.05 (Bonferroni)', linewidth=2) -plt.hlines(threshold_fdr, xmin, xmax, linestyle='--', colors='b', - label='p=0.05 (FDR)', linewidth=2) +plt.hlines( + threshold_uncorrected, + xmin, + xmax, + linestyle="--", + colors="k", + label="p=0.05 (uncorrected)", + linewidth=2, +) +plt.hlines( + threshold_bonferroni, + xmin, + xmax, + linestyle="--", + colors="r", + label="p=0.05 (Bonferroni)", + linewidth=2, +) +plt.hlines( + threshold_fdr, + xmin, + xmax, + linestyle="--", + colors="b", + label="p=0.05 (FDR)", + linewidth=2, +) plt.legend() plt.xlabel("Time (ms)") plt.ylabel("T-stat") diff --git a/examples/stats/linear_regression_raw.py b/examples/stats/linear_regression_raw.py index 54aef70c8e2..11a21a80305 100644 --- a/examples/stats/linear_regression_raw.py +++ b/examples/stats/linear_regression_raw.py @@ -33,37 +33,47 @@ # Load and preprocess data data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -raw_fname = meg_path / 'sample_audvis_filt-0-40_raw.fif' +meg_path = data_path / "MEG" / "sample" +raw_fname = meg_path / "sample_audvis_filt-0-40_raw.fif" raw = mne.io.read_raw_fif(raw_fname) -raw.pick_types(meg='grad', stim=True, eeg=False).load_data() -raw.filter(1, None, fir_design='firwin') # high-pass +raw.pick_types(meg="grad", stim=True, eeg=False).load_data() +raw.filter(1, None, fir_design="firwin") # high-pass # Set up events events = mne.find_events(raw) -event_id = {'Aud/L': 1, 'Aud/R': 2} -tmin, tmax = -.1, .5 +event_id = {"Aud/L": 1, "Aud/R": 2} +tmin, tmax = -0.1, 0.5 # regular epoching picks = mne.pick_types(raw.info, meg=True) -epochs = mne.Epochs(raw, events, event_id, tmin, tmax, reject=None, - baseline=None, preload=True, verbose=False) +epochs = mne.Epochs( + raw, + events, + event_id, + tmin, + tmax, + reject=None, + baseline=None, + preload=True, + verbose=False, +) # rERF -evokeds = linear_regression_raw(raw, events=events, event_id=event_id, - reject=None, tmin=tmin, tmax=tmax) +evokeds = linear_regression_raw( + raw, events=events, event_id=event_id, reject=None, tmin=tmin, tmax=tmax +) # linear_regression_raw returns a dict of evokeds # select conditions similarly to mne.Epochs objects # plot both results, and their difference cond = "Aud/L" fig, (ax1, ax2, ax3) = plt.subplots(3, 1) -params = dict(spatial_colors=True, show=False, ylim=dict(grad=(-200, 200)), - time_unit='s') +params = dict( + spatial_colors=True, show=False, ylim=dict(grad=(-200, 200)), time_unit="s" +) epochs[cond].average().plot(axes=ax1, **params) evokeds[cond].plot(axes=ax2, **params) -contrast = mne.combine_evoked([evokeds[cond], epochs[cond].average()], - weights=[1, -1]) +contrast = mne.combine_evoked([evokeds[cond], epochs[cond].average()], weights=[1, -1]) contrast.plot(axes=ax3, **params) ax1.set_title("Traditional averaging") ax2.set_title("rERF") diff --git a/examples/stats/sensor_permutation_test.py b/examples/stats/sensor_permutation_test.py index 654c9b7153c..7d54df71357 100644 --- a/examples/stats/sensor_permutation_test.py +++ b/examples/stats/sensor_permutation_test.py @@ -28,9 +28,9 @@ # %% # Set parameters data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -raw_fname = meg_path / 'sample_audvis_filt-0-40_raw.fif' -event_fname = meg_path / 'sample_audvis_filt-0-40_raw-eve.fif' +meg_path = data_path / "MEG" / "sample" +raw_fname = meg_path / "sample_audvis_filt-0-40_raw.fif" +event_fname = meg_path / "sample_audvis_filt-0-40_raw-eve.fif" event_id = 1 tmin = -0.2 tmax = 0.5 @@ -40,10 +40,19 @@ events = mne.read_events(event_fname) # pick MEG Gradiometers -picks = mne.pick_types(raw.info, meg='grad', eeg=False, stim=False, eog=True, - exclude='bads') -epochs = mne.Epochs(raw, events, event_id, tmin, tmax, picks=picks, - baseline=(None, 0), reject=dict(grad=4000e-13, eog=150e-6)) +picks = mne.pick_types( + raw.info, meg="grad", eeg=False, stim=False, eog=True, exclude="bads" +) +epochs = mne.Epochs( + raw, + events, + event_id, + tmin, + tmax, + picks=picks, + baseline=(None, 0), + reject=dict(grad=4000e-13, eog=150e-6), +) data = epochs.get_data() times = epochs.times @@ -62,15 +71,23 @@ # %% # View location of significantly active sensors -evoked = mne.EvokedArray(-np.log10(p_values)[:, np.newaxis], - epochs.info, tmin=0.) +evoked = mne.EvokedArray(-np.log10(p_values)[:, np.newaxis], epochs.info, tmin=0.0) # Extract mask and indices of active sensors in the layout stats_picks = mne.pick_channels(evoked.ch_names, significant_sensors_names) mask = p_values[:, np.newaxis] <= 0.05 -evoked.plot_topomap(ch_type='grad', times=[0], scalings=1, - time_format=None, cmap='Reds', vlim=(0., np.max), - units='-log10(p)', cbar_fmt='-%0.1f', mask=mask, - size=3, show_names=lambda x: x[4:] + ' ' * 20, - time_unit='s') +evoked.plot_topomap( + ch_type="grad", + times=[0], + scalings=1, + time_format=None, + cmap="Reds", + vlim=(0.0, np.max), + units="-log10(p)", + cbar_fmt="-%0.1f", + mask=mask, + size=3, + show_names=lambda x: x[4:] + " " * 20, + time_unit="s", +) diff --git a/examples/stats/sensor_regression.py b/examples/stats/sensor_regression.py index 9a1e42ae7f8..2b17927b28b 100644 --- a/examples/stats/sensor_regression.py +++ b/examples/stats/sensor_regression.py @@ -38,7 +38,7 @@ from mne.datasets import kiloword # Load the data -path = kiloword.data_path() / 'kword_metadata-epo.fif' +path = kiloword.data_path() / "kword_metadata-epo.fif" epochs = mne.read_epochs(path) print(epochs.metadata.head()) @@ -54,8 +54,9 @@ colors = {str(val): val for val in df[name].unique()} epochs.metadata = df.assign(Intercept=1) # Add an intercept for later evokeds = {val: epochs[name + " == " + val].average() for val in colors} -plot_compare_evokeds(evokeds, colors=colors, split_legend=True, - cmap=(name + " Percentile", "viridis")) +plot_compare_evokeds( + evokeds, colors=colors, split_legend=True, cmap=(name + " Percentile", "viridis") +) ############################################################################## # We observe that there appears to be a monotonic dependence of EEG on @@ -66,8 +67,9 @@ names = ["Intercept", name] res = linear_regression(epochs, epochs.metadata[names], names=names) for cond in names: - res[cond].beta.plot_joint(title=cond, ts_args=dict(time_unit='s'), - topomap_args=dict(time_unit='s')) + res[cond].beta.plot_joint( + title=cond, ts_args=dict(time_unit="s"), topomap_args=dict(time_unit="s") + ) ############################################################################## # Because the :func:`~mne.stats.linear_regression` function also estimates @@ -81,4 +83,4 @@ # by dark contour lines. reject_H0, fdr_pvals = fdr_correction(res["Concreteness"].p_val.data) evoked = res["Concreteness"].beta -evoked.plot_image(mask=reject_H0, time_unit='s') +evoked.plot_image(mask=reject_H0, time_unit="s") diff --git a/examples/time_frequency/compute_csd.py b/examples/time_frequency/compute_csd.py index e9a962bb733..0de5482be1e 100644 --- a/examples/time_frequency/compute_csd.py +++ b/examples/time_frequency/compute_csd.py @@ -35,9 +35,9 @@ # %% # Loading the sample dataset. data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -fname_raw = meg_path / 'sample_audvis_raw.fif' -fname_event = meg_path / 'sample_audvis_raw-eve.fif' +meg_path = data_path / "MEG" / "sample" +fname_raw = meg_path / "sample_audvis_raw.fif" +fname_event = meg_path / "sample_audvis_raw-eve.fif" raw = mne.io.read_raw_fif(fname_raw) events = mne.read_events(fname_event) @@ -47,12 +47,20 @@ # measurement units, and thus the scalings, differ across sensors. In this # example, for speed and clarity, we select a single channel type: # gradiometers. -picks = mne.pick_types(raw.info, meg='grad') +picks = mne.pick_types(raw.info, meg="grad") # Make some epochs, based on events with trigger code 1 -epochs = mne.Epochs(raw, events, event_id=1, tmin=-0.2, tmax=1, - picks=picks, baseline=(None, 0), - reject=dict(grad=4000e-13), preload=True) +epochs = mne.Epochs( + raw, + events, + event_id=1, + tmin=-0.2, + tmax=1, + picks=picks, + baseline=(None, 0), + reject=dict(grad=4000e-13), + preload=True, +) # %% # Computing CSD matrices using short-term Fourier transform and (adaptive) @@ -85,9 +93,11 @@ # created figures; in this case, each returned list has only one figure # so we use a Python trick of including a comma after our variable name # to assign the figure (not the list) to our ``fig`` variable: -plot_dict = {'Short-time Fourier transform': csd_fft, - 'Adaptive multitapers': csd_mt, - 'Morlet wavelet transform': csd_wav} +plot_dict = { + "Short-time Fourier transform": csd_fft, + "Adaptive multitapers": csd_mt, + "Morlet wavelet transform": csd_wav, +} for title, csd in plot_dict.items(): - fig, = csd.mean().plot() + (fig,) = csd.mean().plot() fig.suptitle(title) diff --git a/examples/time_frequency/compute_source_psd_epochs.py b/examples/time_frequency/compute_source_psd_epochs.py index 1ca42643f49..745fc69717e 100644 --- a/examples/time_frequency/compute_source_psd_epochs.py +++ b/examples/time_frequency/compute_source_psd_epochs.py @@ -24,17 +24,17 @@ print(__doc__) data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -fname_inv = meg_path / 'sample_audvis-meg-oct-6-meg-inv.fif' -fname_raw = meg_path / 'sample_audvis_raw.fif' -fname_event = meg_path / 'sample_audvis_raw-eve.fif' -label_name = 'Aud-lh' -fname_label = meg_path / 'labels' / f'{label_name}.label' -subjects_dir = data_path / 'subjects' +meg_path = data_path / "MEG" / "sample" +fname_inv = meg_path / "sample_audvis-meg-oct-6-meg-inv.fif" +fname_raw = meg_path / "sample_audvis_raw.fif" +fname_event = meg_path / "sample_audvis_raw-eve.fif" +label_name = "Aud-lh" +fname_label = meg_path / "labels" / f"{label_name}.label" +subjects_dir = data_path / "subjects" event_id, tmin, tmax = 1, -0.2, 0.5 snr = 1.0 # use smaller SNR for raw data -lambda2 = 1.0 / snr ** 2 +lambda2 = 1.0 / snr**2 method = "dSPM" # use dSPM method (could also be MNE or sLORETA) # Load data @@ -45,19 +45,27 @@ # Set up pick list include = [] -raw.info['bads'] += ['EEG 053'] # bads + 1 more +raw.info["bads"] += ["EEG 053"] # bads + 1 more # pick MEG channels -picks = mne.pick_types(raw.info, meg=True, eeg=False, stim=False, eog=True, - include=include, exclude='bads') +picks = mne.pick_types( + raw.info, meg=True, eeg=False, stim=False, eog=True, include=include, exclude="bads" +) # Read epochs -epochs = mne.Epochs(raw, events, event_id, tmin, tmax, picks=picks, - baseline=(None, 0), reject=dict(mag=4e-12, grad=4000e-13, - eog=150e-6)) +epochs = mne.Epochs( + raw, + events, + event_id, + tmin, + tmax, + picks=picks, + baseline=(None, 0), + reject=dict(mag=4e-12, grad=4000e-13, eog=150e-6), +) # define frequencies of interest -fmin, fmax = 0., 70. -bandwidth = 4. # bandwidth of the windows in Hz +fmin, fmax = 0.0, 70.0 +bandwidth = 4.0 # bandwidth of the windows in Hz # %% # Compute source space PSD in label @@ -68,14 +76,21 @@ # keep everything in memory. n_epochs_use = 10 -stcs = compute_source_psd_epochs(epochs[:n_epochs_use], inverse_operator, - lambda2=lambda2, - method=method, fmin=fmin, fmax=fmax, - bandwidth=bandwidth, label=label, - return_generator=True, verbose=True) +stcs = compute_source_psd_epochs( + epochs[:n_epochs_use], + inverse_operator, + lambda2=lambda2, + method=method, + fmin=fmin, + fmax=fmax, + bandwidth=bandwidth, + label=label, + return_generator=True, + verbose=True, +) # compute average PSD over the first 10 epochs -psd_avg = 0. +psd_avg = 0.0 for i, stc in enumerate(stcs): psd_avg += stc.data psd_avg /= n_epochs_use @@ -85,16 +100,21 @@ # %% # Visualize the 10 Hz PSD: -brain = stc.plot(initial_time=10., hemi='lh', views='lat', # 10 HZ - clim=dict(kind='value', lims=(20, 40, 60)), - smoothing_steps=3, subjects_dir=subjects_dir) -brain.add_label(label, borders=True, color='k') +brain = stc.plot( + initial_time=10.0, + hemi="lh", + views="lat", # 10 HZ + clim=dict(kind="value", lims=(20, 40, 60)), + smoothing_steps=3, + subjects_dir=subjects_dir, +) +brain.add_label(label, borders=True, color="k") # %% # Visualize the entire spectrum: fig, ax = plt.subplots() ax.plot(freqs, psd_avg.mean(axis=0)) -ax.set_xlabel('Freq (Hz)') +ax.set_xlabel("Freq (Hz)") ax.set_xlim(stc.times[[0, -1]]) -ax.set_ylabel('Power Spectral Density') +ax.set_ylabel("Power Spectral Density") diff --git a/examples/time_frequency/source_label_time_frequency.py b/examples/time_frequency/source_label_time_frequency.py index 721c2fc4d2d..da3af06e4dc 100644 --- a/examples/time_frequency/source_label_time_frequency.py +++ b/examples/time_frequency/source_label_time_frequency.py @@ -32,50 +32,66 @@ # %% # Set parameters data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -raw_fname = meg_path / 'sample_audvis_raw.fif' -fname_inv = meg_path / 'sample_audvis-meg-oct-6-meg-inv.fif' -label_name = 'Aud-rh' -fname_label = meg_path / 'labels' / f'{label_name}.label' +meg_path = data_path / "MEG" / "sample" +raw_fname = meg_path / "sample_audvis_raw.fif" +fname_inv = meg_path / "sample_audvis-meg-oct-6-meg-inv.fif" +label_name = "Aud-rh" +fname_label = meg_path / "labels" / f"{label_name}.label" tmin, tmax, event_id = -0.2, 0.5, 2 # Setup for reading the raw data raw = io.read_raw_fif(raw_fname) -events = mne.find_events(raw, stim_channel='STI 014') +events = mne.find_events(raw, stim_channel="STI 014") inverse_operator = read_inverse_operator(fname_inv) include = [] -raw.info['bads'] += ['MEG 2443', 'EEG 053'] # bads + 2 more +raw.info["bads"] += ["MEG 2443", "EEG 053"] # bads + 2 more # Picks MEG channels -picks = mne.pick_types(raw.info, meg=True, eeg=False, eog=True, - stim=False, include=include, exclude='bads') +picks = mne.pick_types( + raw.info, meg=True, eeg=False, eog=True, stim=False, include=include, exclude="bads" +) reject = dict(grad=4000e-13, mag=4e-12, eog=150e-6) # Load epochs -epochs = mne.Epochs(raw, events, event_id, tmin, tmax, picks=picks, - baseline=(None, 0), reject=reject, - preload=True) +epochs = mne.Epochs( + raw, + events, + event_id, + tmin, + tmax, + picks=picks, + baseline=(None, 0), + reject=reject, + preload=True, +) # Compute a source estimate per frequency band including and excluding the # evoked response freqs = np.arange(7, 30, 2) # define frequencies of interest label = mne.read_label(fname_label) -n_cycles = freqs / 3. # different number of cycle per frequency +n_cycles = freqs / 3.0 # different number of cycle per frequency # subtract the evoked response in order to exclude evoked activity epochs_induced = epochs.copy().subtract_evoked() -plt.close('all') +plt.close("all") -for ii, (this_epochs, title) in enumerate(zip([epochs, epochs_induced], - ['evoked + induced', - 'induced only'])): +for ii, (this_epochs, title) in enumerate( + zip([epochs, epochs_induced], ["evoked + induced", "induced only"]) +): # compute the source space power and the inter-trial coherence power, itc = source_induced_power( - this_epochs, inverse_operator, freqs, label, baseline=(-0.1, 0), - baseline_mode='percent', n_cycles=n_cycles, n_jobs=None) + this_epochs, + inverse_operator, + freqs, + label, + baseline=(-0.1, 0), + baseline_mode="percent", + n_cycles=n_cycles, + n_jobs=None, + ) power = np.mean(power, axis=0) # average over sources itc = np.mean(itc, axis=0) # average over sources @@ -85,22 +101,33 @@ # View time-frequency plots plt.subplots_adjust(0.1, 0.08, 0.96, 0.94, 0.2, 0.43) plt.subplot(2, 2, 2 * ii + 1) - plt.imshow(20 * power, - extent=[times[0], times[-1], freqs[0], freqs[-1]], - aspect='auto', origin='lower', vmin=0., vmax=30., cmap='RdBu_r') - plt.xlabel('Time (s)') - plt.ylabel('Frequency (Hz)') - plt.title('Power (%s)' % title) + plt.imshow( + 20 * power, + extent=[times[0], times[-1], freqs[0], freqs[-1]], + aspect="auto", + origin="lower", + vmin=0.0, + vmax=30.0, + cmap="RdBu_r", + ) + plt.xlabel("Time (s)") + plt.ylabel("Frequency (Hz)") + plt.title("Power (%s)" % title) plt.colorbar() plt.subplot(2, 2, 2 * ii + 2) - plt.imshow(itc, - extent=[times[0], times[-1], freqs[0], freqs[-1]], - aspect='auto', origin='lower', vmin=0, vmax=0.7, - cmap='RdBu_r') - plt.xlabel('Time (s)') - plt.ylabel('Frequency (Hz)') - plt.title('ITC (%s)' % title) + plt.imshow( + itc, + extent=[times[0], times[-1], freqs[0], freqs[-1]], + aspect="auto", + origin="lower", + vmin=0, + vmax=0.7, + cmap="RdBu_r", + ) + plt.xlabel("Time (s)") + plt.ylabel("Frequency (Hz)") + plt.title("ITC (%s)" % title) plt.colorbar() plt.show() diff --git a/examples/time_frequency/source_power_spectrum.py b/examples/time_frequency/source_power_spectrum.py index 4b6d582d50b..a2aab813930 100644 --- a/examples/time_frequency/source_power_spectrum.py +++ b/examples/time_frequency/source_power_spectrum.py @@ -26,37 +26,48 @@ # %% # Set parameters data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -raw_fname = meg_path / 'sample_audvis_raw.fif' -fname_inv = meg_path / 'sample_audvis-meg-oct-6-meg-inv.fif' -fname_label = meg_path / 'labels' / 'Aud-lh.label' +meg_path = data_path / "MEG" / "sample" +raw_fname = meg_path / "sample_audvis_raw.fif" +fname_inv = meg_path / "sample_audvis-meg-oct-6-meg-inv.fif" +fname_label = meg_path / "labels" / "Aud-lh.label" # Setup for reading the raw data raw = io.read_raw_fif(raw_fname, verbose=False) -events = mne.find_events(raw, stim_channel='STI 014') +events = mne.find_events(raw, stim_channel="STI 014") inverse_operator = read_inverse_operator(fname_inv) -raw.info['bads'] = ['MEG 2443', 'EEG 053'] +raw.info["bads"] = ["MEG 2443", "EEG 053"] # picks MEG gradiometers -picks = mne.pick_types(raw.info, meg=True, eeg=False, eog=True, - stim=False, exclude='bads') +picks = mne.pick_types( + raw.info, meg=True, eeg=False, eog=True, stim=False, exclude="bads" +) tmin, tmax = 0, 120 # use the first 120s of data fmin, fmax = 4, 100 # look at frequencies between 4 and 100Hz n_fft = 2048 # the FFT size (n_fft). Ideally a power of 2 label = mne.read_label(fname_label) -stc = compute_source_psd(raw, inverse_operator, lambda2=1. / 9., method="dSPM", - tmin=tmin, tmax=tmax, fmin=fmin, fmax=fmax, - pick_ori="normal", n_fft=n_fft, label=label, - dB=True) +stc = compute_source_psd( + raw, + inverse_operator, + lambda2=1.0 / 9.0, + method="dSPM", + tmin=tmin, + tmax=tmax, + fmin=fmin, + fmax=fmax, + pick_ori="normal", + n_fft=n_fft, + label=label, + dB=True, +) -stc.save('psd_dSPM', overwrite=True) +stc.save("psd_dSPM", overwrite=True) # %% # View PSD of sources in label plt.plot(stc.times, stc.data.T) -plt.xlabel('Frequency (Hz)') -plt.ylabel('PSD (dB)') -plt.title('Source Power Spectrum (PSD)') +plt.xlabel("Frequency (Hz)") +plt.ylabel("PSD (dB)") +plt.title("Source Power Spectrum (PSD)") plt.show() diff --git a/examples/time_frequency/source_power_spectrum_opm.py b/examples/time_frequency/source_power_spectrum_opm.py index 462f79c8eb9..0e5cf5d34c8 100644 --- a/examples/time_frequency/source_power_spectrum_opm.py +++ b/examples/time_frequency/source_power_spectrum_opm.py @@ -35,19 +35,19 @@ print(__doc__) data_path = mne.datasets.opm.data_path() -subject = 'OPM_sample' - -subjects_dir = data_path / 'subjects' -bem_dir = subjects_dir / subject / 'bem' -bem_fname = bem_dir / f'{subject}-5120-5120-5120-bem-sol.fif' -src_fname = bem_dir / f'{subject}-oct6-src.fif' -vv_fname = data_path / 'MEG' / 'SQUID' / 'SQUID_resting_state.fif' -vv_erm_fname = data_path / 'MEG' / 'SQUID' / 'SQUID_empty_room.fif' -vv_trans_fname = data_path / 'MEG' / 'SQUID' / 'SQUID-trans.fif' -opm_fname = data_path / 'MEG' / 'OPM' / 'OPM_resting_state_raw.fif' -opm_erm_fname = data_path / 'MEG' / 'OPM' / 'OPM_empty_room_raw.fif' -opm_trans = mne.transforms.Transform('head', 'mri') # use identity transform -opm_coil_def_fname = data_path / 'MEG' / 'OPM' / 'coil_def.dat' +subject = "OPM_sample" + +subjects_dir = data_path / "subjects" +bem_dir = subjects_dir / subject / "bem" +bem_fname = bem_dir / f"{subject}-5120-5120-5120-bem-sol.fif" +src_fname = bem_dir / f"{subject}-oct6-src.fif" +vv_fname = data_path / "MEG" / "SQUID" / "SQUID_resting_state.fif" +vv_erm_fname = data_path / "MEG" / "SQUID" / "SQUID_empty_room.fif" +vv_trans_fname = data_path / "MEG" / "SQUID" / "SQUID-trans.fif" +opm_fname = data_path / "MEG" / "OPM" / "OPM_resting_state_raw.fif" +opm_erm_fname = data_path / "MEG" / "OPM" / "OPM_empty_room_raw.fif" +opm_trans = mne.transforms.Transform("head", "mri") # use identity transform +opm_coil_def_fname = data_path / "MEG" / "OPM" / "coil_def.dat" ############################################################################## # Load data, resample. We will store the raw objects in dicts with entries @@ -55,28 +55,28 @@ raws = dict() raw_erms = dict() -new_sfreq = 60. # Nyquist frequency (30 Hz) < line noise freq (50 Hz) -raws['vv'] = mne.io.read_raw_fif(vv_fname, verbose='error') # ignore naming -raws['vv'].load_data().resample(new_sfreq) -raws['vv'].info['bads'] = ['MEG2233', 'MEG1842'] -raw_erms['vv'] = mne.io.read_raw_fif(vv_erm_fname, verbose='error') -raw_erms['vv'].load_data().resample(new_sfreq) -raw_erms['vv'].info['bads'] = ['MEG2233', 'MEG1842'] - -raws['opm'] = mne.io.read_raw_fif(opm_fname) -raws['opm'].load_data().resample(new_sfreq) -raw_erms['opm'] = mne.io.read_raw_fif(opm_erm_fname) -raw_erms['opm'].load_data().resample(new_sfreq) +new_sfreq = 60.0 # Nyquist frequency (30 Hz) < line noise freq (50 Hz) +raws["vv"] = mne.io.read_raw_fif(vv_fname, verbose="error") # ignore naming +raws["vv"].load_data().resample(new_sfreq) +raws["vv"].info["bads"] = ["MEG2233", "MEG1842"] +raw_erms["vv"] = mne.io.read_raw_fif(vv_erm_fname, verbose="error") +raw_erms["vv"].load_data().resample(new_sfreq) +raw_erms["vv"].info["bads"] = ["MEG2233", "MEG1842"] + +raws["opm"] = mne.io.read_raw_fif(opm_fname) +raws["opm"].load_data().resample(new_sfreq) +raw_erms["opm"] = mne.io.read_raw_fif(opm_erm_fname) +raw_erms["opm"].load_data().resample(new_sfreq) # Make sure our assumptions later hold -assert raws['opm'].info['sfreq'] == raws['vv'].info['sfreq'] +assert raws["opm"].info["sfreq"] == raws["vv"].info["sfreq"] ############################################################################## # Explore data -titles = dict(vv='VectorView', opm='OPM') -kinds = ('vv', 'opm') +titles = dict(vv="VectorView", opm="OPM") +kinds = ("vv", "opm") n_fft = next_fast_len(int(round(4 * new_sfreq))) -print('Using n_fft=%d (%0.1f s)' % (n_fft, n_fft / raws['vv'].info['sfreq'])) +print("Using n_fft=%d (%0.1f s)" % (n_fft, n_fft / raws["vv"].info["sfreq"])) for kind in kinds: fig = raws[kind].compute_psd(n_fft=n_fft, proj=True).plot() fig.suptitle(titles[kind]) @@ -87,37 +87,48 @@ # --------------------- # Here we use a reduced size source space (oct5) just for speed -src = mne.setup_source_space( - subject, 'oct5', add_dist=False, subjects_dir=subjects_dir) +src = mne.setup_source_space(subject, "oct5", add_dist=False, subjects_dir=subjects_dir) # This line removes source-to-source distances that we will not need. # We only do it here to save a bit of memory, in general this is not required. -del src[0]['dist'], src[1]['dist'] +del src[0]["dist"], src[1]["dist"] bem = mne.read_bem_solution(bem_fname) # For speed, let's just use a 1-layer BEM -bem = mne.make_bem_solution(bem['surfs'][-1:]) +bem = mne.make_bem_solution(bem["surfs"][-1:]) fwd = dict() # check alignment and generate forward for VectorView -kwargs = dict(azimuth=0, elevation=90, distance=0.6, focalpoint=(0., 0., 0.)) +kwargs = dict(azimuth=0, elevation=90, distance=0.6, focalpoint=(0.0, 0.0, 0.0)) fig = mne.viz.plot_alignment( - raws['vv'].info, trans=vv_trans_fname, subject=subject, - subjects_dir=subjects_dir, dig=True, coord_frame='mri', - surfaces=('head', 'white')) + raws["vv"].info, + trans=vv_trans_fname, + subject=subject, + subjects_dir=subjects_dir, + dig=True, + coord_frame="mri", + surfaces=("head", "white"), +) mne.viz.set_3d_view(figure=fig, **kwargs) -fwd['vv'] = mne.make_forward_solution( - raws['vv'].info, vv_trans_fname, src, bem, eeg=False, verbose=True) +fwd["vv"] = mne.make_forward_solution( + raws["vv"].info, vv_trans_fname, src, bem, eeg=False, verbose=True +) ############################################################################## # And for OPM: with mne.use_coil_def(opm_coil_def_fname): fig = mne.viz.plot_alignment( - raws['opm'].info, trans=opm_trans, subject=subject, - subjects_dir=subjects_dir, dig=False, coord_frame='mri', - surfaces=('head', 'white')) + raws["opm"].info, + trans=opm_trans, + subject=subject, + subjects_dir=subjects_dir, + dig=False, + coord_frame="mri", + surfaces=("head", "white"), + ) mne.viz.set_3d_view(figure=fig, **kwargs) - fwd['opm'] = mne.make_forward_solution( - raws['opm'].info, opm_trans, src, bem, eeg=False, verbose=True) + fwd["opm"] = mne.make_forward_solution( + raws["opm"].info, opm_trans, src, bem, eeg=False, verbose=True + ) del src, bem @@ -131,24 +142,29 @@ topos = dict(vv=dict(), opm=dict()) stcs = dict(vv=dict(), opm=dict()) -snr = 3. -lambda2 = 1. / snr ** 2 +snr = 3.0 +lambda2 = 1.0 / snr**2 for kind in kinds: noise_cov = mne.compute_raw_covariance(raw_erms[kind]) inverse_operator = mne.minimum_norm.make_inverse_operator( - raws[kind].info, forward=fwd[kind], noise_cov=noise_cov, verbose=True) + raws[kind].info, forward=fwd[kind], noise_cov=noise_cov, verbose=True + ) stc_psd, sensor_psd = mne.minimum_norm.compute_source_psd( - raws[kind], inverse_operator, lambda2=lambda2, - n_fft=n_fft, dB=False, return_sensor=True, verbose=True) + raws[kind], + inverse_operator, + lambda2=lambda2, + n_fft=n_fft, + dB=False, + return_sensor=True, + verbose=True, + ) topo_norm = sensor_psd.data.sum(axis=1, keepdims=True) stc_norm = stc_psd.sum() # same operation on MNE object, sum across freqs # Normalize each source point by the total power across freqs for band, limits in freq_bands.items(): data = sensor_psd.copy().crop(*limits).data.sum(axis=1, keepdims=True) - topos[kind][band] = mne.EvokedArray( - 100 * data / topo_norm, sensor_psd.info) - stcs[kind][band] = \ - 100 * stc_psd.copy().crop(*limits).sum() / stc_norm.data + topos[kind][band] = mne.EvokedArray(100 * data / topo_norm, sensor_psd.info) + stcs[kind][band] = 100 * stc_psd.copy().crop(*limits).sum() / stc_norm.data del inverse_operator del fwd, raws, raw_erms @@ -161,22 +177,42 @@ # Alpha # ----- + def plot_band(kind, band): """Plot activity within a frequency band on the subject's brain.""" - title = "%s %s\n(%d-%d Hz)" % ((titles[kind], band,) + freq_bands[band]) + title = "%s %s\n(%d-%d Hz)" % ( + ( + titles[kind], + band, + ) + + freq_bands[band] + ) topos[kind][band].plot_topomap( - times=0., scalings=1., cbar_fmt='%0.1f', vlim=(0, None), - cmap='inferno', time_format=title) + times=0.0, + scalings=1.0, + cbar_fmt="%0.1f", + vlim=(0, None), + cmap="inferno", + time_format=title, + ) brain = stcs[kind][band].plot( - subject=subject, subjects_dir=subjects_dir, views='cau', hemi='both', - time_label=title, title=title, colormap='inferno', - time_viewer=False, show_traces=False, - clim=dict(kind='percent', lims=(70, 85, 99)), smoothing_steps=10) + subject=subject, + subjects_dir=subjects_dir, + views="cau", + hemi="both", + time_label=title, + title=title, + colormap="inferno", + time_viewer=False, + show_traces=False, + clim=dict(kind="percent", lims=(70, 85, 99)), + smoothing_steps=10, + ) brain.show_view(azimuth=0, elevation=0, roll=0) return fig, brain -fig_alpha, brain_alpha = plot_band('vv', 'alpha') +fig_alpha, brain_alpha = plot_band("vv", "alpha") # %% # Beta @@ -184,13 +220,13 @@ def plot_band(kind, band): # Here we also show OPM data, which shows a profile similar to the VectorView # data beneath the sensors. VectorView first: -fig_beta, brain_beta = plot_band('vv', 'beta') +fig_beta, brain_beta = plot_band("vv", "beta") # %% # Then OPM: # sphinx_gallery_thumbnail_number = 10 -fig_beta_opm, brain_beta_opm = plot_band('opm', 'beta') +fig_beta_opm, brain_beta_opm = plot_band("opm", "beta") # %% # References diff --git a/examples/time_frequency/source_space_time_frequency.py b/examples/time_frequency/source_space_time_frequency.py index a0a5f944439..61c3959c232 100644 --- a/examples/time_frequency/source_space_time_frequency.py +++ b/examples/time_frequency/source_space_time_frequency.py @@ -28,46 +28,57 @@ # %% # Set parameters data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -raw_fname = meg_path / 'sample_audvis_raw.fif' -fname_inv = meg_path / 'sample_audvis-meg-oct-6-meg-inv.fif' +meg_path = data_path / "MEG" / "sample" +raw_fname = meg_path / "sample_audvis_raw.fif" +fname_inv = meg_path / "sample_audvis-meg-oct-6-meg-inv.fif" tmin, tmax, event_id = -0.2, 0.5, 1 # Setup for reading the raw data raw = io.read_raw_fif(raw_fname) -events = mne.find_events(raw, stim_channel='STI 014') +events = mne.find_events(raw, stim_channel="STI 014") inverse_operator = read_inverse_operator(fname_inv) include = [] -raw.info['bads'] += ['MEG 2443', 'EEG 053'] # bads + 2 more +raw.info["bads"] += ["MEG 2443", "EEG 053"] # bads + 2 more # picks MEG gradiometers -picks = mne.pick_types(raw.info, meg=True, eeg=False, eog=True, - stim=False, include=include, exclude='bads') +picks = mne.pick_types( + raw.info, meg=True, eeg=False, eog=True, stim=False, include=include, exclude="bads" +) # Load condition 1 event_id = 1 events = events[:10] # take 10 events to keep the computation time low # Use linear detrend to reduce any edge artifacts -epochs = mne.Epochs(raw, events, event_id, tmin, tmax, picks=picks, - baseline=(None, 0), reject=dict(grad=4000e-13, eog=150e-6), - preload=True, detrend=1) +epochs = mne.Epochs( + raw, + events, + event_id, + tmin, + tmax, + picks=picks, + baseline=(None, 0), + reject=dict(grad=4000e-13, eog=150e-6), + preload=True, + detrend=1, +) # Compute a source estimate per frequency band bands = dict(alpha=[9, 11], beta=[18, 22]) -stcs = source_band_induced_power(epochs, inverse_operator, bands, n_cycles=2, - use_fft=False, n_jobs=None) +stcs = source_band_induced_power( + epochs, inverse_operator, bands, n_cycles=2, use_fft=False, n_jobs=None +) for b, stc in stcs.items(): - stc.save('induced_power_%s' % b, overwrite=True) + stc.save("induced_power_%s" % b, overwrite=True) # %% # plot mean power -plt.plot(stcs['alpha'].times, stcs['alpha'].data.mean(axis=0), label='Alpha') -plt.plot(stcs['beta'].times, stcs['beta'].data.mean(axis=0), label='Beta') -plt.xlabel('Time (ms)') -plt.ylabel('Power') +plt.plot(stcs["alpha"].times, stcs["alpha"].data.mean(axis=0), label="Alpha") +plt.plot(stcs["beta"].times, stcs["beta"].data.mean(axis=0), label="Beta") +plt.xlabel("Time (ms)") +plt.ylabel("Power") plt.legend() -plt.title('Mean source induced power') +plt.title("Mean source induced power") plt.show() diff --git a/examples/time_frequency/temporal_whitening.py b/examples/time_frequency/temporal_whitening.py index 068abad7337..de70216461b 100644 --- a/examples/time_frequency/temporal_whitening.py +++ b/examples/time_frequency/temporal_whitening.py @@ -26,17 +26,17 @@ print(__doc__) data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -raw_fname = meg_path / 'sample_audvis_raw.fif' -proj_fname = meg_path / 'sample_audvis_ecg-proj.fif' +meg_path = data_path / "MEG" / "sample" +raw_fname = meg_path / "sample_audvis_raw.fif" +proj_fname = meg_path / "sample_audvis_ecg-proj.fif" raw = mne.io.read_raw_fif(raw_fname) proj = mne.read_proj(proj_fname) raw.add_proj(proj) -raw.info['bads'] = ['MEG 2443', 'EEG 053'] # mark bad channels +raw.info["bads"] = ["MEG 2443", "EEG 053"] # mark bad channels # Set up pick list: Gradiometers - bad channels -picks = mne.pick_types(raw.info, meg='grad', exclude='bads') +picks = mne.pick_types(raw.info, meg="grad", exclude="bads") order = 5 # define model order picks = picks[:1] @@ -45,21 +45,21 @@ b, a = fit_iir_model_raw(raw, order=order, picks=picks, tmin=60, tmax=180) d, times = raw[0, 10000:20000] # look at one channel from now on d = d.ravel() # make flat vector -innovation = signal.convolve(d, a, 'valid') +innovation = signal.convolve(d, a, "valid") d_ = signal.lfilter(b, a, innovation) # regenerate the signal d_ = np.r_[d_[0] * np.ones(order), d_] # dummy samples to keep signal length # %% # Plot the different time series and PSDs -plt.close('all') +plt.close("all") plt.figure() -plt.plot(d[:100], label='signal') -plt.plot(d_[:100], label='regenerated signal') +plt.plot(d[:100], label="signal") +plt.plot(d_[:100], label="regenerated signal") plt.legend() plt.figure() -plt.psd(d, Fs=raw.info['sfreq'], NFFT=2048) -plt.psd(innovation, Fs=raw.info['sfreq'], NFFT=2048) -plt.psd(d_, Fs=raw.info['sfreq'], NFFT=2048, linestyle='--') -plt.legend(('Signal', 'Innovation', 'Regenerated signal')) +plt.psd(d, Fs=raw.info["sfreq"], NFFT=2048) +plt.psd(innovation, Fs=raw.info["sfreq"], NFFT=2048) +plt.psd(d_, Fs=raw.info["sfreq"], NFFT=2048, linestyle="--") +plt.legend(("Signal", "Innovation", "Regenerated signal")) plt.show() diff --git a/examples/time_frequency/time_frequency_erds.py b/examples/time_frequency/time_frequency_erds.py index d55122c232b..72b5f36d172 100644 --- a/examples/time_frequency/time_frequency_erds.py +++ b/examples/time_frequency/time_frequency_erds.py @@ -52,7 +52,7 @@ fnames = eegbci.load_data(subject=1, runs=(6, 10, 14)) raw = concatenate_raws([read_raw_edf(f, preload=True) for f in fnames]) -raw.rename_channels(lambda x: x.strip('.')) # remove dots from channel names +raw.rename_channels(lambda x: x.strip(".")) # remove dots from channel names events, _ = mne.events_from_annotations(raw, event_id=dict(T1=2, T2=3)) @@ -61,8 +61,16 @@ tmin, tmax = -1, 4 event_ids = dict(hands=2, feet=3) # map event IDs to tasks -epochs = mne.Epochs(raw, events, event_ids, tmin - 0.5, tmax + 0.5, - picks=('C3', 'Cz', 'C4'), baseline=None, preload=True) +epochs = mne.Epochs( + raw, + events, + event_ids, + tmin - 0.5, + tmax + 0.5, + picks=("C3", "Cz", "C4"), + baseline=None, + preload=True, +) # %% # .. _cnorm-example: @@ -80,20 +88,29 @@ baseline = (-1, 0) # baseline interval (in s) cnorm = TwoSlopeNorm(vmin=vmin, vcenter=0, vmax=vmax) # min, center & max ERDS -kwargs = dict(n_permutations=100, step_down_p=0.05, seed=1, - buffer_size=None, out_type='mask') # for cluster test +kwargs = dict( + n_permutations=100, step_down_p=0.05, seed=1, buffer_size=None, out_type="mask" +) # for cluster test # %% # Finally, we perform time/frequency decomposition over all epochs. -tfr = tfr_multitaper(epochs, freqs=freqs, n_cycles=freqs, use_fft=True, - return_itc=False, average=False, decim=2) +tfr = tfr_multitaper( + epochs, + freqs=freqs, + n_cycles=freqs, + use_fft=True, + return_itc=False, + average=False, + decim=2, +) tfr.crop(tmin, tmax).apply_baseline(baseline, mode="percent") for event in event_ids: # select desired epochs for visualization tfr_ev = tfr[event] - fig, axes = plt.subplots(1, 4, figsize=(12, 4), - gridspec_kw={"width_ratios": [10, 10, 10, 1]}) + fig, axes = plt.subplots( + 1, 4, figsize=(12, 4), gridspec_kw={"width_ratios": [10, 10, 10, 1]} + ) for ch, ax in enumerate(axes[:-1]): # for each channel # positive clusters _, c1, p1, _ = pcluster_test(tfr_ev.data[:, ch], tail=1, **kwargs) @@ -108,9 +125,16 @@ mask = c[..., p <= 0.05].any(axis=-1) # plot TFR (ERDS map with masking) - tfr_ev.average().plot([ch], cmap="RdBu", cnorm=cnorm, axes=ax, - colorbar=False, show=False, mask=mask, - mask_style="mask") + tfr_ev.average().plot( + [ch], + cmap="RdBu", + cnorm=cnorm, + axes=ax, + colorbar=False, + show=False, + mask=mask, + mask_style="mask", + ) ax.set_title(epochs.ch_names[ch], fontsize=10) ax.axvline(0, linewidth=1, color="black", linestyle=":") # event @@ -139,33 +163,28 @@ df = tfr.to_data_frame(time_format=None, long_format=True) # Map to frequency bands: -freq_bounds = {'_': 0, - 'delta': 3, - 'theta': 7, - 'alpha': 13, - 'beta': 35, - 'gamma': 140} -df['band'] = pd.cut(df['freq'], list(freq_bounds.values()), - labels=list(freq_bounds)[1:]) +freq_bounds = {"_": 0, "delta": 3, "theta": 7, "alpha": 13, "beta": 35, "gamma": 140} +df["band"] = pd.cut( + df["freq"], list(freq_bounds.values()), labels=list(freq_bounds)[1:] +) # Filter to retain only relevant frequency bands: -freq_bands_of_interest = ['delta', 'theta', 'alpha', 'beta'] +freq_bands_of_interest = ["delta", "theta", "alpha", "beta"] df = df[df.band.isin(freq_bands_of_interest)] -df['band'] = df['band'].cat.remove_unused_categories() +df["band"] = df["band"].cat.remove_unused_categories() # Order channels for plotting: -df['channel'] = df['channel'].cat.reorder_categories(('C3', 'Cz', 'C4'), - ordered=True) +df["channel"] = df["channel"].cat.reorder_categories(("C3", "Cz", "C4"), ordered=True) -g = sns.FacetGrid(df, row='band', col='channel', margin_titles=True) -g.map(sns.lineplot, 'time', 'value', 'condition', n_boot=10) -axline_kw = dict(color='black', linestyle='dashed', linewidth=0.5, alpha=0.5) +g = sns.FacetGrid(df, row="band", col="channel", margin_titles=True) +g.map(sns.lineplot, "time", "value", "condition", n_boot=10) +axline_kw = dict(color="black", linestyle="dashed", linewidth=0.5, alpha=0.5) g.map(plt.axhline, y=0, **axline_kw) g.map(plt.axvline, x=0, **axline_kw) g.set(ylim=(None, 1.5)) g.set_axis_labels("Time (s)", "ERDS (%)") g.set_titles(col_template="{col_name}", row_template="{row_name}") -g.add_legend(ncol=2, loc='lower center') +g.add_legend(ncol=2, loc="lower center") g.fig.subplots_adjust(left=0.1, right=0.9, top=0.9, bottom=0.08) # %% @@ -174,17 +193,27 @@ # Here, we use seaborn to plot the average ERDS in the motor imagery interval # as a function of frequency band and imagery condition: -df_mean = (df.query('time > 1') - .groupby(['condition', 'epoch', 'band', 'channel'])[['value']] - .mean() - .reset_index()) - -g = sns.FacetGrid(df_mean, col='condition', col_order=['hands', 'feet'], - margin_titles=True) -g = (g.map(sns.violinplot, 'channel', 'value', 'band', n_boot=10, - palette='deep', order=['C3', 'Cz', 'C4'], - hue_order=freq_bands_of_interest, - linewidth=0.5).add_legend(ncol=4, loc='lower center')) +df_mean = ( + df.query("time > 1") + .groupby(["condition", "epoch", "band", "channel"])[["value"]] + .mean() + .reset_index() +) + +g = sns.FacetGrid( + df_mean, col="condition", col_order=["hands", "feet"], margin_titles=True +) +g = g.map( + sns.violinplot, + "channel", + "value", + "band", + n_boot=10, + palette="deep", + order=["C3", "Cz", "C4"], + hue_order=freq_bands_of_interest, + linewidth=0.5, +).add_legend(ncol=4, loc="lower center") g.map(plt.axhline, **axline_kw) g.set_axis_labels("", "ERDS (%)") diff --git a/examples/time_frequency/time_frequency_global_field_power.py b/examples/time_frequency/time_frequency_global_field_power.py index a9af92cdde9..df816162f1c 100644 --- a/examples/time_frequency/time_frequency_global_field_power.py +++ b/examples/time_frequency/time_frequency_global_field_power.py @@ -54,47 +54,52 @@ # %% # Set parameters data_path = somato.data_path() -subject = '01' -task = 'somato' -raw_fname = (data_path / f'sub-{subject}' / 'meg' / - f'sub-{subject}_task-{task}_meg.fif') +subject = "01" +task = "somato" +raw_fname = data_path / f"sub-{subject}" / "meg" / f"sub-{subject}_task-{task}_meg.fif" # let's explore some frequency bands -iter_freqs = [ - ('Theta', 4, 7), - ('Alpha', 8, 12), - ('Beta', 13, 25), - ('Gamma', 30, 45) -] +iter_freqs = [("Theta", 4, 7), ("Alpha", 8, 12), ("Beta", 13, 25), ("Gamma", 30, 45)] # %% # We create average power time courses for each frequency band # set epoching parameters -event_id, tmin, tmax = 1, -1., 3. +event_id, tmin, tmax = 1, -1.0, 3.0 baseline = None # get the header to extract events raw = mne.io.read_raw_fif(raw_fname) -events = mne.find_events(raw, stim_channel='STI 014') +events = mne.find_events(raw, stim_channel="STI 014") frequency_map = list() for band, fmin, fmax in iter_freqs: # (re)load the data to save memory raw = mne.io.read_raw_fif(raw_fname) - raw.pick_types(meg='grad', eog=True) # we just look at gradiometers + raw.pick_types(meg="grad", eog=True) # we just look at gradiometers raw.load_data() # bandpass filter - raw.filter(fmin, fmax, n_jobs=None, # use more jobs to speed up. - l_trans_bandwidth=1, # make sure filter params are the same - h_trans_bandwidth=1) # in each band and skip "auto" option. + raw.filter( + fmin, + fmax, + n_jobs=None, # use more jobs to speed up. + l_trans_bandwidth=1, # make sure filter params are the same + h_trans_bandwidth=1, + ) # in each band and skip "auto" option. # epoch - epochs = mne.Epochs(raw, events, event_id, tmin, tmax, baseline=baseline, - reject=dict(grad=4000e-13, eog=350e-6), - preload=True) + epochs = mne.Epochs( + raw, + events, + event_id, + tmin, + tmax, + baseline=baseline, + reject=dict(grad=4000e-13, eog=350e-6), + preload=True, + ) # remove evoked response epochs.subtract_evoked() @@ -115,30 +120,34 @@ # Helper function for plotting spread def stat_fun(x): """Return sum of squares.""" - return np.sum(x ** 2, axis=0) + return np.sum(x**2, axis=0) # Plot fig, axes = plt.subplots(4, 1, figsize=(10, 7), sharex=True, sharey=True) -colors = plt.colormaps['winter_r'](np.linspace(0, 1, 4)) +colors = plt.colormaps["winter_r"](np.linspace(0, 1, 4)) for ((freq_name, fmin, fmax), average), color, ax in zip( - frequency_map, colors, axes.ravel()[::-1]): + frequency_map, colors, axes.ravel()[::-1] +): times = average.times * 1e3 - gfp = np.sum(average.data ** 2, axis=0) + gfp = np.sum(average.data**2, axis=0) gfp = mne.baseline.rescale(gfp, times, baseline=(None, 0)) ax.plot(times, gfp, label=freq_name, color=color, linewidth=2.5) - ax.axhline(0, linestyle='--', color='grey', linewidth=2) - ci_low, ci_up = bootstrap_confidence_interval(average.data, random_state=0, - stat_fun=stat_fun) + ax.axhline(0, linestyle="--", color="grey", linewidth=2) + ci_low, ci_up = bootstrap_confidence_interval( + average.data, random_state=0, stat_fun=stat_fun + ) ci_low = rescale(ci_low, average.times, baseline=(None, 0)) ci_up = rescale(ci_up, average.times, baseline=(None, 0)) ax.fill_between(times, gfp + ci_up, gfp - ci_low, color=color, alpha=0.3) ax.grid(True) - ax.set_ylabel('GFP') - ax.annotate('%s (%d-%dHz)' % (freq_name, fmin, fmax), - xy=(0.95, 0.8), - horizontalalignment='right', - xycoords='axes fraction') + ax.set_ylabel("GFP") + ax.annotate( + "%s (%d-%dHz)" % (freq_name, fmin, fmax), + xy=(0.95, 0.8), + horizontalalignment="right", + xycoords="axes fraction", + ) ax.set_xlim(-1000, 3000) -axes.ravel()[-1].set_xlabel('Time [ms]') +axes.ravel()[-1].set_xlabel("Time [ms]") diff --git a/examples/time_frequency/time_frequency_simulated.py b/examples/time_frequency/time_frequency_simulated.py index c84803d7d2f..bf8b1dba6ca 100644 --- a/examples/time_frequency/time_frequency_simulated.py +++ b/examples/time_frequency/time_frequency_simulated.py @@ -26,8 +26,13 @@ from mne import create_info, Epochs from mne.baseline import rescale from mne.io import RawArray -from mne.time_frequency import (tfr_multitaper, tfr_stockwell, tfr_morlet, - tfr_array_morlet, AverageTFR) +from mne.time_frequency import ( + tfr_multitaper, + tfr_stockwell, + tfr_morlet, + tfr_array_morlet, + AverageTFR, +) from mne.viz import centers_to_edges print(__doc__) @@ -39,8 +44,8 @@ # We'll simulate data with a known spectro-temporal structure. sfreq = 1000.0 -ch_names = ['SIM0001', 'SIM0002'] -ch_types = ['grad', 'grad'] +ch_names = ["SIM0001", "SIM0002"] +ch_types = ["grad", "grad"] info = create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types) n_times = 1024 # Just over 1 second epochs @@ -51,8 +56,8 @@ # Add a 50 Hz sinusoidal burst to the noise and ramp it. t = np.arange(n_times, dtype=np.float64) / sfreq -signal = np.sin(np.pi * 2. * 50. * t) # 50 Hz sinusoid signal -signal[np.logical_or(t < 0.45, t > 0.55)] = 0. # Hard windowing +signal = np.sin(np.pi * 2.0 * 50.0 * t) # 50 Hz sinusoid signal +signal[np.logical_or(t < 0.45, t > 0.55)] = 0.0 # Hard windowing on_time = np.logical_and(t >= 0.45, t <= 0.55) signal[on_time] *= np.hanning(on_time.sum()) # Ramping data[:, 100:-100] += np.tile(signal, n_epochs) # add signal @@ -60,8 +65,15 @@ raw = RawArray(data, info) events = np.zeros((n_epochs, 3), dtype=int) events[:, 0] = np.arange(n_epochs) * n_times -epochs = Epochs(raw, events, dict(sin50hz=0), tmin=0, tmax=n_times / sfreq, - reject=dict(grad=4000), baseline=None) +epochs = Epochs( + raw, + events, + dict(sin50hz=0), + tmin=0, + tmax=n_times / sfreq, + reject=dict(grad=4000), + baseline=None, +) epochs.average().plot() @@ -85,23 +97,39 @@ # properties, and thus a different TFR. You can trade time resolution or # frequency resolution or both in order to get a reduction in variance. -freqs = np.arange(5., 100., 3.) -vmin, vmax = -3., 3. # Define our color limits. +freqs = np.arange(5.0, 100.0, 3.0) +vmin, vmax = -3.0, 3.0 # Define our color limits. fig, axs = plt.subplots(1, 3, figsize=(15, 5), sharey=True) for n_cycles, time_bandwidth, ax, title in zip( - [freqs / 2, freqs, freqs / 2], # number of cycles - [2.0, 4.0, 8.0], # time bandwidth - axs, - ['Sim: Least smoothing, most variance', - 'Sim: Less frequency smoothing,\nmore time smoothing', - 'Sim: Less time smoothing,\nmore frequency smoothing']): - power = tfr_multitaper(epochs, freqs=freqs, n_cycles=n_cycles, - time_bandwidth=time_bandwidth, return_itc=False) + [freqs / 2, freqs, freqs / 2], # number of cycles + [2.0, 4.0, 8.0], # time bandwidth + axs, + [ + "Sim: Least smoothing, most variance", + "Sim: Less frequency smoothing,\nmore time smoothing", + "Sim: Less time smoothing,\nmore frequency smoothing", + ], +): + power = tfr_multitaper( + epochs, + freqs=freqs, + n_cycles=n_cycles, + time_bandwidth=time_bandwidth, + return_itc=False, + ) ax.set_title(title) # Plot results. Baseline correct based on first 100 ms. - power.plot([0], baseline=(0., 0.1), mode='mean', vmin=vmin, vmax=vmax, - axes=ax, show=False, colorbar=False) + power.plot( + [0], + baseline=(0.0, 0.1), + mode="mean", + vmin=vmin, + vmax=vmax, + axes=ax, + show=False, + colorbar=False, + ) plt.tight_layout() ############################################################################## @@ -119,9 +147,10 @@ fmin, fmax = freqs[[0, -1]] for width, ax in zip((0.2, 0.7, 3.0), axs): power = tfr_stockwell(epochs, fmin=fmin, fmax=fmax, width=width) - power.plot([0], baseline=(0., 0.1), mode='mean', axes=ax, show=False, - colorbar=False) - ax.set_title('Sim: Using S transform, width = {:0.1f}'.format(width)) + power.plot( + [0], baseline=(0.0, 0.1), mode="mean", axes=ax, show=False, colorbar=False + ) + ax.set_title("Sim: Using S transform, width = {:0.1f}".format(width)) plt.tight_layout() # %% @@ -134,14 +163,21 @@ # number of cycles to include in the window. fig, axs = plt.subplots(1, 3, figsize=(15, 5), sharey=True) -all_n_cycles = [1, 3, freqs / 2.] +all_n_cycles = [1, 3, freqs / 2.0] for n_cycles, ax in zip(all_n_cycles, axs): - power = tfr_morlet(epochs, freqs=freqs, - n_cycles=n_cycles, return_itc=False) - power.plot([0], baseline=(0., 0.1), mode='mean', vmin=vmin, vmax=vmax, - axes=ax, show=False, colorbar=False) - n_cycles = 'scaled by freqs' if not isinstance(n_cycles, int) else n_cycles - ax.set_title(f'Sim: Using Morlet wavelet, n_cycles = {n_cycles}') + power = tfr_morlet(epochs, freqs=freqs, n_cycles=n_cycles, return_itc=False) + power.plot( + [0], + baseline=(0.0, 0.1), + mode="mean", + vmin=vmin, + vmax=vmax, + axes=ax, + show=False, + colorbar=False, + ) + n_cycles = "scaled by freqs" if not isinstance(n_cycles, int) else n_cycles + ax.set_title(f"Sim: Using Morlet wavelet, n_cycles = {n_cycles}") plt.tight_layout() # %% @@ -154,10 +190,9 @@ # the width of this filter is recommended to be about 2 Hz. fig, axs = plt.subplots(1, 3, figsize=(15, 5), sharey=True) -bandwidths = [1., 2., 4.] +bandwidths = [1.0, 2.0, 4.0] for bandwidth, ax in zip(bandwidths, axs): - data = np.zeros((len(ch_names), freqs.size, epochs.times.size), - dtype=complex) + data = np.zeros((len(ch_names), freqs.size, epochs.times.size), dtype=complex) for idx, freq in enumerate(freqs): # Filter raw data and re-epoch to avoid the filter being longer than # the epoch data for low frequencies and short epochs, such as here. @@ -167,24 +202,37 @@ # these are all very similar because the filters are almost the same. # In practice, using the default is usually a wise choice. raw_filter.filter( - l_freq=freq - bandwidth / 2, h_freq=freq + bandwidth / 2, + l_freq=freq - bandwidth / 2, + h_freq=freq + bandwidth / 2, # no negative values for large bandwidth and low freq l_trans_bandwidth=min([4 * bandwidth, freq - bandwidth]), - h_trans_bandwidth=4 * bandwidth) + h_trans_bandwidth=4 * bandwidth, + ) raw_filter.apply_hilbert() - epochs_hilb = Epochs(raw_filter, events, tmin=0, tmax=n_times / sfreq, - baseline=(0, 0.1)) + epochs_hilb = Epochs( + raw_filter, events, tmin=0, tmax=n_times / sfreq, baseline=(0, 0.1) + ) tfr_data = epochs_hilb.get_data() tfr_data = tfr_data * tfr_data.conj() # compute power tfr_data = np.mean(tfr_data, axis=0) # average over epochs data[:, idx] = tfr_data power = AverageTFR(info, data, epochs.times, freqs, nave=n_epochs) - power.plot([0], baseline=(0., 0.1), mode='mean', vmin=-0.1, vmax=0.1, - axes=ax, show=False, colorbar=False) - n_cycles = 'scaled by freqs' if not isinstance(n_cycles, int) else n_cycles - ax.set_title('Sim: Using narrow bandpass filter Hilbert,\n' - f'bandwidth = {bandwidth}, ' - f'transition bandwidth = {4 * bandwidth}') + power.plot( + [0], + baseline=(0.0, 0.1), + mode="mean", + vmin=-0.1, + vmax=0.1, + axes=ax, + show=False, + colorbar=False, + ) + n_cycles = "scaled by freqs" if not isinstance(n_cycles, int) else n_cycles + ax.set_title( + "Sim: Using narrow bandpass filter Hilbert,\n" + f"bandwidth = {bandwidth}, " + f"transition bandwidth = {4 * bandwidth}" + ) plt.tight_layout() # %% @@ -195,13 +243,21 @@ # We can do this by using ``average=False``. In this case, an instance of # :class:`mne.time_frequency.EpochsTFR` is returned. -n_cycles = freqs / 2. -power = tfr_morlet(epochs, freqs=freqs, - n_cycles=n_cycles, return_itc=False, average=False) +n_cycles = freqs / 2.0 +power = tfr_morlet( + epochs, freqs=freqs, n_cycles=n_cycles, return_itc=False, average=False +) print(type(power)) avgpower = power.average() -avgpower.plot([0], baseline=(0., 0.1), mode='mean', vmin=vmin, vmax=vmax, - title='Using Morlet wavelets and EpochsTFR', show=False) +avgpower.plot( + [0], + baseline=(0.0, 0.1), + mode="mean", + vmin=vmin, + vmax=vmax, + title="Using Morlet wavelets and EpochsTFR", + show=False, +) # %% # Operating on arrays @@ -212,16 +268,20 @@ # ``(n_epochs, n_channels, n_times)``. They will also return a numpy array # of shape ``(n_epochs, n_channels, n_freqs, n_times)``. -power = tfr_array_morlet(epochs.get_data(), sfreq=epochs.info['sfreq'], - freqs=freqs, n_cycles=n_cycles, - output='avg_power') +power = tfr_array_morlet( + epochs.get_data(), + sfreq=epochs.info["sfreq"], + freqs=freqs, + n_cycles=n_cycles, + output="avg_power", +) # Baseline the output -rescale(power, epochs.times, (0., 0.1), mode='mean', copy=False) +rescale(power, epochs.times, (0.0, 0.1), mode="mean", copy=False) fig, ax = plt.subplots() x, y = centers_to_edges(epochs.times * 1000, freqs) -mesh = ax.pcolormesh(x, y, power[0], cmap='RdBu_r', vmin=vmin, vmax=vmax) -ax.set_title('TFR calculated on a numpy array') -ax.set(ylim=freqs[[0, -1]], xlabel='Time (ms)') +mesh = ax.pcolormesh(x, y, power[0], cmap="RdBu_r", vmin=vmin, vmax=vmax) +ax.set_title("TFR calculated on a numpy array") +ax.set(ylim=freqs[[0, -1]], xlabel="Time (ms)") fig.colorbar(mesh) plt.tight_layout() diff --git a/examples/visualization/3d_to_2d.py b/examples/visualization/3d_to_2d.py index bb692533baa..9eecc33f196 100644 --- a/examples/visualization/3d_to_2d.py +++ b/examples/visualization/3d_to_2d.py @@ -34,12 +34,12 @@ from mne.viz import plot_alignment, set_3d_view, snapshot_brain_montage misc_path = mne.datasets.misc.data_path() -subjects_dir = misc_path / 'ecog' -ecog_data_fname = subjects_dir / 'sample_ecog_ieeg.fif' +subjects_dir = misc_path / "ecog" +ecog_data_fname = subjects_dir / "sample_ecog_ieeg.fif" # We've already clicked and exported -layout_path = Path(dirname(mne.__file__)) / 'data' / 'image' -layout_name = 'custom_layout.lout' +layout_path = Path(dirname(mne.__file__)) / "data" / "image" +layout_name = "custom_layout.lout" # %% # Load data @@ -49,14 +49,14 @@ # a 2D snapshot. raw = read_raw_fif(ecog_data_fname) -raw.pick_channels([f'G{i}' for i in range(1, 257)]) # pick just one grid +raw.pick_channels([f"G{i}" for i in range(1, 257)]) # pick just one grid # Since we loaded in the ecog data from FIF, the coordinates # are in 'head' space, but we actually want them in 'mri' space. # So we will apply the head->mri transform that was used when # generating the dataset (the estimated head->mri transform). montage = raw.get_montage() -trans = mne.coreg.estimate_head_mri_t('sample_ecog', subjects_dir) +trans = mne.coreg.estimate_head_mri_t("sample_ecog", subjects_dir) montage.apply_trans(trans) # %% @@ -68,8 +68,13 @@ # with the electrode positions on that image. We use this in conjunction with # :func:`mne.viz.plot_alignment`, which visualizes electrode positions. -fig = plot_alignment(raw.info, trans=trans, subject='sample_ecog', - subjects_dir=subjects_dir, surfaces=dict(pial=0.9)) +fig = plot_alignment( + raw.info, + trans=trans, + subject="sample_ecog", + subjects_dir=subjects_dir, + surfaces=dict(pial=0.9), +) set_3d_view(figure=fig, azimuth=20, elevation=80) xy, im = snapshot_brain_montage(fig, montage) @@ -84,9 +89,9 @@ # This allows us to use matplotlib to create arbitrary 2d scatterplots fig2, ax = plt.subplots(figsize=(10, 10)) ax.imshow(im) -cmap = ax.scatter(*xy_pts.T, c=beta_power, s=100, cmap='coolwarm') +cmap = ax.scatter(*xy_pts.T, c=beta_power, s=100, cmap="coolwarm") cbar = fig2.colorbar(cmap) -cbar.ax.set_ylabel('Beta Power') +cbar.ax.set_ylabel("Beta Power") ax.set_axis_off() # fig2.savefig('./brain.png', bbox_inches='tight') # For ClickableImage @@ -126,6 +131,6 @@ y = (1 - lt.pos[:, 1]) * float(im.shape[0]) # Flip the y-position fig, ax = plt.subplots() ax.imshow(im) -ax.scatter(x, y, s=80, color='r') +ax.scatter(x, y, s=80, color="r") fig.tight_layout() ax.set_axis_off() diff --git a/examples/visualization/brain.py b/examples/visualization/brain.py index 5b31bc7b106..35a7ac77bfd 100644 --- a/examples/visualization/brain.py +++ b/examples/visualization/brain.py @@ -29,8 +29,8 @@ print(__doc__) data_path = sample.data_path() -subjects_dir = data_path / 'subjects' -sample_dir = data_path / 'MEG' / 'sample' +subjects_dir = data_path / "subjects" +sample_dir = data_path / "MEG" / "sample" # %% # Add source information @@ -38,16 +38,21 @@ # # Plot source information. -brain_kwargs = dict(alpha=0.1, background='white', cortex='low_contrast') -brain = mne.viz.Brain('sample', subjects_dir=subjects_dir, **brain_kwargs) +brain_kwargs = dict(alpha=0.1, background="white", cortex="low_contrast") +brain = mne.viz.Brain("sample", subjects_dir=subjects_dir, **brain_kwargs) -stc = mne.read_source_estimate(sample_dir / 'sample_audvis-meg') +stc = mne.read_source_estimate(sample_dir / "sample_audvis-meg") stc.crop(0.09, 0.1) -kwargs = dict(fmin=stc.data.min(), fmax=stc.data.max(), alpha=0.25, - smoothing_steps='nearest', time=stc.times) -brain.add_data(stc.lh_data, hemi='lh', vertices=stc.lh_vertno, **kwargs) -brain.add_data(stc.rh_data, hemi='rh', vertices=stc.rh_vertno, **kwargs) +kwargs = dict( + fmin=stc.data.min(), + fmax=stc.data.max(), + alpha=0.25, + smoothing_steps="nearest", + time=stc.times, +) +brain.add_data(stc.lh_data, hemi="lh", vertices=stc.lh_vertno, **kwargs) +brain.add_data(stc.rh_data, hemi="rh", vertices=stc.rh_vertno, **kwargs) # %% # Modify the view of the brain @@ -55,7 +60,7 @@ # # You can adjust the view of the brain using ``show_view`` method. -brain = mne.viz.Brain('sample', subjects_dir=subjects_dir, **brain_kwargs) +brain = mne.viz.Brain("sample", subjects_dir=subjects_dir, **brain_kwargs) brain.show_view(azimuth=190, elevation=70, distance=350, focalpoint=(0, 0, 20)) # %% @@ -73,8 +78,8 @@ # .. note:: The MNE sample dataset contains only a subselection of the # Freesurfer labels created during the ``recon-all``. -brain = mne.viz.Brain('sample', subjects_dir=subjects_dir, **brain_kwargs) -brain.add_label('BA44', hemi='lh', color='green', borders=True) +brain = mne.viz.Brain("sample", subjects_dir=subjects_dir, **brain_kwargs) +brain.add_label("BA44", hemi="lh", color="green", borders=True) brain.show_view(azimuth=190, elevation=70, distance=350, focalpoint=(0, 0, 20)) # %% @@ -83,7 +88,7 @@ # # Add a head image using the ``add_head`` method. -brain = mne.viz.Brain('sample', subjects_dir=subjects_dir, **brain_kwargs) +brain = mne.viz.Brain("sample", subjects_dir=subjects_dir, **brain_kwargs) brain.add_head(alpha=0.5) # %% @@ -93,9 +98,9 @@ # To put into context the data that generated the source time course, # the sensor positions can be displayed as well. -brain = mne.viz.Brain('sample', subjects_dir=subjects_dir, **brain_kwargs) -evoked = mne.read_evokeds(sample_dir / 'sample_audvis-ave.fif')[0] -trans = mne.read_trans(sample_dir / 'sample_audvis_raw-trans.fif') +brain = mne.viz.Brain("sample", subjects_dir=subjects_dir, **brain_kwargs) +evoked = mne.read_evokeds(sample_dir / "sample_audvis-ave.fif")[0] +trans = mne.read_trans(sample_dir / "sample_audvis_raw-trans.fif") brain.add_sensors(evoked.info, trans) brain.show_view(distance=500) # move back to show sensors @@ -106,9 +111,9 @@ # Dipole modeling as in :ref:`tut-dipole-orientations` can be plotted on the # brain as well. -brain = mne.viz.Brain('sample', subjects_dir=subjects_dir, **brain_kwargs) -dip = mne.read_dipole(sample_dir / 'sample_audvis_set1.dip') -cmap = plt.colormaps['YlOrRd'] +brain = mne.viz.Brain("sample", subjects_dir=subjects_dir, **brain_kwargs) +dip = mne.read_dipole(sample_dir / "sample_audvis_set1.dip") +cmap = plt.colormaps["YlOrRd"] colors = [cmap(gof / dip.gof.max()) for gof in dip.gof] brain.add_dipole(dip, trans, colors=colors, scales=list(dip.amplitude * 1e8)) brain.show_view(azimuth=-20, elevation=60, distance=300) @@ -123,8 +128,8 @@ fig, ax = plt.subplots() ax.imshow(img) -ax.axis('off') +ax.axis("off") cax = fig.add_axes([0.9, 0.1, 0.05, 0.8]) norm = Normalize(vmin=0, vmax=dip.gof.max()) fig.colorbar(ScalarMappable(norm=norm, cmap=cmap), cax=cax) -fig.suptitle('Dipole Fits Scaled by Amplitude and Colored by GOF') +fig.suptitle("Dipole Fits Scaled by Amplitude and Colored by GOF") diff --git a/examples/visualization/channel_epochs_image.py b/examples/visualization/channel_epochs_image.py index bb52c11c44b..618330ec44d 100644 --- a/examples/visualization/channel_epochs_image.py +++ b/examples/visualization/channel_epochs_image.py @@ -33,9 +33,9 @@ # %% # Set parameters -meg_path = data_path / 'MEG' / 'sample' -raw_fname = meg_path / 'sample_audvis_filt-0-40_raw.fif' -event_fname = meg_path / 'sample_audvis_filt-0-40_raw-eve.fif' +meg_path = data_path / "MEG" / "sample" +raw_fname = meg_path / "sample_audvis_filt-0-40_raw.fif" +event_fname = meg_path / "sample_audvis_filt-0-40_raw-eve.fif" event_id, tmin, tmax = 1, -0.2, 0.4 # Setup for reading the raw data @@ -43,12 +43,21 @@ events = mne.read_events(event_fname) # Set up pick list: EEG + MEG - bad channels (modify to your needs) -raw.info['bads'] = ['MEG 2443', 'EEG 053'] +raw.info["bads"] = ["MEG 2443", "EEG 053"] # Create epochs, here for gradiometers + EOG only for simplicity -epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=True, - picks=('grad', 'eog'), baseline=(None, 0), preload=True, - reject=dict(grad=4000e-13, eog=150e-6)) +epochs = mne.Epochs( + raw, + events, + event_id, + tmin, + tmax, + proj=True, + picks=("grad", "eog"), + baseline=(None, 0), + preload=True, + reject=dict(grad=4000e-13, eog=150e-6), +) # %% # Show event-related fields images @@ -56,26 +65,36 @@ # and order with spectral reordering # If you don't have scikit-learn installed set order_func to None from sklearn.manifold import spectral_embedding # noqa -from sklearn.metrics.pairwise import rbf_kernel # noqa +from sklearn.metrics.pairwise import rbf_kernel # noqa def order_func(times, data): this_data = data[:, (times > 0.0) & (times < 0.350)] - this_data /= np.sqrt(np.sum(this_data ** 2, axis=1))[:, np.newaxis] - return np.argsort(spectral_embedding(rbf_kernel(this_data, gamma=1.), - n_components=1, random_state=0).ravel()) + this_data /= np.sqrt(np.sum(this_data**2, axis=1))[:, np.newaxis] + return np.argsort( + spectral_embedding( + rbf_kernel(this_data, gamma=1.0), n_components=1, random_state=0 + ).ravel() + ) good_pick = 97 # channel with a clear evoked response bad_pick = 98 # channel with no evoked response # We'll also plot a sample time onset for each trial -plt_times = np.linspace(0, .2, len(epochs)) - -plt.close('all') -mne.viz.plot_epochs_image(epochs, [good_pick, bad_pick], sigma=.5, - order=order_func, vmin=-250, vmax=250, - overlay_times=plt_times, show=True) +plt_times = np.linspace(0, 0.2, len(epochs)) + +plt.close("all") +mne.viz.plot_epochs_image( + epochs, + [good_pick, bad_pick], + sigma=0.5, + order=order_func, + vmin=-250, + vmax=250, + overlay_times=plt_times, + show=True, +) # %% # References diff --git a/examples/visualization/eeg_on_scalp.py b/examples/visualization/eeg_on_scalp.py index 7ad5438b9dc..f27bc63ecdd 100644 --- a/examples/visualization/eeg_on_scalp.py +++ b/examples/visualization/eeg_on_scalp.py @@ -19,15 +19,22 @@ print(__doc__) data_path = mne.datasets.sample.data_path() -subjects_dir = data_path / 'subjects' -meg_path = data_path / 'MEG' / 'sample' -trans = mne.read_trans(meg_path / 'sample_audvis_raw-trans.fif') -raw = mne.io.read_raw_fif(meg_path / 'sample_audvis_raw.fif') +subjects_dir = data_path / "subjects" +meg_path = data_path / "MEG" / "sample" +trans = mne.read_trans(meg_path / "sample_audvis_raw-trans.fif") +raw = mne.io.read_raw_fif(meg_path / "sample_audvis_raw.fif") # Plot electrode locations on scalp -fig = plot_alignment(raw.info, trans, subject='sample', dig=False, - eeg=['original', 'projected'], meg=[], - coord_frame='head', subjects_dir=subjects_dir) +fig = plot_alignment( + raw.info, + trans, + subject="sample", + dig=False, + eeg=["original", "projected"], + meg=[], + coord_frame="head", + subjects_dir=subjects_dir, +) # Set viewing angle set_3d_view(figure=fig, azimuth=135, elevation=80) diff --git a/examples/visualization/evoked_arrowmap.py b/examples/visualization/evoked_arrowmap.py index 7ce3f1df093..294be182c7c 100644 --- a/examples/visualization/evoked_arrowmap.py +++ b/examples/visualization/evoked_arrowmap.py @@ -33,13 +33,13 @@ print(__doc__) path = sample.data_path() -fname = path / 'MEG' / 'sample' / 'sample_audvis-ave.fif' +fname = path / "MEG" / "sample" / "sample_audvis-ave.fif" # load evoked data -condition = 'Left Auditory' +condition = "Left Auditory" evoked = read_evokeds(fname, condition=condition, baseline=(None, 0)) -evoked_mag = evoked.copy().pick_types(meg='mag') -evoked_grad = evoked.copy().pick_types(meg='grad') +evoked_mag = evoked.copy().pick_types(meg="mag") +evoked_grad = evoked.copy().pick_types(meg="grad") # %% # Plot magnetometer data as an arrowmap along with the topoplot at the time @@ -57,8 +57,11 @@ # %% # Plot gradiometer data as an arrowmap along with the topoplot at the time # of the maximum sensor space activity: -plot_arrowmap(evoked_grad.data[:, max_time_idx], info_from=evoked_grad.info, - info_to=evoked_mag.info) +plot_arrowmap( + evoked_grad.data[:, max_time_idx], + info_from=evoked_grad.info, + info_to=evoked_mag.info, +) # %% # Since Vectorview 102 system perform sparse spatial sampling of the magnetic @@ -68,10 +71,14 @@ # Plot gradiometer data as an arrowmap along with the topoplot at the time # of the maximum sensor space activity: path = bst_raw.data_path() -raw_fname = (path / 'MEG' / 'bst_raw' / - 'subj001_somatosensory_20111109_01_AUX-f.ds') +raw_fname = path / "MEG" / "bst_raw" / "subj001_somatosensory_20111109_01_AUX-f.ds" raw_ctf = mne.io.read_raw_ctf(raw_fname) raw_ctf_info = mne.pick_info( - raw_ctf.info, mne.pick_types(raw_ctf.info, meg=True, ref_meg=False)) -plot_arrowmap(evoked_grad.data[:, max_time_idx], info_from=evoked_grad.info, - info_to=raw_ctf_info, scale=6e-10) + raw_ctf.info, mne.pick_types(raw_ctf.info, meg=True, ref_meg=False) +) +plot_arrowmap( + evoked_grad.data[:, max_time_idx], + info_from=evoked_grad.info, + info_to=raw_ctf_info, + scale=6e-10, +) diff --git a/examples/visualization/evoked_topomap.py b/examples/visualization/evoked_topomap.py index abeb527757e..20bb9611497 100644 --- a/examples/visualization/evoked_topomap.py +++ b/examples/visualization/evoked_topomap.py @@ -29,11 +29,11 @@ print(__doc__) path = sample.data_path() -fname = path / 'MEG' / 'sample' / 'sample_audvis-ave.fif' +fname = path / "MEG" / "sample" / "sample_audvis-ave.fif" # load evoked corresponding to a specific condition # from the fif file and subtract baseline -condition = 'Left Auditory' +condition = "Left Auditory" evoked = read_evokeds(fname, condition=condition, baseline=(None, 0)) # %% @@ -45,28 +45,28 @@ # topographies will be shown. We select timepoints from 50 to 150 ms with a # step of 20ms and plot magnetometer data: times = np.arange(0.05, 0.151, 0.02) -evoked.plot_topomap(times, ch_type='mag') +evoked.plot_topomap(times, ch_type="mag") # %% # If times is set to None at most 10 regularly spaced topographies will be # shown: -evoked.plot_topomap(ch_type='mag') +evoked.plot_topomap(ch_type="mag") # %% # We can use ``nrows`` and ``ncols`` parameter to create multiline plots # with more timepoints. all_times = np.arange(-0.2, 0.5, 0.03) -evoked.plot_topomap(all_times, ch_type='mag', ncols=8, nrows='auto') +evoked.plot_topomap(all_times, ch_type="mag", ncols=8, nrows="auto") # %% # Instead of showing topographies at specific time points we can compute # averages of 50 ms bins centered on these time points to reduce the noise in # the topographies: -evoked.plot_topomap(times, ch_type='mag', average=0.05) +evoked.plot_topomap(times, ch_type="mag", average=0.05) # %% # We can plot gradiometer data (plots the RMS for each pair of gradiometers) -evoked.plot_topomap(times, ch_type='grad') +evoked.plot_topomap(times, ch_type="grad") # %% # Additional :func:`~mne.viz.plot_topomap` options @@ -79,8 +79,7 @@ # * ``res`` - to control the resolution of the topographies (lower resolution # means faster plotting) # * ``contours`` to define how many contour lines should be plotted -evoked.plot_topomap(times, ch_type='mag', cmap='Spectral_r', res=32, - contours=4) +evoked.plot_topomap(times, ch_type="mag", cmap="Spectral_r", res=32, contours=4) # %% # If you look at the edges of the head circle of a single topomap you'll see @@ -94,17 +93,24 @@ # The default value ``extrapolate='auto'`` will use ``'local'`` for MEG sensors # and ``'head'`` otherwise. Here we show each option: -extrapolations = ['local', 'head', 'box'] +extrapolations = ["local", "head", "box"] fig, axes = plt.subplots(figsize=(7.5, 4.5), nrows=2, ncols=3) # Here we look at EEG channels, and use a custom head sphere to get all the # sensors to be well within the drawn head surface -for axes_row, ch_type in zip(axes, ('mag', 'eeg')): +for axes_row, ch_type in zip(axes, ("mag", "eeg")): for ax, extr in zip(axes_row, extrapolations): - evoked.plot_topomap(0.1, ch_type=ch_type, size=2, extrapolate=extr, - axes=ax, show=False, colorbar=False, - sphere=(0., 0., 0., 0.09)) - ax.set_title('%s %s' % (ch_type.upper(), extr), fontsize=14) + evoked.plot_topomap( + 0.1, + ch_type=ch_type, + size=2, + extrapolate=extr, + axes=ax, + show=False, + colorbar=False, + sphere=(0.0, 0.0, 0.0, 0.09), + ) + ax.set_title("%s %s" % (ch_type.upper(), extr), fontsize=14) fig.tight_layout() # %% @@ -114,10 +120,11 @@ # Now we plot magnetometer data as topomap at a single time point: 100 ms # post-stimulus, add channel labels, title and adjust plot margins: -fig = evoked.plot_topomap(0.1, ch_type='mag', show_names=True, colorbar=False, - size=6, res=128) +fig = evoked.plot_topomap( + 0.1, ch_type="mag", show_names=True, colorbar=False, size=6, res=128 +) fig.subplots_adjust(left=0.01, right=0.99, bottom=0.01, top=0.88) -fig.suptitle('Auditory response') +fig.suptitle("Auditory response") # %% # We can also highlight specific channels by adding a mask, to e.g. mark @@ -128,8 +135,8 @@ # Select times and plot times = (0.09, 0.1, 0.11) -mask_params = dict(markersize=10, markerfacecolor='y') -evoked.plot_topomap(times, ch_type='mag', mask=mask, mask_params=mask_params) +mask_params = dict(markersize=10, markerfacecolor="y") +evoked.plot_topomap(times, ch_type="mag", mask=mask, mask_params=mask_params) # %% # Or by manually picking the channels to highlight at different times: @@ -137,16 +144,17 @@ times = (0.09, 0.1, 0.11) _times = ((np.abs(evoked.times - t)).argmin() for t in times) significant_channels = [ - ('MEG 0231', 'MEG 1611', 'MEG 1621', 'MEG 1631', 'MEG 1811'), - ('MEG 2411', 'MEG 2421'), - ('MEG 1621')] + ("MEG 0231", "MEG 1611", "MEG 1621", "MEG 1631", "MEG 1811"), + ("MEG 2411", "MEG 2421"), + ("MEG 1621"), +] _channels = [np.in1d(evoked.ch_names, ch) for ch in significant_channels] -mask = np.zeros(evoked.data.shape, dtype='bool') +mask = np.zeros(evoked.data.shape, dtype="bool") for _chs, _time in zip(_channels, _times): mask[_chs, _time] = True -evoked.plot_topomap(times, ch_type='mag', mask=mask, mask_params=mask_params) +evoked.plot_topomap(times, ch_type="mag", mask=mask, mask_params=mask_params) # %% # Interpolating topomaps @@ -162,18 +170,18 @@ # The default cubic interpolation is the smoothest and is great for # publications. -evoked.plot_topomap(times, ch_type='eeg', image_interp='cubic') +evoked.plot_topomap(times, ch_type="eeg", image_interp="cubic") # %% # The linear interpolation might be helpful in some cases. -evoked.plot_topomap(times, ch_type='eeg', image_interp='linear') +evoked.plot_topomap(times, ch_type="eeg", image_interp="linear") # %% # The nearest (Voronoi, no interpolation) interpolation is especially helpful # for debugging and seeing the values assigned to the topomap unaltered. -evoked.plot_topomap(times, ch_type='eeg', image_interp='nearest', contours=0) +evoked.plot_topomap(times, ch_type="eeg", image_interp="nearest", contours=0) # %% # Animating the topomap @@ -184,5 +192,4 @@ # sphinx_gallery_thumbnail_number = 9 times = np.arange(0.05, 0.151, 0.01) -fig, anim = evoked.animate_topomap( - times=times, ch_type='mag', frame_rate=2, blit=False) +fig, anim = evoked.animate_topomap(times=times, ch_type="mag", frame_rate=2, blit=False) diff --git a/examples/visualization/evoked_whitening.py b/examples/visualization/evoked_whitening.py index 7a5f7552cc1..1d1575a83b6 100644 --- a/examples/visualization/evoked_whitening.py +++ b/examples/visualization/evoked_whitening.py @@ -35,21 +35,30 @@ # Set parameters data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -raw_fname = meg_path / 'sample_audvis_filt-0-40_raw.fif' -event_fname = meg_path / 'sample_audvis_filt-0-40_raw-eve.fif' +meg_path = data_path / "MEG" / "sample" +raw_fname = meg_path / "sample_audvis_filt-0-40_raw.fif" +event_fname = meg_path / "sample_audvis_filt-0-40_raw-eve.fif" raw = io.read_raw_fif(raw_fname, preload=True) -raw.filter(1, 40, fir_design='firwin') -raw.info['bads'] += ['MEG 2443'] # bads + 1 more +raw.filter(1, 40, fir_design="firwin") +raw.info["bads"] += ["MEG 2443"] # bads + 1 more events = mne.read_events(event_fname) # let's look at rare events, button presses event_id, tmin, tmax = 2, -0.2, 0.5 reject = dict(mag=4e-12, grad=4000e-13, eeg=80e-6) -epochs = mne.Epochs(raw, events, event_id, tmin, tmax, picks=('meg', 'eeg'), - baseline=None, reject=reject, preload=True) +epochs = mne.Epochs( + raw, + events, + event_id, + tmin, + tmax, + picks=("meg", "eeg"), + baseline=None, + reject=reject, + preload=True, +) # Uncomment next line to use fewer samples and study regularization effects # epochs = epochs[:20] # For your data, use as many samples as you can! @@ -57,24 +66,32 @@ # %% # Compute covariance using automated regularization method_params = dict(diagonal_fixed=dict(mag=0.01, grad=0.01, eeg=0.01)) -noise_covs = compute_covariance(epochs, tmin=None, tmax=0, method='auto', - return_estimators=True, n_jobs=None, - projs=None, rank=None, - method_params=method_params, verbose=True) +noise_covs = compute_covariance( + epochs, + tmin=None, + tmax=0, + method="auto", + return_estimators=True, + n_jobs=None, + projs=None, + rank=None, + method_params=method_params, + verbose=True, +) # With "return_estimator=True" all estimated covariances sorted # by log-likelihood are returned. -print('Covariance estimates sorted from best to worst') +print("Covariance estimates sorted from best to worst") for c in noise_covs: - print("%s : %s" % (c['method'], c['loglik'])) + print("%s : %s" % (c["method"], c["loglik"])) # %% # Show the evoked data: evoked = epochs.average() -evoked.plot(time_unit='s') # plot evoked response +evoked.plot(time_unit="s") # plot evoked response # %% # We can then show whitening for our various noise covariance estimates. @@ -85,4 +102,4 @@ # # For the Global field power we expect a value of 1. -evoked.plot_white(noise_covs, time_unit='s') +evoked.plot_white(noise_covs, time_unit="s") diff --git a/examples/visualization/meg_sensors.py b/examples/visualization/meg_sensors.py index 9d5ccd6411c..3685ee68543 100644 --- a/examples/visualization/meg_sensors.py +++ b/examples/visualization/meg_sensors.py @@ -17,8 +17,13 @@ import mne from mne.datasets import sample, spm_face, testing -from mne.io import (read_raw_artemis123, read_raw_bti, read_raw_ctf, - read_raw_fif, read_raw_kit) +from mne.io import ( + read_raw_artemis123, + read_raw_bti, + read_raw_ctf, + read_raw_fif, + read_raw_kit, +) from mne.viz import plot_alignment, set_3d_title print(__doc__) @@ -29,48 +34,52 @@ # Neuromag # -------- -kwargs = dict(eeg=False, coord_frame='meg', show_axes=True, verbose=True) +kwargs = dict(eeg=False, coord_frame="meg", show_axes=True, verbose=True) -raw = read_raw_fif( - sample.data_path() / 'MEG' / 'sample' / 'sample_audvis_raw.fif') -fig = plot_alignment(raw.info, meg=('helmet', 'sensors'), **kwargs) -set_3d_title(figure=fig, title='Neuromag') +raw = read_raw_fif(sample.data_path() / "MEG" / "sample" / "sample_audvis_raw.fif") +fig = plot_alignment(raw.info, meg=("helmet", "sensors"), **kwargs) +set_3d_title(figure=fig, title="Neuromag") # %% # CTF # --- raw = read_raw_ctf( - spm_face.data_path() / 'MEG' / 'spm' / 'SPM_CTF_MEG_example_faces1_3D.ds') -fig = plot_alignment(raw.info, meg=('helmet', 'sensors', 'ref'), **kwargs) -set_3d_title(figure=fig, title='CTF 275') + spm_face.data_path() / "MEG" / "spm" / "SPM_CTF_MEG_example_faces1_3D.ds" +) +fig = plot_alignment(raw.info, meg=("helmet", "sensors", "ref"), **kwargs) +set_3d_title(figure=fig, title="CTF 275") # %% # BTi # --- -bti_path = root_path / 'io' / 'bti' / 'tests' / 'data' -raw = read_raw_bti(bti_path / 'test_pdf_linux', - bti_path / 'test_config_linux', - bti_path / 'test_hs_linux') -fig = plot_alignment(raw.info, meg=('helmet', 'sensors', 'ref'), **kwargs) -set_3d_title(figure=fig, title='Magnes 3600wh') +bti_path = root_path / "io" / "bti" / "tests" / "data" +raw = read_raw_bti( + bti_path / "test_pdf_linux", + bti_path / "test_config_linux", + bti_path / "test_hs_linux", +) +fig = plot_alignment(raw.info, meg=("helmet", "sensors", "ref"), **kwargs) +set_3d_title(figure=fig, title="Magnes 3600wh") # %% # KIT # --- -kit_path = root_path / 'io' / 'kit' / 'tests' / 'data' -raw = read_raw_kit(kit_path / 'test.sqd') -fig = plot_alignment(raw.info, meg=('helmet', 'sensors'), **kwargs) -set_3d_title(figure=fig, title='KIT') +kit_path = root_path / "io" / "kit" / "tests" / "data" +raw = read_raw_kit(kit_path / "test.sqd") +fig = plot_alignment(raw.info, meg=("helmet", "sensors"), **kwargs) +set_3d_title(figure=fig, title="KIT") # %% # Artemis123 # ---------- raw = read_raw_artemis123( - testing.data_path() / 'ARTEMIS123' / - 'Artemis_Data_2017-04-14-10h-38m-59s_Phantom_1k_HPI_1s.bin') -fig = plot_alignment(raw.info, meg=('helmet', 'sensors', 'ref'), **kwargs) -set_3d_title(figure=fig, title='Artemis123') + testing.data_path() + / "ARTEMIS123" + / "Artemis_Data_2017-04-14-10h-38m-59s_Phantom_1k_HPI_1s.bin" +) +fig = plot_alignment(raw.info, meg=("helmet", "sensors", "ref"), **kwargs) +set_3d_title(figure=fig, title="Artemis123") diff --git a/examples/visualization/mne_helmet.py b/examples/visualization/mne_helmet.py index c6c155bcfd3..1085cbfc044 100644 --- a/examples/visualization/mne_helmet.py +++ b/examples/visualization/mne_helmet.py @@ -13,23 +13,44 @@ import mne sample_path = mne.datasets.sample.data_path() -subjects_dir = sample_path / 'subjects' -fname_evoked = sample_path / 'MEG' / 'sample' / 'sample_audvis-ave.fif' -fname_inv = (sample_path / 'MEG' / 'sample' / - 'sample_audvis-meg-oct-6-meg-inv.fif') -fname_trans = sample_path / 'MEG' / 'sample' / 'sample_audvis_raw-trans.fif' +subjects_dir = sample_path / "subjects" +fname_evoked = sample_path / "MEG" / "sample" / "sample_audvis-ave.fif" +fname_inv = sample_path / "MEG" / "sample" / "sample_audvis-meg-oct-6-meg-inv.fif" +fname_trans = sample_path / "MEG" / "sample" / "sample_audvis_raw-trans.fif" inv = mne.minimum_norm.read_inverse_operator(fname_inv) -evoked = mne.read_evokeds(fname_evoked, baseline=(None, 0), - proj=True, verbose=False, condition='Left Auditory') -maps = mne.make_field_map(evoked, trans=fname_trans, ch_type='meg', - subject='sample', subjects_dir=subjects_dir) +evoked = mne.read_evokeds( + fname_evoked, + baseline=(None, 0), + proj=True, + verbose=False, + condition="Left Auditory", +) +maps = mne.make_field_map( + evoked, + trans=fname_trans, + ch_type="meg", + subject="sample", + subjects_dir=subjects_dir, +) time = 0.083 fig = mne.viz.create_3d_figure((256, 256)) mne.viz.plot_alignment( - evoked.info, subject='sample', subjects_dir=subjects_dir, fig=fig, - trans=fname_trans, meg='sensors', eeg=False, surfaces='pial', - coord_frame='mri') + evoked.info, + subject="sample", + subjects_dir=subjects_dir, + fig=fig, + trans=fname_trans, + meg="sensors", + eeg=False, + surfaces="pial", + coord_frame="mri", +) evoked.plot_field(maps, time=time, fig=fig, time_label=None, vmax=5e-13) mne.viz.set_3d_view( - fig, azimuth=40, elevation=87, focalpoint=(0., -0.01, 0.04), roll=-25, - distance=0.55) + fig, + azimuth=40, + elevation=87, + focalpoint=(0.0, -0.01, 0.04), + roll=-25, + distance=0.55, +) diff --git a/examples/visualization/montage_sgskip.py b/examples/visualization/montage_sgskip.py index 96ab574499e..521e4e87a16 100644 --- a/examples/visualization/montage_sgskip.py +++ b/examples/visualization/montage_sgskip.py @@ -28,15 +28,18 @@ for current_montage in get_builtin_montages(): montage = mne.channels.make_standard_montage(current_montage) - info = mne.create_info( - ch_names=montage.ch_names, sfreq=100., ch_types='eeg') + info = mne.create_info(ch_names=montage.ch_names, sfreq=100.0, ch_types="eeg") info.set_montage(montage) - sphere = mne.make_sphere_model(r0='auto', head_radius='auto', info=info) + sphere = mne.make_sphere_model(r0="auto", head_radius="auto", info=info) fig = mne.viz.plot_alignment( # Plot options - show_axes=True, dig='fiducials', surfaces='head', + show_axes=True, + dig="fiducials", + surfaces="head", trans=mne.Transform("head", "mri", trans=np.eye(4)), # identity - bem=sphere, info=info) + bem=sphere, + info=info, + ) set_3d_view(figure=fig, azimuth=135, elevation=80) set_3d_title(figure=fig, title=current_montage) @@ -49,15 +52,19 @@ for current_montage in get_builtin_montages(): montage = mne.channels.make_standard_montage(current_montage) # Create dummy info - info = mne.create_info( - ch_names=montage.ch_names, sfreq=100., ch_types='eeg') + info = mne.create_info(ch_names=montage.ch_names, sfreq=100.0, ch_types="eeg") info.set_montage(montage) fig = mne.viz.plot_alignment( # Plot options - show_axes=True, dig='fiducials', surfaces='head', mri_fiducials=True, - subject='fsaverage', subjects_dir=subjects_dir, info=info, - coord_frame='mri', - trans='fsaverage', # transform from head coords to fsaverage's MRI + show_axes=True, + dig="fiducials", + surfaces="head", + mri_fiducials=True, + subject="fsaverage", + subjects_dir=subjects_dir, + info=info, + coord_frame="mri", + trans="fsaverage", # transform from head coords to fsaverage's MRI ) set_3d_view(figure=fig, azimuth=135, elevation=80) set_3d_title(figure=fig, title=current_montage) diff --git a/examples/visualization/parcellation.py b/examples/visualization/parcellation.py index 7118a2594b5..9e416c97c48 100644 --- a/examples/visualization/parcellation.py +++ b/examples/visualization/parcellation.py @@ -26,37 +26,58 @@ # %% import mne + Brain = mne.viz.get_brain_class() -subjects_dir = mne.datasets.sample.data_path() / 'subjects' -mne.datasets.fetch_hcp_mmp_parcellation(subjects_dir=subjects_dir, - verbose=True) +subjects_dir = mne.datasets.sample.data_path() / "subjects" +mne.datasets.fetch_hcp_mmp_parcellation(subjects_dir=subjects_dir, verbose=True) -mne.datasets.fetch_aparc_sub_parcellation(subjects_dir=subjects_dir, - verbose=True) +mne.datasets.fetch_aparc_sub_parcellation(subjects_dir=subjects_dir, verbose=True) labels = mne.read_labels_from_annot( - 'fsaverage', 'HCPMMP1', 'lh', subjects_dir=subjects_dir) - -brain = Brain('fsaverage', 'lh', 'inflated', subjects_dir=subjects_dir, - cortex='low_contrast', background='white', size=(800, 600)) -brain.add_annotation('HCPMMP1') -aud_label = [label for label in labels if label.name == 'L_A1_ROI-lh'][0] + "fsaverage", "HCPMMP1", "lh", subjects_dir=subjects_dir +) + +brain = Brain( + "fsaverage", + "lh", + "inflated", + subjects_dir=subjects_dir, + cortex="low_contrast", + background="white", + size=(800, 600), +) +brain.add_annotation("HCPMMP1") +aud_label = [label for label in labels if label.name == "L_A1_ROI-lh"][0] brain.add_label(aud_label, borders=False) # %% # We can also plot a combined set of labels (23 per hemisphere). -brain = Brain('fsaverage', 'lh', 'inflated', subjects_dir=subjects_dir, - cortex='low_contrast', background='white', size=(800, 600)) -brain.add_annotation('HCPMMP1_combined') +brain = Brain( + "fsaverage", + "lh", + "inflated", + subjects_dir=subjects_dir, + cortex="low_contrast", + background="white", + size=(800, 600), +) +brain.add_annotation("HCPMMP1_combined") # %% # We can add another custom parcellation -brain = Brain('fsaverage', 'lh', 'inflated', subjects_dir=subjects_dir, - cortex='low_contrast', background='white', size=(800, 600)) -brain.add_annotation('aparc_sub') +brain = Brain( + "fsaverage", + "lh", + "inflated", + subjects_dir=subjects_dir, + cortex="low_contrast", + background="white", + size=(800, 600), +) +brain.add_annotation("aparc_sub") # %% # References diff --git a/examples/visualization/publication_figure.py b/examples/visualization/publication_figure.py index f86cc44075d..f753c72a2c8 100644 --- a/examples/visualization/publication_figure.py +++ b/examples/visualization/publication_figure.py @@ -22,8 +22,7 @@ import numpy as np import matplotlib.pyplot as plt -from mpl_toolkits.axes_grid1 import (make_axes_locatable, ImageGrid, - inset_locator) +from mpl_toolkits.axes_grid1 import make_axes_locatable, ImageGrid, inset_locator import mne @@ -36,12 +35,12 @@ # start by loading some :ref:`example data `. data_path = mne.datasets.sample.data_path() -subjects_dir = data_path / 'subjects' -fname_stc = data_path / 'MEG' / 'sample' / 'sample_audvis-meg-eeg-lh.stc' -fname_evoked = data_path / 'MEG' / 'sample' / 'sample_audvis-ave.fif' +subjects_dir = data_path / "subjects" +fname_stc = data_path / "MEG" / "sample" / "sample_audvis-meg-eeg-lh.stc" +fname_evoked = data_path / "MEG" / "sample" / "sample_audvis-ave.fif" -evoked = mne.read_evokeds(fname_evoked, 'Left Auditory') -evoked.pick_types(meg='grad').apply_baseline((None, 0.)) +evoked = mne.read_evokeds(fname_evoked, "Left Auditory") +evoked.pick_types(meg="grad").apply_baseline((None, 0.0)) max_t = evoked.get_peak()[1] stc = mne.read_source_estimate(fname_stc) @@ -51,9 +50,16 @@ evoked.plot() -stc.plot(views='lat', hemi='split', size=(800, 400), subject='sample', - subjects_dir=subjects_dir, initial_time=max_t, - time_viewer=False, show_traces=False) +stc.plot( + views="lat", + hemi="split", + size=(800, 400), + subject="sample", + subjects_dir=subjects_dir, + initial_time=max_t, + time_viewer=False, + show_traces=False, +) # %% # To make a publication-ready figure, first we'll re-plot the brain on a white @@ -61,14 +67,24 @@ # While we're at it, let's change the colormap, set custom colormap limits and # remove the default colorbar (so we can add a smaller, vertical one later): -colormap = 'viridis' -clim = dict(kind='value', lims=[4, 8, 12]) +colormap = "viridis" +clim = dict(kind="value", lims=[4, 8, 12]) # Plot the STC, get the brain image, crop it: -brain = stc.plot(views='lat', hemi='split', size=(800, 400), subject='sample', - subjects_dir=subjects_dir, initial_time=max_t, background='w', - colorbar=False, clim=clim, colormap=colormap, - time_viewer=False, show_traces=False) +brain = stc.plot( + views="lat", + hemi="split", + size=(800, 400), + subject="sample", + subjects_dir=subjects_dir, + initial_time=max_t, + background="w", + colorbar=False, + clim=clim, + colormap=colormap, + time_viewer=False, + show_traces=False, +) screenshot = brain.screenshot() brain.close() @@ -87,10 +103,11 @@ # before/after results fig = plt.figure(figsize=(4, 4)) axes = ImageGrid(fig, 111, nrows_ncols=(2, 1), axes_pad=0.5) -for ax, image, title in zip(axes, [screenshot, cropped_screenshot], - ['Before', 'After']): +for ax, image, title in zip( + axes, [screenshot, cropped_screenshot], ["Before", "After"] +): ax.imshow(image) - ax.set_title('{} cropping'.format(title)) + ax.set_title("{} cropping".format(title)) # %% # A lot of figure settings can be adjusted after the figure is created, but @@ -99,14 +116,16 @@ # script generates several figures that you want to all have the same style: # Tweak the figure style -plt.rcParams.update({ - 'ytick.labelsize': 'small', - 'xtick.labelsize': 'small', - 'axes.labelsize': 'small', - 'axes.titlesize': 'medium', - 'grid.color': '0.75', - 'grid.linestyle': ':', -}) +plt.rcParams.update( + { + "ytick.labelsize": "small", + "xtick.labelsize": "small", + "axes.labelsize": "small", + "axes.titlesize": "medium", + "grid.color": "0.75", + "grid.linestyle": ":", + } +) # %% # Now let's create our custom figure. There are lots of ways to do this step. @@ -119,8 +138,9 @@ # sphinx_gallery_thumbnail_number = 4 # figsize unit is inches -fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(4.5, 3.), - gridspec_kw=dict(height_ratios=[3, 4])) +fig, axes = plt.subplots( + nrows=2, ncols=1, figsize=(4.5, 3.0), gridspec_kw=dict(height_ratios=[3, 4]) +) # alternate way #1: using subplot2grid # fig = plt.figure(figsize=(4.5, 3.)) @@ -138,42 +158,55 @@ # plot the evoked in the desired subplot, and add a line at peak activation evoked.plot(axes=axes[evoked_idx]) -peak_line = axes[evoked_idx].axvline(max_t, color='#66CCEE', ls='--') +peak_line = axes[evoked_idx].axvline(max_t, color="#66CCEE", ls="--") # custom legend axes[evoked_idx].legend( - [axes[evoked_idx].lines[0], peak_line], ['MEG data', 'Peak time'], - frameon=True, columnspacing=0.1, labelspacing=0.1, - fontsize=8, fancybox=True, handlelength=1.8) + [axes[evoked_idx].lines[0], peak_line], + ["MEG data", "Peak time"], + frameon=True, + columnspacing=0.1, + labelspacing=0.1, + fontsize=8, + fancybox=True, + handlelength=1.8, +) # remove the "N_ave" annotation for text in list(axes[evoked_idx].texts): text.remove() # Remove spines and add grid axes[evoked_idx].grid(True) axes[evoked_idx].set_axisbelow(True) -for key in ('top', 'right'): +for key in ("top", "right"): axes[evoked_idx].spines[key].set(visible=False) # Tweak the ticks and limits axes[evoked_idx].set( - yticks=np.arange(-200, 201, 100), xticks=np.arange(-0.2, 0.51, 0.1)) -axes[evoked_idx].set( - ylim=[-225, 225], xlim=[-0.2, 0.5]) + yticks=np.arange(-200, 201, 100), xticks=np.arange(-0.2, 0.51, 0.1) +) +axes[evoked_idx].set(ylim=[-225, 225], xlim=[-0.2, 0.5]) # now add the brain to the lower axes axes[brain_idx].imshow(cropped_screenshot) -axes[brain_idx].axis('off') +axes[brain_idx].axis("off") # add a vertical colorbar with the same properties as the 3D one divider = make_axes_locatable(axes[brain_idx]) -cax = divider.append_axes('right', size='5%', pad=0.2) -cbar = mne.viz.plot_brain_colorbar(cax, clim, colormap, label='Activation (F)') +cax = divider.append_axes("right", size="5%", pad=0.2) +cbar = mne.viz.plot_brain_colorbar(cax, clim, colormap, label="Activation (F)") # tweak margins and spacing -fig.subplots_adjust( - left=0.15, right=0.9, bottom=0.01, top=0.9, wspace=0.1, hspace=0.5) +fig.subplots_adjust(left=0.15, right=0.9, bottom=0.01, top=0.9, wspace=0.1, hspace=0.5) # add subplot labels -for ax, label in zip(axes, 'AB'): - ax.text(0.03, ax.get_position().ymax, label, transform=fig.transFigure, - fontsize=12, fontweight='bold', va='top', ha='left') +for ax, label in zip(axes, "AB"): + ax.text( + 0.03, + ax.get_position().ymax, + label, + transform=fig.transFigure, + fontsize=12, + fontweight="bold", + va="top", + ha="left", + ) # %% # Custom timecourse with montage inset @@ -206,10 +239,9 @@ to_plot = [f"EEG {i:03}" for i in range(1, 5)] # get the data for plotting in a short time interval from 10 to 20 seconds -start = int(raw.info['sfreq'] * 10) -stop = int(raw.info['sfreq'] * 20) -data, times = raw.get_data(picks=to_plot, - start=start, stop=stop, return_times=True) +start = int(raw.info["sfreq"] * 10) +stop = int(raw.info["sfreq"] * 20) +data, times = raw.get_data(picks=to_plot, start=start, stop=stop, return_times=True) # Scale the data from the MNE internal unit V to µV data *= 1e6 diff --git a/examples/visualization/roi_erpimage_by_rt.py b/examples/visualization/roi_erpimage_by_rt.py index e803b3cb14b..a8d2bae8d58 100644 --- a/examples/visualization/roi_erpimage_by_rt.py +++ b/examples/visualization/roi_erpimage_by_rt.py @@ -31,24 +31,48 @@ # %% # Load EEGLAB example data (a small EEG dataset) data_path = mne.datasets.testing.data_path() -fname = data_path / 'EEGLAB' / 'test_raw.set' +fname = data_path / "EEGLAB" / "test_raw.set" event_id = {"rt": 1, "square": 2} # must be specified for str events raw = mne.io.read_raw_eeglab(fname) mapping = { - 'EEG 000': 'Fpz', 'EEG 001': 'EOG1', 'EEG 002': 'F3', 'EEG 003': 'Fz', - 'EEG 004': 'F4', 'EEG 005': 'EOG2', 'EEG 006': 'FC5', 'EEG 007': 'FC1', - 'EEG 008': 'FC2', 'EEG 009': 'FC6', 'EEG 010': 'T7', 'EEG 011': 'C3', - 'EEG 012': 'C4', 'EEG 013': 'Cz', 'EEG 014': 'T8', 'EEG 015': 'CP5', - 'EEG 016': 'CP1', 'EEG 017': 'CP2', 'EEG 018': 'CP6', 'EEG 019': 'P7', - 'EEG 020': 'P3', 'EEG 021': 'Pz', 'EEG 022': 'P4', 'EEG 023': 'P8', - 'EEG 024': 'PO7', 'EEG 025': 'PO3', 'EEG 026': 'POz', 'EEG 027': 'PO4', - 'EEG 028': 'PO8', 'EEG 029': 'O1', 'EEG 030': 'Oz', 'EEG 031': 'O2' + "EEG 000": "Fpz", + "EEG 001": "EOG1", + "EEG 002": "F3", + "EEG 003": "Fz", + "EEG 004": "F4", + "EEG 005": "EOG2", + "EEG 006": "FC5", + "EEG 007": "FC1", + "EEG 008": "FC2", + "EEG 009": "FC6", + "EEG 010": "T7", + "EEG 011": "C3", + "EEG 012": "C4", + "EEG 013": "Cz", + "EEG 014": "T8", + "EEG 015": "CP5", + "EEG 016": "CP1", + "EEG 017": "CP2", + "EEG 018": "CP6", + "EEG 019": "P7", + "EEG 020": "P3", + "EEG 021": "Pz", + "EEG 022": "P4", + "EEG 023": "P8", + "EEG 024": "PO7", + "EEG 025": "PO3", + "EEG 026": "POz", + "EEG 027": "PO4", + "EEG 028": "PO8", + "EEG 029": "O1", + "EEG 030": "Oz", + "EEG 031": "O2", } raw.rename_channels(mapping) -raw.set_channel_types({"EOG1": 'eog', "EOG2": 'eog'}) -raw.set_montage('standard_1020') +raw.set_channel_types({"EOG1": "eog", "EOG2": "eog"}) +raw.set_montage("standard_1020") events = mne.events_from_annotations(raw, event_id)[0] @@ -61,11 +85,11 @@ tmax = 0.7 sfreq = raw.info["sfreq"] reference_id, target_id = 2, 1 -new_events, rts = define_target_events(events, reference_id, target_id, sfreq, - tmin=0., tmax=tmax, new_id=2) +new_events, rts = define_target_events( + events, reference_id, target_id, sfreq, tmin=0.0, tmax=tmax, new_id=2 +) -epochs = mne.Epochs(raw, events=new_events, tmax=tmax + 0.1, - event_id={"square": 2}) +epochs = mne.Epochs(raw, events=new_events, tmax=tmax + 0.1, event_id={"square": 2}) # %% # Plot using :term:`global field power` @@ -76,13 +100,23 @@ selections = make_1020_channel_selections(epochs.info, midline="12z") # The actual plots (GFP) -epochs.plot_image(group_by=selections, order=order, sigma=1.5, - overlay_times=rts / 1000., combine='gfp', - ts_args=dict(vlines=[0, rts.mean() / 1000.])) +epochs.plot_image( + group_by=selections, + order=order, + sigma=1.5, + overlay_times=rts / 1000.0, + combine="gfp", + ts_args=dict(vlines=[0, rts.mean() / 1000.0]), +) # %% # Plot using median -epochs.plot_image(group_by=selections, order=order, sigma=1.5, - overlay_times=rts / 1000., combine='median', - ts_args=dict(vlines=[0, rts.mean() / 1000.])) +epochs.plot_image( + group_by=selections, + order=order, + sigma=1.5, + overlay_times=rts / 1000.0, + combine="median", + ts_args=dict(vlines=[0, rts.mean() / 1000.0]), +) diff --git a/examples/visualization/sensor_noise_level.py b/examples/visualization/sensor_noise_level.py index 55b220ba1c0..ca5c70d0233 100644 --- a/examples/visualization/sensor_noise_level.py +++ b/examples/visualization/sensor_noise_level.py @@ -19,13 +19,14 @@ data_path = mne.datasets.sample.data_path() raw_erm = mne.io.read_raw_fif( - data_path / 'MEG' / 'sample' / 'ernoise_raw.fif', preload=True + data_path / "MEG" / "sample" / "ernoise_raw.fif", preload=True ) # %% # We can plot the absolute noise levels: -raw_erm.compute_psd(tmax=10).plot(average=True, spatial_colors=False, - dB=False, xscale='log') +raw_erm.compute_psd(tmax=10).plot( + average=True, spatial_colors=False, dB=False, xscale="log" +) # %% # References # ---------- diff --git a/examples/visualization/ssp_projs_sensitivity_map.py b/examples/visualization/ssp_projs_sensitivity_map.py index 2c8259d7a24..d51c498e423 100644 --- a/examples/visualization/ssp_projs_sensitivity_map.py +++ b/examples/visualization/ssp_projs_sensitivity_map.py @@ -24,10 +24,10 @@ data_path = sample.data_path() -subjects_dir = data_path / 'subjects' -meg_path = data_path / 'MEG' / 'sample' -fname = meg_path / 'sample_audvis-meg-eeg-oct-6-fwd.fif' -ecg_fname = meg_path / 'sample_audvis_ecg-proj.fif' +subjects_dir = data_path / "subjects" +meg_path = data_path / "MEG" / "sample" +fname = meg_path / "sample_audvis-meg-eeg-oct-6-fwd.fif" +ecg_fname = meg_path / "sample_audvis_ecg-proj.fif" fwd = read_forward_solution(fname) @@ -36,7 +36,7 @@ projs = projs[::2] # Compute sensitivity map -ssp_ecg_map = sensitivity_map(fwd, ch_type='grad', projs=projs, mode='angle') +ssp_ecg_map = sensitivity_map(fwd, ch_type="grad", projs=projs, mode="angle") # %% # Show sensitivity map @@ -44,6 +44,10 @@ plt.hist(ssp_ecg_map.data.ravel()) plt.show() -args = dict(clim=dict(kind='value', lims=(0.2, 0.6, 1.)), smoothing_steps=7, - hemi='rh', subjects_dir=subjects_dir) -ssp_ecg_map.plot(subject='sample', time_label='ECG SSP sensitivity', **args) +args = dict( + clim=dict(kind="value", lims=(0.2, 0.6, 1.0)), + smoothing_steps=7, + hemi="rh", + subjects_dir=subjects_dir, +) +ssp_ecg_map.plot(subject="sample", time_label="ECG SSP sensitivity", **args) diff --git a/examples/visualization/topo_compare_conditions.py b/examples/visualization/topo_compare_conditions.py index 6687ba37576..742565fc1fd 100644 --- a/examples/visualization/topo_compare_conditions.py +++ b/examples/visualization/topo_compare_conditions.py @@ -31,9 +31,9 @@ # %% # Set parameters -meg_path = data_path / 'MEG' / 'sample' -raw_fname = meg_path / 'sample_audvis_filt-0-40_raw.fif' -event_fname = meg_path / 'sample_audvis_filt-0-40_raw-eve.fif' +meg_path = data_path / "MEG" / "sample" +raw_fname = meg_path / "sample_audvis_filt-0-40_raw.fif" +event_fname = meg_path / "sample_audvis_filt-0-40_raw-eve.fif" tmin = -0.2 tmax = 0.5 @@ -45,20 +45,20 @@ reject = dict(grad=4000e-13, mag=4e-12) # Create epochs including different events -event_id = {'audio/left': 1, 'audio/right': 2, - 'visual/left': 3, 'visual/right': 4} -epochs = mne.Epochs(raw, events, event_id, tmin, tmax, - picks='meg', baseline=(None, 0), reject=reject) +event_id = {"audio/left": 1, "audio/right": 2, "visual/left": 3, "visual/right": 4} +epochs = mne.Epochs( + raw, events, event_id, tmin, tmax, picks="meg", baseline=(None, 0), reject=reject +) # Generate list of evoked objects from conditions names -evokeds = [epochs[name].average() for name in ('left', 'right')] +evokeds = [epochs[name].average() for name in ("left", "right")] # %% # Show topography for two different conditions -colors = 'blue', 'red' -title = 'MNE sample data\nleft vs right (A/V combined)' +colors = "blue", "red" +title = "MNE sample data\nleft vs right (A/V combined)" -plot_evoked_topo(evokeds, color=colors, title=title, background_color='w') +plot_evoked_topo(evokeds, color=colors, title=title, background_color="w") plt.show() diff --git a/examples/visualization/topo_customized.py b/examples/visualization/topo_customized.py index e9106a1e8d2..02c0435b25f 100644 --- a/examples/visualization/topo_customized.py +++ b/examples/visualization/topo_customized.py @@ -30,18 +30,17 @@ print(__doc__) data_path = sample.data_path() -meg_path = data_path / 'MEG' / 'sample' -raw_fname = meg_path / 'sample_audvis_filt-0-40_raw.fif' +meg_path = data_path / "MEG" / "sample" +raw_fname = meg_path / "sample_audvis_filt-0-40_raw.fif" raw = io.read_raw_fif(raw_fname, preload=True) -raw.filter(1, 20, fir_design='firwin') +raw.filter(1, 20, fir_design="firwin") picks = mne.pick_types(raw.info, meg=True, exclude=[]) tmin, tmax = 0, 120 # use the first 120s of data fmin, fmax = 2, 20 # look at frequencies between 2 and 20Hz n_fft = 2048 # the FFT size (n_fft). Ideally a power of 2 -spectrum = raw.compute_psd( - picks=picks, tmin=tmin, tmax=tmax, fmin=fmin, fmax=fmax) +spectrum = raw.compute_psd(picks=picks, tmin=tmin, tmax=tmax, fmin=fmin, fmax=fmax) psds, freqs = spectrum.get_data(exclude=(), return_freqs=True) psds = 20 * np.log10(psds) # scale to dB @@ -53,17 +52,19 @@ def my_callback(ax, ch_idx): in the plot. To work with the viz internals, this function should only take two parameters, the axis and the channel or data index. """ - ax.plot(freqs, psds[ch_idx], color='red') - ax.set_xlabel('Frequency (Hz)') - ax.set_ylabel('Power (dB)') + ax.plot(freqs, psds[ch_idx], color="red") + ax.set_xlabel("Frequency (Hz)") + ax.set_ylabel("Power (dB)") -for ax, idx in iter_topography(raw.info, - fig_facecolor='white', - axis_facecolor='white', - axis_spinecolor='white', - on_pick=my_callback): - ax.plot(psds[idx], color='red') +for ax, idx in iter_topography( + raw.info, + fig_facecolor="white", + axis_facecolor="white", + axis_spinecolor="white", + on_pick=my_callback, +): + ax.plot(psds[idx], color="red") -plt.gcf().suptitle('Power spectral densities') +plt.gcf().suptitle("Power spectral densities") plt.show() diff --git a/examples/visualization/xhemi.py b/examples/visualization/xhemi.py index bb5a4971d4d..693d702629c 100644 --- a/examples/visualization/xhemi.py +++ b/examples/visualization/xhemi.py @@ -19,26 +19,31 @@ import mne data_dir = mne.datasets.sample.data_path() -subjects_dir = data_dir / 'subjects' -stc_path = data_dir / 'MEG' / 'sample' / 'sample_audvis-meg-eeg' -stc = mne.read_source_estimate(stc_path, 'sample') +subjects_dir = data_dir / "subjects" +stc_path = data_dir / "MEG" / "sample" / "sample_audvis-meg-eeg" +stc = mne.read_source_estimate(stc_path, "sample") # First, morph the data to fsaverage_sym, for which we have left_right # registrations: -stc = mne.compute_source_morph(stc, 'sample', 'fsaverage_sym', smooth=5, - warn=False, - subjects_dir=subjects_dir).apply(stc) +stc = mne.compute_source_morph( + stc, "sample", "fsaverage_sym", smooth=5, warn=False, subjects_dir=subjects_dir +).apply(stc) # Compute a morph-matrix mapping the right to the left hemisphere, # and vice-versa. -morph = mne.compute_source_morph(stc, 'fsaverage_sym', 'fsaverage_sym', - spacing=stc.vertices, warn=False, - subjects_dir=subjects_dir, xhemi=True, - verbose='error') # creating morph map +morph = mne.compute_source_morph( + stc, + "fsaverage_sym", + "fsaverage_sym", + spacing=stc.vertices, + warn=False, + subjects_dir=subjects_dir, + xhemi=True, + verbose="error", +) # creating morph map stc_xhemi = morph.apply(stc) # Now we can subtract them and plot the result: diff = stc - stc_xhemi -diff.plot(hemi='lh', subjects_dir=subjects_dir, initial_time=0.07, - size=(800, 600)) +diff.plot(hemi="lh", subjects_dir=subjects_dir, initial_time=0.07, size=(800, 600)) diff --git a/logo/generate_mne_logos.py b/logo/generate_mne_logos.py index 072710182be..34b77788750 100644 --- a/logo/generate_mne_logos.py +++ b/logo/generate_mne_logos.py @@ -23,18 +23,24 @@ dpi = 300 center_fudge = np.array([15, 30]) # compensate for font bounding box padding tagline_scale_fudge = 0.97 # to get justification right -tagline_offset_fudge = np.array([0, -100.]) +tagline_offset_fudge = np.array([0, -100.0]) # font, etc -rcp = {'font.sans-serif': ['Primetime'], 'font.style': 'normal', - 'font.weight': 'black', 'font.variant': 'normal', 'figure.dpi': dpi, - 'savefig.dpi': dpi, 'contour.negative_linestyle': 'solid'} +rcp = { + "font.sans-serif": ["Primetime"], + "font.style": "normal", + "font.weight": "black", + "font.variant": "normal", + "figure.dpi": dpi, + "savefig.dpi": dpi, + "contour.negative_linestyle": "solid", +} plt.rcdefaults() rcParams.update(rcp) # initialize figure (no axes, margins, etc) fig = plt.figure(1, figsize=(5, 2.25), frameon=False, dpi=dpi) -ax = plt.Axes(fig, [0., 0., 1., 1.]) +ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0]) ax.set_axis_off() fig.add_axes(ax) @@ -44,10 +50,12 @@ y = np.arange(-3.0, 3.0, delta) X, Y = np.meshgrid(x, y) xy = np.array([X, Y]).transpose(1, 2, 0) -Z1 = multivariate_normal.pdf(xy, mean=[-5.0, 0.9], - cov=np.array([[8.0, 1.0], [1.0, 7.0]]) ** 2) -Z2 = multivariate_normal.pdf(xy, mean=[2.6, -2.5], - cov=np.array([[15.0, 2.5], [2.5, 2.5]]) ** 2) +Z1 = multivariate_normal.pdf( + xy, mean=[-5.0, 0.9], cov=np.array([[8.0, 1.0], [1.0, 7.0]]) ** 2 +) +Z2 = multivariate_normal.pdf( + xy, mean=[2.6, -2.5], cov=np.array([[15.0, 2.5], [2.5, 2.5]]) ** 2 +) Z = Z2 - 0.7 * Z1 # color map: field gradient (yellow-red-gray-blue-cyan) @@ -56,36 +64,46 @@ # 'blue': ((0, 0, 0), (0.4, 0, 0), (0.5, 0.5, 0.5), (0.6, 1, 1), (1, 1, 1)), # noqa # 'green': ((0, 1, 1), (0.4, 0, 0), (0.5, 0.5, 0.5), (0.6, 0, 0), (1, 1, 1)), # noqa # } -yrtbc = {'red': ((0.0, 1.0, 1.0), (0.5, 1.0, 0.0), (1.0, 0.0, 0.0)), - 'blue': ((0.0, 0.0, 0.0), (0.5, 0.0, 1.0), (1.0, 1.0, 1.0)), - 'green': ((0.0, 1.0, 1.0), (0.5, 0.0, 0.0), (1.0, 1.0, 1.0)), - 'alpha': ((0.0, 1.0, 1.0), (0.4, 0.8, 0.8), (0.5, 0.2, 0.2), - (0.6, 0.8, 0.8), (1.0, 1.0, 1.0))} +yrtbc = { + "red": ((0.0, 1.0, 1.0), (0.5, 1.0, 0.0), (1.0, 0.0, 0.0)), + "blue": ((0.0, 0.0, 0.0), (0.5, 0.0, 1.0), (1.0, 1.0, 1.0)), + "green": ((0.0, 1.0, 1.0), (0.5, 0.0, 0.0), (1.0, 1.0, 1.0)), + "alpha": ( + (0.0, 1.0, 1.0), + (0.4, 0.8, 0.8), + (0.5, 0.2, 0.2), + (0.6, 0.8, 0.8), + (1.0, 1.0, 1.0), + ), +} # color map: field lines (red | blue) -redbl = {'red': ((0., 1., 1.), (0.5, 1., 0.), (1., 0., 0.)), - 'blue': ((0., 0., 0.), (0.5, 0., 1.), (1., 1., 1.)), - 'green': ((0., 0., 0.), (1., 0., 0.)), - 'alpha': ((0., 0.4, 0.4), (1., 0.4, 0.4))} -mne_field_grad_cols = LinearSegmentedColormap('mne_grad', yrtbc) -mne_field_line_cols = LinearSegmentedColormap('mne_line', redbl) +redbl = { + "red": ((0.0, 1.0, 1.0), (0.5, 1.0, 0.0), (1.0, 0.0, 0.0)), + "blue": ((0.0, 0.0, 0.0), (0.5, 0.0, 1.0), (1.0, 1.0, 1.0)), + "green": ((0.0, 0.0, 0.0), (1.0, 0.0, 0.0)), + "alpha": ((0.0, 0.4, 0.4), (1.0, 0.4, 0.4)), +} +mne_field_grad_cols = LinearSegmentedColormap("mne_grad", yrtbc) +mne_field_line_cols = LinearSegmentedColormap("mne_line", redbl) # plot gradient and contour lines -im = ax.imshow(Z, cmap=mne_field_grad_cols, aspect='equal', zorder=1) +im = ax.imshow(Z, cmap=mne_field_grad_cols, aspect="equal", zorder=1) cs = ax.contour(Z, 9, cmap=mne_field_line_cols, linewidths=1, zorder=1) xlim, ylim = ax.get_xbound(), ax.get_ybound() plot_dims = np.r_[np.diff(xlim), np.diff(ylim)] rect = Rectangle( - [xlim[0], ylim[0]], plot_dims[0], plot_dims[1], facecolor='w', zorder=0.5) + [xlim[0], ylim[0]], plot_dims[0], plot_dims[1], facecolor="w", zorder=0.5 +) # create MNE clipping mask -mne_path = TextPath((0, 0), 'MNE') +mne_path = TextPath((0, 0), "MNE") dims = mne_path.vertices.max(0) - mne_path.vertices.min(0) -vert = mne_path.vertices - dims / 2. +vert = mne_path.vertices - dims / 2.0 mult = (plot_dims / dims).min() mult = [mult, -mult] # y axis is inverted (origin at top left) -offset = plot_dims / 2. - center_fudge +offset = plot_dims / 2.0 - center_fudge mne_clip = Path(offset + vert * mult, mne_path.codes) -ax.add_patch(PathPatch(mne_clip, color='w', zorder=0, linewidth=0)) +ax.add_patch(PathPatch(mne_clip, color="w", zorder=0, linewidth=0)) # apply clipping mask to field gradient and lines im.set_clip_path(mne_clip, transform=im.get_transform()) ax.add_patch(rect) @@ -96,64 +114,78 @@ mne_corners = mne_clip.get_extents().corners() # add tagline -rcParams.update({'font.sans-serif': ['Cooper Hewitt'], 'font.weight': '300'}) -tag_path = TextPath((0, 0), 'MEG + EEG ANALYSIS & VISUALIZATION') +rcParams.update({"font.sans-serif": ["Cooper Hewitt"], "font.weight": "300"}) +tag_path = TextPath((0, 0), "MEG + EEG ANALYSIS & VISUALIZATION") dims = tag_path.vertices.max(0) - tag_path.vertices.min(0) -vert = tag_path.vertices - dims / 2. +vert = tag_path.vertices - dims / 2.0 mult = tagline_scale_fudge * (plot_dims / dims).min() mult = [mult, -mult] # y axis is inverted -offset = mne_corners[-1] - np.array([mne_clip.get_extents().size[0] / 2., - -dims[1]]) - tagline_offset_fudge +offset = ( + mne_corners[-1] + - np.array([mne_clip.get_extents().size[0] / 2.0, -dims[1]]) + - tagline_offset_fudge +) tag_clip = Path(offset + vert * mult, tag_path.codes) -tag_patch = PathPatch(tag_clip, facecolor='k', edgecolor='none', zorder=10) +tag_patch = PathPatch(tag_clip, facecolor="k", edgecolor="none", zorder=10) ax.add_patch(tag_patch) yl = ax.get_ylim() -yy = np.max([tag_clip.vertices.max(0)[-1], - tag_clip.vertices.min(0)[-1]]) +yy = np.max([tag_clip.vertices.max(0)[-1], tag_clip.vertices.min(0)[-1]]) ax.set_ylim(np.ceil(yy), yl[-1]) # only save actual image extent plus a bit of padding plt.draw() -static_dir = op.join(op.dirname(__file__), '..', 'doc', '_static') +static_dir = op.join(op.dirname(__file__), "..", "doc", "_static") assert op.isdir(static_dir) -plt.savefig(op.join(static_dir, 'mne_logo.svg'), transparent=True) -tag_patch.set_facecolor('w') -rect.set_facecolor('0.5') -plt.savefig(op.join(static_dir, 'mne_logo_dark.svg'), transparent=True) -tag_patch.set_facecolor('k') -rect.set_facecolor('w') +plt.savefig(op.join(static_dir, "mne_logo.svg"), transparent=True) +tag_patch.set_facecolor("w") +rect.set_facecolor("0.5") +plt.savefig(op.join(static_dir, "mne_logo_dark.svg"), transparent=True) +tag_patch.set_facecolor("k") +rect.set_facecolor("w") # modify to make the splash screen -data_dir = op.join(op.dirname(__file__), '..', 'mne', 'icons') -ax.patches[-1].set_facecolor('w') +data_dir = op.join(op.dirname(__file__), "..", "mne", "icons") +ax.patches[-1].set_facecolor("w") for coll in list(ax.collections): coll.remove() -bounds = np.array([ - [mne_path.vertices[:, ii].min(), mne_path.vertices[:, ii].max()] - for ii in range(2)]) -bounds *= (plot_dims / dims) +bounds = np.array( + [ + [mne_path.vertices[:, ii].min(), mne_path.vertices[:, ii].max()] + for ii in range(2) + ] +) +bounds *= plot_dims / dims xy = np.mean(bounds, axis=1) - [100, 0] r = np.diff(bounds, axis=1).max() * 1.2 w, h = r, r * (2 / 3) box_xy = [xy[0] - w * 0.5, xy[1] - h * (2 / 5)] ax.set_ylim(box_xy[1] + h * 1.001, box_xy[1] - h * 0.001) patch = FancyBboxPatch( - box_xy, w, h, clip_on=False, zorder=-1, fc='k', ec='none', alpha=0.75, - boxstyle="round,rounding_size=200.0", mutation_aspect=1) + box_xy, + w, + h, + clip_on=False, + zorder=-1, + fc="k", + ec="none", + alpha=0.75, + boxstyle="round,rounding_size=200.0", + mutation_aspect=1, +) ax.add_patch(patch) fig.set_size_inches((512 / dpi, 512 * (h / w) / dpi)) -plt.savefig(op.join(data_dir, 'mne_splash.png'), transparent=True) +plt.savefig(op.join(data_dir, "mne_splash.png"), transparent=True) patch.remove() # modify to make an icon ax.patches.pop(-1) # no tag line for our icon -patch = Ellipse(xy, r, r, clip_on=False, zorder=-1, fc='k') +patch = Ellipse(xy, r, r, clip_on=False, zorder=-1, fc="k") ax.add_patch(patch) ax.set_ylim(xy[1] + r / 1.9, xy[1] - r / 1.9) fig.set_size_inches((256 / dpi, 256 / dpi)) # Qt does not support clip paths in SVG rendering so we have to use PNG here # then use "optipng -o7" on it afterward (14% reduction in file size) -plt.savefig(op.join(data_dir, 'mne_default_icon.png'), transparent=True) +plt.savefig(op.join(data_dir, "mne_default_icon.png"), transparent=True) plt.close() # 188x45 image @@ -162,31 +194,33 @@ h_px = 45 center_fudge = np.array([60, 0]) scale_fudge = 2.1 -rcParams.update({'font.sans-serif': ['Primetime'], 'font.weight': 'black'}) -x = np.linspace(-1., 1., w_px // 2) -y = np.linspace(-1., 1., h_px // 2) +rcParams.update({"font.sans-serif": ["Primetime"], "font.weight": "black"}) +x = np.linspace(-1.0, 1.0, w_px // 2) +y = np.linspace(-1.0, 1.0, h_px // 2) X, Y = np.meshgrid(x, y) # initialize figure (no axes, margins, etc) -fig = plt.figure(1, figsize=(w_px / dpi, h_px / dpi), facecolor='k', - frameon=False, dpi=dpi) -ax = plt.Axes(fig, [0., 0., 1., 1.]) +fig = plt.figure( + 1, figsize=(w_px / dpi, h_px / dpi), facecolor="k", frameon=False, dpi=dpi +) +ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0]) ax.set_axis_off() fig.add_axes(ax) # plot rainbow -ax.imshow(X, cmap=mne_field_grad_cols, aspect='equal', zorder=1) -ax.imshow(np.ones_like(X) * 0.5, cmap='Greys', aspect='equal', zorder=0, - clim=[0, 1]) +ax.imshow(X, cmap=mne_field_grad_cols, aspect="equal", zorder=1) +ax.imshow(np.ones_like(X) * 0.5, cmap="Greys", aspect="equal", zorder=0, clim=[0, 1]) plot_dims = np.r_[np.diff(ax.get_xbound()), np.diff(ax.get_ybound())] # MNE text in white -mne_path = TextPath((0, 0), 'MNE') +mne_path = TextPath((0, 0), "MNE") dims = mne_path.vertices.max(0) - mne_path.vertices.min(0) -vert = mne_path.vertices - dims / 2. +vert = mne_path.vertices - dims / 2.0 mult = scale_fudge * (plot_dims / dims).min() mult = [mult, -mult] # y axis is inverted (origin at top left) -offset = np.array([scale_fudge, 1.]) * \ - np.array([-dims[0], plot_dims[-1]]) / 2. - center_fudge +offset = ( + np.array([scale_fudge, 1.0]) * np.array([-dims[0], plot_dims[-1]]) / 2.0 + - center_fudge +) mne_clip = Path(offset + vert * mult, mne_path.codes) -mne_patch = PathPatch(mne_clip, facecolor='0.5', edgecolor='none', zorder=10) +mne_patch = PathPatch(mne_clip, facecolor="0.5", edgecolor="none", zorder=10) ax.add_patch(mne_patch) # adjust xlim and ylim mne_corners = mne_clip.get_extents().corners() @@ -194,11 +228,10 @@ xmax, ymax = np.max(mne_corners, axis=0) xl = ax.get_xlim() yl = ax.get_ylim() -xpad = np.abs(np.diff([xmin, xl[1]])) / 20. -ypad = np.abs(np.diff([ymax, ymin])) / 20. +xpad = np.abs(np.diff([xmin, xl[1]])) / 20.0 +ypad = np.abs(np.diff([ymax, ymin])) / 20.0 ax.set_xlim(xmin - xpad, xl[1] + xpad) ax.set_ylim(ymax + ypad, ymin - ypad) plt.draw() -plt.savefig(op.join(static_dir, 'mne_logo_small.svg'), - dpi=dpi, transparent=True) +plt.savefig(op.join(static_dir, "mne_logo_small.svg"), dpi=dpi, transparent=True) plt.close() diff --git a/mne/__init__.py b/mne/__init__.py index 27a2846887e..4457f310986 100644 --- a/mne/__init__.py +++ b/mne/__init__.py @@ -18,95 +18,214 @@ try: from importlib.metadata import version + __version__ = version("mne") except Exception: try: from ._version import __version__ except ImportError: - __version__ = '0.0.0' + __version__ = "0.0.0" # have to import verbose first since it's needed by many things -from .utils import (set_log_level, set_log_file, verbose, set_config, - get_config, get_config_path, set_cache_dir, - set_memmap_min_size, grand_average, sys_info, open_docs, - use_log_level) -from .io.pick import (pick_types, pick_channels, - pick_channels_regexp, pick_channels_forward, - pick_types_forward, pick_channels_cov, - pick_channels_evoked, pick_info, - channel_type, channel_indices_by_type) +from .utils import ( + set_log_level, + set_log_file, + verbose, + set_config, + get_config, + get_config_path, + set_cache_dir, + set_memmap_min_size, + grand_average, + sys_info, + open_docs, + use_log_level, +) +from .io.pick import ( + pick_types, + pick_channels, + pick_channels_regexp, + pick_channels_forward, + pick_types_forward, + pick_channels_cov, + pick_channels_evoked, + pick_info, + channel_type, + channel_indices_by_type, +) from .io.base import concatenate_raws, match_channel_orders from .io.meas_info import create_info, Info from .io.proj import Projection from .io.kit import read_epochs_kit from .io.eeglab import read_epochs_eeglab -from .io.reference import (set_eeg_reference, set_bipolar_reference, - add_reference_channels) +from .io.reference import ( + set_eeg_reference, + set_bipolar_reference, + add_reference_channels, +) from .io.what import what -from .bem import (make_sphere_model, make_bem_model, make_bem_solution, - read_bem_surfaces, write_bem_surfaces, write_head_bem, - read_bem_solution, write_bem_solution) -from .cov import (read_cov, write_cov, Covariance, compute_raw_covariance, - compute_covariance, whiten_evoked, make_ad_hoc_cov) -from .event import (read_events, write_events, find_events, merge_events, - pick_events, make_fixed_length_events, concatenate_events, - find_stim_steps, AcqParserFIF, count_events) -from ._freesurfer import (head_to_mni, head_to_mri, read_talxfm, - get_volume_labels_from_aseg, read_freesurfer_lut, - vertex_to_mni, read_lta) -from .forward import (read_forward_solution, apply_forward, apply_forward_raw, - average_forward_solutions, Forward, - write_forward_solution, make_forward_solution, - convert_forward_solution, make_field_map, - make_forward_dipole, use_coil_def) -from .source_estimate import (read_source_estimate, - SourceEstimate, VectorSourceEstimate, - VolSourceEstimate, VolVectorSourceEstimate, - MixedSourceEstimate, MixedVectorSourceEstimate, - grade_to_tris, - spatial_src_adjacency, - spatial_tris_adjacency, - spatial_dist_adjacency, - spatial_inter_hemi_adjacency, - spatio_temporal_src_adjacency, - spatio_temporal_tris_adjacency, - spatio_temporal_dist_adjacency, - extract_label_time_course, stc_near_sensors) -from .surface import (read_surface, write_surface, decimate_surface, read_tri, - get_head_surf, get_meg_helmet_surf, dig_mri_distances, - warp_montage_volume, get_montage_volume_labels) +from .bem import ( + make_sphere_model, + make_bem_model, + make_bem_solution, + read_bem_surfaces, + write_bem_surfaces, + write_head_bem, + read_bem_solution, + write_bem_solution, +) +from .cov import ( + read_cov, + write_cov, + Covariance, + compute_raw_covariance, + compute_covariance, + whiten_evoked, + make_ad_hoc_cov, +) +from .event import ( + read_events, + write_events, + find_events, + merge_events, + pick_events, + make_fixed_length_events, + concatenate_events, + find_stim_steps, + AcqParserFIF, + count_events, +) +from ._freesurfer import ( + head_to_mni, + head_to_mri, + read_talxfm, + get_volume_labels_from_aseg, + read_freesurfer_lut, + vertex_to_mni, + read_lta, +) +from .forward import ( + read_forward_solution, + apply_forward, + apply_forward_raw, + average_forward_solutions, + Forward, + write_forward_solution, + make_forward_solution, + convert_forward_solution, + make_field_map, + make_forward_dipole, + use_coil_def, +) +from .source_estimate import ( + read_source_estimate, + SourceEstimate, + VectorSourceEstimate, + VolSourceEstimate, + VolVectorSourceEstimate, + MixedSourceEstimate, + MixedVectorSourceEstimate, + grade_to_tris, + spatial_src_adjacency, + spatial_tris_adjacency, + spatial_dist_adjacency, + spatial_inter_hemi_adjacency, + spatio_temporal_src_adjacency, + spatio_temporal_tris_adjacency, + spatio_temporal_dist_adjacency, + extract_label_time_course, + stc_near_sensors, +) +from .surface import ( + read_surface, + write_surface, + decimate_surface, + read_tri, + get_head_surf, + get_meg_helmet_surf, + dig_mri_distances, + warp_montage_volume, + get_montage_volume_labels, +) from .morph_map import read_morph_map -from .morph import (SourceMorph, read_source_morph, grade_to_vertices, - compute_source_morph) -from .source_space import (read_source_spaces, - write_source_spaces, setup_source_space, - setup_volume_source_space, SourceSpaces, - add_source_space_distances, morph_source_spaces, - get_volume_labels_from_src) -from .annotations import (Annotations, read_annotations, annotations_from_events, - events_from_annotations) -from .epochs import (BaseEpochs, Epochs, EpochsArray, read_epochs, - concatenate_epochs, make_fixed_length_epochs) -from .evoked import (Evoked, EvokedArray, read_evokeds, write_evokeds, - combine_evoked) -from .label import (read_label, label_sign_flip, - write_label, stc_to_label, grow_labels, Label, split_label, - BiHemiLabel, read_labels_from_annot, write_labels_to_annot, - random_parcellation, morph_labels, labels_to_stc) +from .morph import ( + SourceMorph, + read_source_morph, + grade_to_vertices, + compute_source_morph, +) +from .source_space import ( + read_source_spaces, + write_source_spaces, + setup_source_space, + setup_volume_source_space, + SourceSpaces, + add_source_space_distances, + morph_source_spaces, + get_volume_labels_from_src, +) +from .annotations import ( + Annotations, + read_annotations, + annotations_from_events, + events_from_annotations, +) +from .epochs import ( + BaseEpochs, + Epochs, + EpochsArray, + read_epochs, + concatenate_epochs, + make_fixed_length_epochs, +) +from .evoked import Evoked, EvokedArray, read_evokeds, write_evokeds, combine_evoked +from .label import ( + read_label, + label_sign_flip, + write_label, + stc_to_label, + grow_labels, + Label, + split_label, + BiHemiLabel, + read_labels_from_annot, + write_labels_to_annot, + random_parcellation, + morph_labels, + labels_to_stc, +) from .misc import parse_config, read_reject_parameters -from .coreg import (create_default_subject, scale_bem, scale_mri, scale_labels, - scale_source_space) -from .transforms import (read_trans, write_trans, - transform_surface_to, Transform) -from .proj import (read_proj, write_proj, compute_proj_epochs, - compute_proj_evoked, compute_proj_raw, sensitivity_map) +from .coreg import ( + create_default_subject, + scale_bem, + scale_mri, + scale_labels, + scale_source_space, +) +from .transforms import read_trans, write_trans, transform_surface_to, Transform +from .proj import ( + read_proj, + write_proj, + compute_proj_epochs, + compute_proj_evoked, + compute_proj_raw, + sensitivity_map, +) from .dipole import read_dipole, Dipole, DipoleFixed, fit_dipole -from .channels import (equalize_channels, rename_channels, find_layout, - read_vectorview_selection) +from .channels import ( + equalize_channels, + rename_channels, + find_layout, + read_vectorview_selection, +) from .report import Report, open_report -from .io import (read_epochs_fieldtrip, read_evoked_besa, - read_evoked_fieldtrip, read_evokeds_mff) +from .io import ( + read_epochs_fieldtrip, + read_evoked_besa, + read_evoked_fieldtrip, + read_evokeds_mff, +) from .rank import compute_rank from . import beamformer diff --git a/mne/__main__.py b/mne/__main__.py index 414754c1885..5a3bfa5abb6 100644 --- a/mne/__main__.py +++ b/mne/__main__.py @@ -3,5 +3,5 @@ from .commands.utils import main -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mne/_freesurfer.py b/mne/_freesurfer.py index d92ac40e807..9b774dafc0d 100644 --- a/mne/_freesurfer.py +++ b/mne/_freesurfer.py @@ -12,31 +12,45 @@ from .bem import _bem_find_surface, read_bem_surfaces from .io.constants import FIFF from .io.meas_info import read_fiducials -from .transforms import (apply_trans, invert_transform, combine_transforms, - _ensure_trans, read_ras_mni_t, Transform) +from .transforms import ( + apply_trans, + invert_transform, + combine_transforms, + _ensure_trans, + read_ras_mni_t, + Transform, +) from .surface import read_surface, _read_mri_surface -from .utils import (verbose, _validate_type, _check_fname, _check_option, - get_subjects_dir, _import_nibabel, logger) +from .utils import ( + verbose, + _validate_type, + _check_fname, + _check_option, + get_subjects_dir, + _import_nibabel, + logger, +) def _check_subject_dir(subject, subjects_dir): """Check that the Freesurfer subject directory is as expected.""" subjects_dir = Path(get_subjects_dir(subjects_dir, raise_error=True)) - for img_name in ('T1', 'brain', 'aseg'): + for img_name in ("T1", "brain", "aseg"): if not (subjects_dir / subject / "mri" / f"{img_name}.mgz").is_file(): - raise ValueError('Freesurfer recon-all subject folder ' - 'is incorrect or improperly formatted, ' - f'got {subjects_dir / subject}') + raise ValueError( + "Freesurfer recon-all subject folder " + "is incorrect or improperly formatted, " + f"got {subjects_dir / subject}" + ) return subjects_dir / subject def _get_aseg(aseg, subject, subjects_dir): """Check that the anatomical segmentation file exists and load it.""" - nib = _import_nibabel('load aseg') + nib = _import_nibabel("load aseg") subjects_dir = Path(get_subjects_dir(subjects_dir, raise_error=True)) - if not aseg.endswith('aseg'): - raise RuntimeError( - f'`aseg` file path must end with "aseg", got {aseg}') + if not aseg.endswith("aseg"): + raise RuntimeError(f'`aseg` file path must end with "aseg", got {aseg}') aseg = _check_fname( subjects_dir / subject / "mri" / (aseg + ".mgz"), overwrite="read", @@ -47,7 +61,7 @@ def _get_aseg(aseg, subject, subjects_dir): return aseg, aseg_data -def _reorient_image(img, axcodes='RAS'): +def _reorient_image(img, axcodes="RAS"): """Reorient an image to a given orientation. Parameters @@ -69,11 +83,12 @@ def _reorient_image(img, axcodes='RAS'): ----- .. versionadded:: 0.24 """ - nib = _import_nibabel('reorient MRI image') + nib = _import_nibabel("reorient MRI image") orig_data = np.array(img.dataobj).astype(np.float32) # reorient data to RAS ornt = nib.orientations.axcodes2ornt( - nib.orientations.aff2axcodes(img.affine)).astype(int) + nib.orientations.aff2axcodes(img.affine) + ).astype(int) ras_ornt = nib.orientations.axcodes2ornt(axcodes) ornt_trans = nib.orientations.ornt_transform(ornt, ras_ornt) img_data = nib.orientations.apply_orientation(orig_data, ornt_trans) @@ -105,7 +120,7 @@ def _mri_orientation(orientation): .. versionadded:: 0.21 .. versionchanged:: 0.24 """ - _check_option('orientation', orientation, ('coronal', 'axial', 'sagittal')) + _check_option("orientation", orientation, ("coronal", "axial", "sagittal")) axis = dict(coronal=1, axial=2, sagittal=0)[orientation] x, y = sorted(set([0, 1, 2]).difference(set([axis]))) return axis, x, y @@ -114,72 +129,81 @@ def _mri_orientation(orientation): def _get_mri_info_data(mri, data): # Read the segmentation data using nibabel if data: - _import_nibabel('load MRI atlas data') + _import_nibabel("load MRI atlas data") out = dict() - _, out['vox_mri_t'], out['mri_ras_t'], dims, _, mgz = _read_mri_info( - mri, return_img=True) + _, out["vox_mri_t"], out["mri_ras_t"], dims, _, mgz = _read_mri_info( + mri, return_img=True + ) out.update( - mri_width=dims[0], mri_height=dims[1], - mri_depth=dims[1], mri_volume_name=mri) + mri_width=dims[0], mri_height=dims[1], mri_depth=dims[1], mri_volume_name=mri + ) if data: assert mgz is not None - out['mri_vox_t'] = invert_transform(out['vox_mri_t']) - out['data'] = np.asarray(mgz.dataobj) + out["mri_vox_t"] = invert_transform(out["vox_mri_t"]) + out["data"] = np.asarray(mgz.dataobj) return out def _get_mgz_header(fname): """Adapted from nibabel to quickly extract header info.""" - fname = _check_fname(fname, overwrite='read', must_exist=True, - name='MRI image') + fname = _check_fname(fname, overwrite="read", must_exist=True, name="MRI image") if fname.suffix != ".mgz": - raise OSError('Filename must end with .mgz') - header_dtd = [('version', '>i4'), ('dims', '>i4', (4,)), - ('type', '>i4'), ('dof', '>i4'), ('goodRASFlag', '>i2'), - ('delta', '>f4', (3,)), ('Mdc', '>f4', (3, 3)), - ('Pxyz_c', '>f4', (3,))] + raise OSError("Filename must end with .mgz") + header_dtd = [ + ("version", ">i4"), + ("dims", ">i4", (4,)), + ("type", ">i4"), + ("dof", ">i4"), + ("goodRASFlag", ">i2"), + ("delta", ">f4", (3,)), + ("Mdc", ">f4", (3, 3)), + ("Pxyz_c", ">f4", (3,)), + ] header_dtype = np.dtype(header_dtd) - with GzipFile(fname, 'rb') as fid: + with GzipFile(fname, "rb") as fid: hdr_str = fid.read(header_dtype.itemsize) - header = np.ndarray(shape=(), dtype=header_dtype, - buffer=hdr_str) + header = np.ndarray(shape=(), dtype=header_dtype, buffer=hdr_str) # dims - dims = header['dims'].astype(int) + dims = header["dims"].astype(int) dims = dims[:3] if len(dims) == 4 else dims # vox2ras_tkr - delta = header['delta'] + delta = header["delta"] ds = np.array(delta, float) ns = np.array(dims * ds) / 2.0 - v2rtkr = np.array([[-ds[0], 0, 0, ns[0]], - [0, 0, ds[2], -ns[2]], - [0, -ds[1], 0, ns[1]], - [0, 0, 0, 1]], dtype=np.float32) + v2rtkr = np.array( + [ + [-ds[0], 0, 0, ns[0]], + [0, 0, ds[2], -ns[2]], + [0, -ds[1], 0, ns[1]], + [0, 0, 0, 1], + ], + dtype=np.float32, + ) # ras2vox d = np.diag(delta) pcrs_c = dims / 2.0 - Mdc = header['Mdc'].T - pxyz_0 = header['Pxyz_c'] - np.dot(Mdc, np.dot(d, pcrs_c)) + Mdc = header["Mdc"].T + pxyz_0 = header["Pxyz_c"] - np.dot(Mdc, np.dot(d, pcrs_c)) M = np.eye(4, 4) M[0:3, 0:3] = np.dot(Mdc, d) M[0:3, 3] = pxyz_0.T - header = dict(dims=dims, vox2ras_tkr=v2rtkr, vox2ras=M, - zooms=header['delta']) + header = dict(dims=dims, vox2ras_tkr=v2rtkr, vox2ras=M, zooms=header["delta"]) return header def _get_atlas_values(vol_info, rr): # Transform MRI coordinates (where our surfaces live) to voxels - rr_vox = apply_trans(vol_info['mri_vox_t'], rr) - good = ((rr_vox >= -.5) & - (rr_vox < np.array(vol_info['data'].shape, int) - 0.5)).all(-1) + rr_vox = apply_trans(vol_info["mri_vox_t"], rr) + good = ( + (rr_vox >= -0.5) & (rr_vox < np.array(vol_info["data"].shape, int) - 0.5) + ).all(-1) idx = np.round(rr_vox[good].T).astype(np.int64) values = np.full(rr.shape[0], np.nan) - values[good] = vol_info['data'][tuple(idx)] + values[good] = vol_info["data"][tuple(idx)] return values -def get_volume_labels_from_aseg(mgz_fname, return_colors=False, - atlas_ids=None): +def get_volume_labels_from_aseg(mgz_fname, return_colors=False, atlas_ids=None): """Return a list of names and colors of segmented volumes. Parameters @@ -214,7 +238,7 @@ def get_volume_labels_from_aseg(mgz_fname, return_colors=False, .. versionadded:: 0.9.0 """ - nib = _import_nibabel('load MRI atlas data') + nib = _import_nibabel("load MRI atlas data") mgz_fname = _check_fname( mgz_fname, overwrite="read", must_exist=True, name="mgz_fname" ) @@ -224,12 +248,13 @@ def get_volume_labels_from_aseg(mgz_fname, return_colors=False, if atlas_ids is None: atlas_ids, colors = read_freesurfer_lut() elif return_colors: - raise ValueError('return_colors must be False if atlas_ids are ' - 'provided') + raise ValueError("return_colors must be False if atlas_ids are " "provided") # restrict to the ones in the MRI, sorted by label name keep = np.in1d(list(atlas_ids.values()), want) - keys = sorted((key for ki, key in enumerate(atlas_ids.keys()) if keep[ki]), - key=lambda x: atlas_ids[x]) + keys = sorted( + (key for ki, key in enumerate(atlas_ids.keys()) if keep[ki]), + key=lambda x: atlas_ids[x], + ) if return_colors: colors = [colors[k] for k in keys] out = keys, colors @@ -243,8 +268,16 @@ def get_volume_labels_from_aseg(mgz_fname, return_colors=False, @verbose -def head_to_mri(pos, subject, mri_head_t, subjects_dir=None, *, - kind='mri', unscale=False, verbose=None): +def head_to_mri( + pos, + subject, + mri_head_t, + subjects_dir=None, + *, + kind="mri", + unscale=False, + verbose=None, +): """Convert pos from head coordinate system to MRI ones. Parameters @@ -279,23 +312,24 @@ def head_to_mri(pos, subject, mri_head_t, subjects_dir=None, *, This function requires nibabel. """ from .coreg import read_mri_cfg - _validate_type(kind, str, 'kind') - _check_option('kind', kind, ('ras', 'mri')) + + _validate_type(kind, str, "kind") + _check_option("kind", kind, ("ras", "mri")) subjects_dir = get_subjects_dir(subjects_dir, raise_error=True) t1_fname = subjects_dir / subject / "mri" / "T1.mgz" - head_mri_t = _ensure_trans(mri_head_t, 'head', 'mri') - if kind == 'ras': + head_mri_t = _ensure_trans(mri_head_t, "head", "mri") + if kind == "ras": _, _, mri_ras_t, _, _ = _read_mri_info(t1_fname) - head_ras_t = combine_transforms(head_mri_t, mri_ras_t, 'head', 'ras') + head_ras_t = combine_transforms(head_mri_t, mri_ras_t, "head", "ras") head_dest_t = head_ras_t else: - assert kind == 'mri' + assert kind == "mri" head_dest_t = head_mri_t pos_dest = apply_trans(head_dest_t, pos) # unscale if requested if unscale: params = read_mri_cfg(subject, subjects_dir) - pos_dest /= params['scale'] + pos_dest /= params["scale"] pos_dest *= 1e3 # mm return pos_dest @@ -303,6 +337,7 @@ def head_to_mri(pos, subject, mri_head_t, subjects_dir=None, *, ############################################################################## # Surface to MNI conversion + @verbose def vertex_to_mni(vertices, hemis, subject, subjects_dir=None, verbose=None): """Convert the array of vertices for a hemisphere to MNI coordinates. @@ -332,33 +367,30 @@ def vertex_to_mni(vertices, hemis, subject, subjects_dir=None, verbose=None): hemis = [hemis] * len(vertices) if not len(hemis) == len(vertices): - raise ValueError('hemi and vertices must match in length') + raise ValueError("hemi and vertices must match in length") subjects_dir = get_subjects_dir(subjects_dir, raise_error=True) - surfs = [ - subjects_dir / subject / "surf" / f"{h}.white" - for h in ["lh", "rh"] - ] + surfs = [subjects_dir / subject / "surf" / f"{h}.white" for h in ["lh", "rh"]] # read surface locations in MRI space rr = [read_surface(s)[0] for s in surfs] # take point locations in MRI space and convert to MNI coordinates xfm = read_talxfm(subject, subjects_dir) - xfm['trans'][:3, 3] *= 1000. # m->mm + xfm["trans"][:3, 3] *= 1000.0 # m->mm data = np.array([rr[h][v, :] for h, v in zip(hemis, vertices)]) if singleton: data = data[0] - return apply_trans(xfm['trans'], data) + return apply_trans(xfm["trans"], data) ############################################################################## # Volume to MNI conversion + @verbose -def head_to_mni(pos, subject, mri_head_t, subjects_dir=None, - verbose=None): +def head_to_mni(pos, subject, mri_head_t, subjects_dir=None, verbose=None): """Convert pos from head coordinate system to MNI ones. Parameters @@ -384,9 +416,12 @@ def head_to_mni(pos, subject, mri_head_t, subjects_dir=None, # before we go from head to MRI (surface RAS) head_mni_t = combine_transforms( - _ensure_trans(mri_head_t, 'head', 'mri'), - read_talxfm(subject, subjects_dir), 'head', 'mni_tal') - return apply_trans(head_mni_t, pos) * 1000. + _ensure_trans(mri_head_t, "head", "mri"), + read_talxfm(subject, subjects_dir), + "head", + "mni_tal", + ) + return apply_trans(head_mni_t, pos) * 1000.0 @verbose @@ -424,20 +459,17 @@ def get_mni_fiducials(subject, subjects_dir=None, verbose=None): # transformation, and/or project the points onto the head surface # (if available). fname_fids_fs = ( - Path(__file__).parent - / "data" - / "fsaverage" - / "fsaverage-fiducials.fif" + Path(__file__).parent / "data" / "fsaverage" / "fsaverage-fiducials.fif" ) # Read fsaverage fiducials file and subject Talairach. fids, coord_frame = read_fiducials(fname_fids_fs) assert coord_frame == FIFF.FIFFV_COORD_MRI - if subject == 'fsaverage': + if subject == "fsaverage": return fids # special short-circuit for fsaverage mni_mri_t = invert_transform(read_talxfm(subject, subjects_dir)) for f in fids: - f['r'] = apply_trans(mni_mri_t, f['r']) + f["r"] = apply_trans(mni_mri_t, f["r"]) return fids @@ -463,34 +495,40 @@ def estimate_head_mri_t(subject, subjects_dir=None, verbose=None): subjects_dir = get_subjects_dir(subjects_dir, raise_error=True) lpa, nasion, rpa = get_mni_fiducials(subject, subjects_dir) - montage = make_dig_montage(lpa=lpa['r'], nasion=nasion['r'], rpa=rpa['r'], - coord_frame='mri') + montage = make_dig_montage( + lpa=lpa["r"], nasion=nasion["r"], rpa=rpa["r"], coord_frame="mri" + ) return invert_transform(compute_native_head_t(montage)) def _ensure_image_in_surface_RAS(image, subject, subjects_dir): """Check if the image is in Freesurfer surface RAS space.""" - nib = _import_nibabel('load a volume image') + nib = _import_nibabel("load a volume image") if not isinstance(image, nib.spatialimages.SpatialImage): image = nib.load(image) image = nib.MGHImage(image.dataobj.astype(np.float32), image.affine) - fs_img = nib.load(op.join(subjects_dir, subject, 'mri', 'brain.mgz')) + fs_img = nib.load(op.join(subjects_dir, subject, "mri", "brain.mgz")) if not np.allclose(image.affine, fs_img.affine, atol=1e-6): - raise RuntimeError('The `image` is not aligned to Freesurfer ' - 'surface RAS space. This space is required as ' - 'it is the space where the anatomical ' - 'segmentation and reconstructed surfaces are') + raise RuntimeError( + "The `image` is not aligned to Freesurfer " + "surface RAS space. This space is required as " + "it is the space where the anatomical " + "segmentation and reconstructed surfaces are" + ) return image # returns MGH image for header def _get_affine_from_lta_info(lines): """Get the vox2ras affine from lta file info.""" - volume_data = np.loadtxt( - [line.split('=')[1] for line in lines]) + volume_data = np.loadtxt([line.split("=")[1] for line in lines]) # get the size of the volume (number of voxels), slice resolution. # the matrix of directional cosines and the ras at the center of the bore - dims, deltas, dir_cos, center_ras = \ - volume_data[0], volume_data[1], volume_data[2:5], volume_data[5] + dims, deltas, dir_cos, center_ras = ( + volume_data[0], + volume_data[1], + volume_data[2:5], + volume_data[5], + ) dir_cos_delta = dir_cos.T * deltas vol_center = (dir_cos_delta @ dims[:3]) / 2 affine = np.eye(4) @@ -514,11 +552,11 @@ def read_lta(fname, verbose=None): affine : ndarray The affine transformation described by the lta file. """ - _check_fname(fname, 'read', must_exist=True) - with open(fname, 'r') as fid: + _check_fname(fname, "read", must_exist=True) + with open(fname, "r") as fid: lines = fid.readlines() # 0 is linear vox2vox, 1 is linear ras2ras - trans_type = int(lines[0].split('=')[1].strip()[0]) + trans_type = int(lines[0].split("=")[1].strip()[0]) assert trans_type in (0, 1) affine = np.loadtxt(lines[5:9]) if trans_type == 1: @@ -556,7 +594,7 @@ def read_talxfm(subject, subjects_dir=None, verbose=None): subjects_dir = get_subjects_dir(subjects_dir) # Setup the RAS to MNI transform ras_mni_t = read_ras_mni_t(subject, subjects_dir) - ras_mni_t['trans'][:3, 3] /= 1000. # mm->m + ras_mni_t["trans"][:3, 3] /= 1000.0 # mm->m # We want to get from Freesurfer surface RAS ('mri') to MNI ('mni_tal'). # This file only gives us RAS (non-zero origin) ('ras') to MNI ('mni_tal'). @@ -568,33 +606,36 @@ def read_talxfm(subject, subjects_dir=None, verbose=None): if not path.is_file(): path = subjects_dir / subject / "mri" / "T1.mgz" if not path.is_file(): - raise OSError('mri not found: %s' % path) + raise OSError("mri not found: %s" % path) _, _, mri_ras_t, _, _ = _read_mri_info(path) - mri_mni_t = combine_transforms(mri_ras_t, ras_mni_t, 'mri', 'mni_tal') + mri_mni_t = combine_transforms(mri_ras_t, ras_mni_t, "mri", "mni_tal") return mri_mni_t def _check_mri(mri, subject, subjects_dir): """Check whether an mri exists in the Freesurfer subject directory.""" - _validate_type(mri, 'path-like', 'mri') + _validate_type(mri, "path-like", "mri") if op.isfile(mri) and op.basename(mri) != mri: return mri if not op.isfile(mri): if subject is None: raise FileNotFoundError( - f'MRI file {mri!r} not found and no subject provided') + f"MRI file {mri!r} not found and no subject provided" + ) subjects_dir = str(get_subjects_dir(subjects_dir, raise_error=True)) - mri = op.join(subjects_dir, subject, 'mri', mri) + mri = op.join(subjects_dir, subject, "mri", mri) if not op.isfile(mri): - raise FileNotFoundError(f'MRI file {mri!r} not found') + raise FileNotFoundError(f"MRI file {mri!r} not found") if op.basename(mri) == mri: - err = (f'Ambiguous filename - found {mri!r} in current folder.\n' - 'If this is correct prefix name with relative or absolute path') + err = ( + f"Ambiguous filename - found {mri!r} in current folder.\n" + "If this is correct prefix name with relative or absolute path" + ) raise OSError(err) return mri -def _read_mri_info(path, units='m', return_img=False, use_nibabel=False): +def _read_mri_info(path, units="m", return_img=False, use_nibabel=False): # This is equivalent but 100x slower, so only use nibabel if we need to # (later): if use_nibabel: @@ -606,29 +647,28 @@ def _read_mri_info(path, units='m', return_img=False, use_nibabel=False): zooms = hdr.get_zooms()[:3] else: hdr = _get_mgz_header(path) - n_orig = hdr['vox2ras'] - t_orig = hdr['vox2ras_tkr'] - dims = hdr['dims'] - zooms = hdr['zooms'] + n_orig = hdr["vox2ras"] + t_orig = hdr["vox2ras_tkr"] + dims = hdr["dims"] + zooms = hdr["zooms"] # extract the MRI_VOXEL to RAS (non-zero origin) transform - vox_ras_t = Transform('mri_voxel', 'ras', n_orig) + vox_ras_t = Transform("mri_voxel", "ras", n_orig) # extract the MRI_VOXEL to MRI transform - vox_mri_t = Transform('mri_voxel', 'mri', t_orig) + vox_mri_t = Transform("mri_voxel", "mri", t_orig) # construct the MRI to RAS (non-zero origin) transform - mri_ras_t = combine_transforms( - invert_transform(vox_mri_t), vox_ras_t, 'mri', 'ras') + mri_ras_t = combine_transforms(invert_transform(vox_mri_t), vox_ras_t, "mri", "ras") - assert units in ('m', 'mm') - if units == 'm': + assert units in ("m", "mm") + if units == "m": conv = np.array([[1e-3, 1e-3, 1e-3, 1]]).T # scaling and translation terms - vox_ras_t['trans'] *= conv - vox_mri_t['trans'] *= conv + vox_ras_t["trans"] *= conv + vox_mri_t["trans"] *= conv # just the translation term - mri_ras_t['trans'][:, 3:4] *= conv + mri_ras_t["trans"][:, 3:4] *= conv out = (vox_ras_t, vox_mri_t, mri_ras_t, dims, zooms) if return_img: @@ -653,8 +693,8 @@ def read_freesurfer_lut(fname=None): Mapping from label names to colors. """ lut = _get_lut(fname) - names, ids = lut['name'], lut['id'] - colors = np.array([lut['R'], lut['G'], lut['B'], lut['A']], float).T + names, ids = lut["name"], lut["id"] + colors = np.array([lut["R"], lut["G"], lut["B"], lut["A"]], float).T atlas_ids = dict(zip(names, ids)) colors = dict(zip(names, colors)) return atlas_ids, colors @@ -664,22 +704,28 @@ def _get_lut(fname=None): """Get a FreeSurfer LUT.""" if fname is None: fname = Path(__file__).parent / "data" / "FreeSurferColorLUT.txt" - _check_fname(fname, 'read', must_exist=True) - dtype = [('id', ' 0 + assert len(lut["name"]) > 0 return lut @@ -709,44 +755,50 @@ def _get_head_surface(surf, subject, subjects_dir, bem=None, verbose=None): ----- .. versionadded: 0.24 """ - _check_option( - 'surf', surf, ('auto', 'head', 'outer_skin', 'head-dense', 'seghead')) - if surf in ('auto', 'head', 'outer_skin'): + _check_option("surf", surf, ("auto", "head", "outer_skin", "head-dense", "seghead")) + if surf in ("auto", "head", "outer_skin"): if bem is not None: try: - return _bem_find_surface(bem, 'head') + return _bem_find_surface(bem, "head") except RuntimeError: - logger.info('Could not find the surface for ' - 'head in the provided BEM model, ' - 'looking in the subject directory.') + logger.info( + "Could not find the surface for " + "head in the provided BEM model, " + "looking in the subject directory." + ) if subject is None: - if surf == 'auto': + if surf == "auto": return - raise ValueError('To plot the head surface, the BEM/sphere' - ' model must contain a head surface ' - 'or "subject" must be provided (got ' - 'None)') - subject_dir = op.join( - get_subjects_dir(subjects_dir, raise_error=True), subject) - if surf in ('head-dense', 'seghead'): - try_fnames = [op.join(subject_dir, 'bem', f'{subject}-head-dense.fif'), - op.join(subject_dir, 'surf', 'lh.seghead')] + raise ValueError( + "To plot the head surface, the BEM/sphere" + " model must contain a head surface " + 'or "subject" must be provided (got ' + "None)" + ) + subject_dir = op.join(get_subjects_dir(subjects_dir, raise_error=True), subject) + if surf in ("head-dense", "seghead"): + try_fnames = [ + op.join(subject_dir, "bem", f"{subject}-head-dense.fif"), + op.join(subject_dir, "surf", "lh.seghead"), + ] else: try_fnames = [ - op.join(subject_dir, 'bem', 'outer_skin.surf'), - op.join(subject_dir, 'bem', 'flash', 'outer_skin.surf'), - op.join(subject_dir, 'bem', f'{subject}-head-sparse.fif'), - op.join(subject_dir, 'bem', f'{subject}-head.fif'), + op.join(subject_dir, "bem", "outer_skin.surf"), + op.join(subject_dir, "bem", "flash", "outer_skin.surf"), + op.join(subject_dir, "bem", f"{subject}-head-sparse.fif"), + op.join(subject_dir, "bem", f"{subject}-head.fif"), ] for fname in try_fnames: if op.exists(fname): - logger.info(f'Using {op.basename(fname)} for head surface.') - if op.splitext(fname)[-1] == '.fif': - return read_bem_surfaces(fname, on_defects='warn')[0] + logger.info(f"Using {op.basename(fname)} for head surface.") + if op.splitext(fname)[-1] == ".fif": + return read_bem_surfaces(fname, on_defects="warn")[0] else: return _read_mri_surface(fname) - raise OSError('No head surface found for subject ' - f'{subject} after trying:\n' + '\n'.join(try_fnames)) + raise OSError( + "No head surface found for subject " + f"{subject} after trying:\n" + "\n".join(try_fnames) + ) @verbose @@ -776,29 +828,32 @@ def _get_skull_surface(surf, subject, subjects_dir, bem=None, verbose=None): """ if bem is not None: try: - return _bem_find_surface(bem, surf + '_skull') + return _bem_find_surface(bem, surf + "_skull") except RuntimeError: - logger.info('Could not find the surface for ' - 'skull in the provided BEM model, ' - 'looking in the subject directory.') + logger.info( + "Could not find the surface for " + "skull in the provided BEM model, " + "looking in the subject directory." + ) subjects_dir = Path(get_subjects_dir(subjects_dir, raise_error=True)) fname = _check_fname( subjects_dir / subject / "bem" / (surf + "_skull.surf"), overwrite="read", must_exist=True, - name=f"{surf} skull surface" + name=f"{surf} skull surface", ) return _read_mri_surface(fname) def _estimate_talxfm_rigid(subject, subjects_dir): from .coreg import fit_matched_points, _trans_from_params + xfm = read_talxfm(subject, subjects_dir) # XYZ+origin + halfway pts_tal = np.concatenate([np.eye(4)[:, :3], np.eye(3) * 0.5]) pts_subj = apply_trans(invert_transform(xfm), pts_tal) # we fit with scaling enabled, but then discard it (we just need # the rigid-body components) - params = fit_matched_points(pts_subj, pts_tal, scale=3, out='params') + params = fit_matched_points(pts_subj, pts_tal, scale=3, out="params") rigid = _trans_from_params((True, True, False), params[:6]) return rigid diff --git a/mne/_ola.py b/mne/_ola.py index a4ecad26a66..df92f771bf6 100644 --- a/mne/_ola.py +++ b/mne/_ola.py @@ -10,6 +10,7 @@ ############################################################################### # Class for interpolation between adjacent points + class _Interp2: r"""Interpolate between two points. @@ -41,56 +42,62 @@ class _Interp2: """ - def __init__(self, control_points, values, interp='hann'): + def __init__(self, control_points, values, interp="hann"): # set up interpolation self.control_points = np.array(control_points, int).ravel() - if not np.array_equal(np.unique(self.control_points), - self.control_points): - raise ValueError('Control points must be sorted and unique') + if not np.array_equal(np.unique(self.control_points), self.control_points): + raise ValueError("Control points must be sorted and unique") if len(self.control_points) == 0: - raise ValueError('Must be at least one control point') + raise ValueError("Must be at least one control point") if not (self.control_points >= 0).all(): - raise ValueError('All control points must be positive (got %s)' - % (self.control_points[:3],)) + raise ValueError( + "All control points must be positive (got %s)" + % (self.control_points[:3],) + ) if isinstance(values, np.ndarray): values = [values] if isinstance(values, (list, tuple)): for v in values: if not (v is None or isinstance(v, np.ndarray)): - raise TypeError('All entries in "values" must be ndarray ' - 'or None, got %s' % (type(v),)) + raise TypeError( + 'All entries in "values" must be ndarray ' + "or None, got %s" % (type(v),) + ) if v is not None and v.shape[0] != len(self.control_points): - raise ValueError('Values, if provided, must be the same ' - 'length as the number of control points ' - '(%s), got %s' - % (len(self.control_points), v.shape[0])) + raise ValueError( + "Values, if provided, must be the same " + "length as the number of control points " + "(%s), got %s" % (len(self.control_points), v.shape[0]) + ) use_values = values def val(pt): idx = np.where(control_points == pt)[0][0] return [v[idx] if v is not None else None for v in use_values] + values = val self.values = values self.n_last = None self._position = 0 # start at zero self._left_idx = 0 self._left = self._right = self._use_interp = None - known_types = ('cos2', 'linear', 'zero', 'hann') + known_types = ("cos2", "linear", "zero", "hann") if interp not in known_types: - raise ValueError('interp must be one of %s, got "%s"' - % (known_types, interp)) + raise ValueError( + 'interp must be one of %s, got "%s"' % (known_types, interp) + ) self._interp = interp def feed_generator(self, n_pts): """Feed data and get interpolators as a generator.""" self.n_last = 0 - n_pts = _ensure_int(n_pts, 'n_pts') + n_pts = _ensure_int(n_pts, "n_pts") original_position = self._position stop = self._position + n_pts - logger.debug('Feed %s (%s-%s)' % (n_pts, self._position, stop)) + logger.debug("Feed %s (%s-%s)" % (n_pts, self._position, stop)) used = np.zeros(n_pts, bool) if self._left is None: # first one - logger.debug(' Eval @ %s (%s)' % (0, self.control_points[0])) + logger.debug(" Eval @ %s (%s)" % (0, self.control_points[0])) self._left = self.values(self.control_points[0]) if len(self.control_points) == 1: self._right = self._left @@ -98,9 +105,8 @@ def feed_generator(self, n_pts): # Left zero-order hold condition if self._position < self.control_points[self._left_idx]: - n_use = min(self.control_points[self._left_idx] - self._position, - n_pts) - logger.debug(' Left ZOH %s' % n_use) + n_use = min(self.control_points[self._left_idx] - self._position, n_pts) + logger.debug(" Left ZOH %s" % n_use) this_sl = slice(None, n_use) assert used[this_sl].size == n_use assert not used[this_sl].any() @@ -125,35 +131,36 @@ def feed_generator(self, n_pts): self._left_idx += 1 self._use_interp = None # need to recreate it eval_pt = self.control_points[self._left_idx + 1] - logger.debug(' Eval @ %s (%s)' - % (self._left_idx + 1, eval_pt)) + logger.debug(" Eval @ %s (%s)" % (self._left_idx + 1, eval_pt)) self._right = self.values(eval_pt) assert self._right is not None left_point = self.control_points[self._left_idx] right_point = self.control_points[self._left_idx + 1] if self._use_interp is None: interp_span = right_point - left_point - if self._interp == 'zero': + if self._interp == "zero": self._use_interp = None - elif self._interp == 'linear': - self._use_interp = np.linspace(1., 0., interp_span, - endpoint=False) + elif self._interp == "linear": + self._use_interp = np.linspace( + 1.0, 0.0, interp_span, endpoint=False + ) else: # self._interp in ('cos2', 'hann'): self._use_interp = np.cos( - np.linspace(0, np.pi / 2., interp_span, - endpoint=False)) + np.linspace(0, np.pi / 2.0, interp_span, endpoint=False) + ) self._use_interp *= self._use_interp n_use = min(stop, right_point) - self._position if n_use > 0: - logger.debug(' Interp %s %s (%s-%s)' % (self._interp, n_use, - left_point, right_point)) + logger.debug( + " Interp %s %s (%s-%s)" + % (self._interp, n_use, left_point, right_point) + ) interp_start = self._position - left_point assert interp_start >= 0 if self._use_interp is None: this_interp = None else: - this_interp = \ - self._use_interp[interp_start:interp_start + n_use] + this_interp = self._use_interp[interp_start : interp_start + n_use] assert this_interp.size == n_use this_sl = slice(n_used, n_used + n_use) assert used[this_sl].size == n_use @@ -167,7 +174,7 @@ def feed_generator(self, n_pts): if self.control_points[self._left_idx] <= self._position: n_use = stop - self._position if n_use > 0: - logger.debug(' Right ZOH %s' % n_use) + logger.debug(" Right ZOH %s" % n_use) this_sl = slice(n_pts - n_use, None) assert not used[this_sl].any() used[this_sl] = True @@ -187,16 +194,18 @@ def feed(self, n_pts): out_arrays = None for o in self.feed_generator(n_pts): if out_arrays is None: - out_arrays = [np.empty(v.shape + (n_pts,)) - if v is not None else None for v in o[1]] + out_arrays = [ + np.empty(v.shape + (n_pts,)) if v is not None else None + for v in o[1] + ] for ai, arr in enumerate(out_arrays): if arr is not None: if o[3] is None: arr[..., o[0]] = o[1][ai][..., np.newaxis] else: - arr[..., o[0]] = ( - o[1][ai][..., np.newaxis] * o[3] + - o[2][ai][..., np.newaxis] * (1. - o[3])) + arr[..., o[0]] = o[1][ai][..., np.newaxis] * o[3] + o[2][ai][ + ..., np.newaxis + ] * (1.0 - o[3]) assert out_arrays is not None return out_arrays @@ -208,12 +217,12 @@ def feed(self, n_pts): def _check_store(store): if isinstance(store, np.ndarray): store = [store] - if isinstance(store, (list, tuple)) and all(isinstance(s, np.ndarray) - for s in store): + if isinstance(store, (list, tuple)) and all( + isinstance(s, np.ndarray) for s in store + ): store = _Storer(*store) if not callable(store): - raise TypeError('store must be callable, got type %s' - % (type(store),)) + raise TypeError("store must be callable, got type %s" % (type(store),)) return store @@ -261,28 +270,40 @@ class _COLA: """ @verbose - def __init__(self, process, store, n_total, n_samples, n_overlap, - sfreq, window='hann', tol=1e-10, *, verbose=None): + def __init__( + self, + process, + store, + n_total, + n_samples, + n_overlap, + sfreq, + window="hann", + tol=1e-10, + *, + verbose=None + ): from scipy.signal import get_window - n_samples = _ensure_int(n_samples, 'n_samples') - n_overlap = _ensure_int(n_overlap, 'n_overlap') - n_total = _ensure_int(n_total, 'n_total') + + n_samples = _ensure_int(n_samples, "n_samples") + n_overlap = _ensure_int(n_overlap, "n_overlap") + n_total = _ensure_int(n_total, "n_total") if n_samples <= 0: - raise ValueError('n_samples must be > 0, got %s' % (n_samples,)) + raise ValueError("n_samples must be > 0, got %s" % (n_samples,)) if n_overlap < 0: - raise ValueError('n_overlap must be >= 0, got %s' % (n_overlap,)) + raise ValueError("n_overlap must be >= 0, got %s" % (n_overlap,)) if n_total < 0: - raise ValueError('n_total must be >= 0, got %s' % (n_total,)) + raise ValueError("n_total must be >= 0, got %s" % (n_total,)) self._n_samples = int(n_samples) self._n_overlap = int(n_overlap) del n_samples, n_overlap if n_total < self._n_samples: - raise ValueError('Number of samples per window (%d) must be at ' - 'most the total number of samples (%s)' - % (self._n_samples, n_total)) + raise ValueError( + "Number of samples per window (%d) must be at " + "most the total number of samples (%s)" % (self._n_samples, n_total) + ) if not callable(process): - raise TypeError('process must be callable, got type %s' - % (type(process),)) + raise TypeError("process must be callable, got type %s" % (type(process),)) self._process = process self._step = self._n_samples - self._n_overlap self._store = _check_store(store) @@ -290,25 +311,36 @@ def __init__(self, process, store, n_total, n_samples, n_overlap, self._in_buffers = self._out_buffers = None # Create our window boundaries - window_name = window if isinstance(window, str) else 'custom' - self._window = get_window(window, self._n_samples, - fftbins=(self._n_samples - 1) % 2) - self._window /= _check_cola(self._window, self._n_samples, self._step, - window_name, tol=tol) + window_name = window if isinstance(window, str) else "custom" + self._window = get_window( + window, self._n_samples, fftbins=(self._n_samples - 1) % 2 + ) + self._window /= _check_cola( + self._window, self._n_samples, self._step, window_name, tol=tol + ) self.starts = np.arange(0, n_total - self._n_samples + 1, self._step) self.stops = self.starts + self._n_samples delta = n_total - self.stops[-1] self.stops[-1] = n_total sfreq = float(sfreq) - pl = 's' if len(self.starts) != 1 else '' - logger.info(' Processing %4d data chunk%s of (at least) %0.1f s ' - 'with %0.1f s overlap and %s windowing' - % (len(self.starts), pl, self._n_samples / sfreq, - self._n_overlap / sfreq, window_name)) + pl = "s" if len(self.starts) != 1 else "" + logger.info( + " Processing %4d data chunk%s of (at least) %0.1f s " + "with %0.1f s overlap and %s windowing" + % ( + len(self.starts), + pl, + self._n_samples / sfreq, + self._n_overlap / sfreq, + window_name, + ) + ) del window, window_name if delta > 0: - logger.info(' The final %0.3f s will be lumped into the ' - 'final window' % (delta / sfreq,)) + logger.info( + " The final %0.3f s will be lumped into the " + "final window" % (delta / sfreq,) + ) @property def _in_offset(self): @@ -322,65 +354,79 @@ def feed(self, *datas, verbose=None, **kwargs): if self._in_buffers is None: self._in_buffers = [None] * len(datas) if len(datas) != len(self._in_buffers): - raise ValueError('Got %d array(s), needed %d' - % (len(datas), len(self._in_buffers))) + raise ValueError( + "Got %d array(s), needed %d" % (len(datas), len(self._in_buffers)) + ) for di, data in enumerate(datas): if not isinstance(data, np.ndarray) or data.ndim < 1: - raise TypeError('data entry %d must be an 2D ndarray, got %s' - % (di, type(data),)) + raise TypeError( + "data entry %d must be an 2D ndarray, got %s" + % ( + di, + type(data), + ) + ) if self._in_buffers[di] is None: # In practice, users can give large chunks, so we use # dynamic allocation of the in buffer. We could save some # memory allocation by only ever processing max_len at once, # but this would increase code complexity. - self._in_buffers[di] = np.empty( - data.shape[:-1] + (0,), data.dtype) - if data.shape[:-1] != self._in_buffers[di].shape[:-1] or \ - self._in_buffers[di].dtype != data.dtype: - raise TypeError('data must dtype %s and shape[:-1]==%s, ' - 'got dtype %s shape[:-1]=%s' - % (self._in_buffers[di].dtype, - self._in_buffers[di].shape[:-1], - data.dtype, data.shape[:-1])) - logger.debug(' + Appending %d->%d' - % (self._in_offset, self._in_offset + data.shape[-1])) - self._in_buffers[di] = np.concatenate( - [self._in_buffers[di], data], -1) + self._in_buffers[di] = np.empty(data.shape[:-1] + (0,), data.dtype) + if ( + data.shape[:-1] != self._in_buffers[di].shape[:-1] + or self._in_buffers[di].dtype != data.dtype + ): + raise TypeError( + "data must dtype %s and shape[:-1]==%s, " + "got dtype %s shape[:-1]=%s" + % ( + self._in_buffers[di].dtype, + self._in_buffers[di].shape[:-1], + data.dtype, + data.shape[:-1], + ) + ) + logger.debug( + " + Appending %d->%d" + % (self._in_offset, self._in_offset + data.shape[-1]) + ) + self._in_buffers[di] = np.concatenate([self._in_buffers[di], data], -1) if self._in_offset > self.stops[-1]: - raise ValueError('data (shape %s) exceeded expected total ' - 'buffer size (%s > %s)' - % (data.shape, self._in_offset, - self.stops[-1])) + raise ValueError( + "data (shape %s) exceeded expected total " + "buffer size (%s > %s)" + % (data.shape, self._in_offset, self.stops[-1]) + ) # Check to see if we can process the next chunk and dump outputs - while self._idx < len(self.starts) and \ - self._in_offset >= self.stops[self._idx]: + while self._idx < len(self.starts) and self._in_offset >= self.stops[self._idx]: start, stop = self.starts[self._idx], self.stops[self._idx] this_len = stop - start this_window = self._window.copy() if self._idx == len(self.starts) - 1: this_window = np.pad( - self._window, (0, this_len - len(this_window)), 'constant') + self._window, (0, this_len - len(this_window)), "constant" + ) for offset in range(self._step, len(this_window), self._step): n_use = len(this_window) - offset this_window[offset:] += self._window[:n_use] if self._idx == 0: - for offset in range(self._n_samples - self._step, 0, - -self._step): + for offset in range(self._n_samples - self._step, 0, -self._step): this_window[:offset] += self._window[-offset:] - logger.debug(' * Processing %d->%d' % (start, stop)) - this_proc = [in_[..., :this_len].copy() - for in_ in self._in_buffers] - if not all(proc.shape[-1] == this_len == this_window.size - for proc in this_proc): - raise RuntimeError('internal indexing error') + logger.debug(" * Processing %d->%d" % (start, stop)) + this_proc = [in_[..., :this_len].copy() for in_ in self._in_buffers] + if not all( + proc.shape[-1] == this_len == this_window.size for proc in this_proc + ): + raise RuntimeError("internal indexing error") outs = self._process(*this_proc, **kwargs) if self._out_buffers is None: max_len = np.max(self.stops - self.starts) - self._out_buffers = [np.zeros(o.shape[:-1] + (max_len,), - o.dtype) for o in outs] + self._out_buffers = [ + np.zeros(o.shape[:-1] + (max_len,), o.dtype) for o in outs + ] for oi, out in enumerate(outs): out *= this_window - self._out_buffers[oi][..., :stop - start] += out + self._out_buffers[oi][..., : stop - start] += out self._idx += 1 if self._idx < len(self.starts): next_start = self.starts[self._idx] @@ -389,29 +435,29 @@ def feed(self, *datas, verbose=None, **kwargs): delta = next_start - self.starts[self._idx - 1] for di in range(len(self._in_buffers)): self._in_buffers[di] = self._in_buffers[di][..., delta:] - logger.debug(' - Shifting input/output buffers by %d samples' - % (delta,)) + logger.debug(" - Shifting input/output buffers by %d samples" % (delta,)) self._store(*[o[..., :delta] for o in self._out_buffers]) for ob in self._out_buffers: ob[..., :-delta] = ob[..., delta:] - ob[..., -delta:] = 0. + ob[..., -delta:] = 0.0 def _check_cola(win, nperseg, step, window_name, tol=1e-10): """Check whether the Constant OverLap Add (COLA) constraint is met.""" # adapted from SciPy - binsums = np.sum([win[ii * step:(ii + 1) * step] - for ii in range(nperseg // step)], axis=0) + binsums = np.sum( + [win[ii * step : (ii + 1) * step] for ii in range(nperseg // step)], axis=0 + ) if nperseg % step != 0: - binsums[:nperseg % step] += win[-(nperseg % step):] + binsums[: nperseg % step] += win[-(nperseg % step) :] const = np.median(binsums) deviation = np.max(np.abs(binsums - const)) if deviation > tol: - raise ValueError('segment length %d with step %d for %s window ' - 'type does not provide a constant output ' - '(%g%% deviation)' - % (nperseg, step, window_name, - 100 * deviation / const)) + raise ValueError( + "segment length %d with step %d for %s window " + "type does not provide a constant output " + "(%g%% deviation)" % (nperseg, step, window_name, 100 * deviation / const) + ) return const @@ -421,16 +467,16 @@ class _Storer: def __init__(self, *outs, picks=None): for oi, out in enumerate(outs): if not isinstance(out, np.ndarray) or out.ndim < 1: - raise TypeError('outs[oi] must be >= 1D ndarray, got %s' - % (out,)) + raise TypeError("outs[oi] must be >= 1D ndarray, got %s" % (out,)) self.outs = outs self.idx = 0 self.picks = picks def __call__(self, *outs): - if (len(outs) != len(self.outs) or - not all(out.shape[-1] == outs[0].shape[-1] for out in outs)): - raise ValueError('Bad outs') + if len(outs) != len(self.outs) or not all( + out.shape[-1] == outs[0].shape[-1] for out in outs + ): + raise ValueError("Bad outs") idx = (Ellipsis,) if self.picks is not None: idx += (self.picks,) diff --git a/mne/annotations.py b/mne/annotations.py index 00de96d32e4..9df381bfc87 100644 --- a/mne/annotations.py +++ b/mne/annotations.py @@ -16,15 +16,38 @@ from textwrap import shorten import numpy as np -from .utils import (_pl, check_fname, _validate_type, verbose, warn, logger, - _check_pandas_installed, _mask_to_onsets_offsets, - _DefaultEventParser, _check_dt, _stamp_to_dt, _dt_to_stamp, - _check_fname, int_like, _check_option, fill_doc, - _on_missing, _is_numeric, _check_dict_keys) - -from .io.write import (start_block, end_block, write_float, - write_name_list_sanitized, _safe_name_list, - write_double, start_file, write_string) +from .utils import ( + _pl, + check_fname, + _validate_type, + verbose, + warn, + logger, + _check_pandas_installed, + _mask_to_onsets_offsets, + _DefaultEventParser, + _check_dt, + _stamp_to_dt, + _dt_to_stamp, + _check_fname, + int_like, + _check_option, + fill_doc, + _on_missing, + _is_numeric, + _check_dict_keys, +) + +from .io.write import ( + start_block, + end_block, + write_float, + write_name_list_sanitized, + _safe_name_list, + write_double, + start_file, + write_string, +) from .io.constants import FIFF from .io.open import fiff_open from .io.tree import dir_tree_find @@ -38,41 +61,46 @@ def _check_o_d_s_c(onset, duration, description, ch_names): onset = np.atleast_1d(np.array(onset, dtype=float)) if onset.ndim != 1: - raise ValueError('Onset must be a one dimensional array, got %s ' - '(shape %s).' - % (onset.ndim, onset.shape)) + raise ValueError( + "Onset must be a one dimensional array, got %s " + "(shape %s)." % (onset.ndim, onset.shape) + ) duration = np.array(duration, dtype=float) if duration.ndim == 0 or duration.shape == (1,): duration = np.repeat(duration, len(onset)) if duration.ndim != 1: - raise ValueError('Duration must be a one dimensional array, ' - 'got %d.' % (duration.ndim,)) + raise ValueError( + "Duration must be a one dimensional array, " "got %d." % (duration.ndim,) + ) description = np.array(description, dtype=str) if description.ndim == 0 or description.shape == (1,): description = np.repeat(description, len(onset)) if description.ndim != 1: - raise ValueError('Description must be a one dimensional array, ' - 'got %d.' % (description.ndim,)) - _safe_name_list(description, 'write', 'description') + raise ValueError( + "Description must be a one dimensional array, " + "got %d." % (description.ndim,) + ) + _safe_name_list(description, "write", "description") # ch_names: convert to ndarray of tuples - _validate_type(ch_names, (None, tuple, list, np.ndarray), 'ch_names') + _validate_type(ch_names, (None, tuple, list, np.ndarray), "ch_names") if ch_names is None: ch_names = [()] * len(onset) ch_names = list(ch_names) for ai, ch in enumerate(ch_names): - _validate_type(ch, (list, tuple, np.ndarray), f'ch_names[{ai}]') + _validate_type(ch, (list, tuple, np.ndarray), f"ch_names[{ai}]") ch_names[ai] = tuple(ch) for ci, name in enumerate(ch_names[ai]): - _validate_type(name, str, f'ch_names[{ai}][{ci}]') + _validate_type(name, str, f"ch_names[{ai}][{ci}]") ch_names = _ndarray_ch_names(ch_names) if not (len(onset) == len(duration) == len(description) == len(ch_names)): raise ValueError( - 'Onset, duration, description, and ch_names must be ' - f'equal in sizes, got {len(onset)}, {len(duration)}, ' - f'{len(description)}, and {len(ch_names)}.') + "Onset, duration, description, and ch_names must be " + f"equal in sizes, got {len(onset)}, {len(duration)}, " + f"{len(description)}, and {len(ch_names)}." + ) return onset, duration, description, ch_names @@ -247,11 +275,13 @@ class Annotations: :meth:`Raw.save() ` notes for details. """ # noqa: E501 - def __init__(self, onset, duration, description, - orig_time=None, ch_names=None): # noqa: D102 + def __init__( + self, onset, duration, description, orig_time=None, ch_names=None + ): # noqa: D102 self._orig_time = _handle_meas_date(orig_time) - self.onset, self.duration, self.description, self.ch_names = \ - _check_o_d_s_c(onset, duration, description, ch_names) + self.onset, self.duration, self.description, self.ch_names = _check_o_d_s_c( + onset, duration, description, ch_names + ) self._sort() # ensure we're sorted @property @@ -263,21 +293,27 @@ def __eq__(self, other): """Compare to another Annotations instance.""" if not isinstance(other, Annotations): return False - return (np.array_equal(self.onset, other.onset) and - np.array_equal(self.duration, other.duration) and - np.array_equal(self.description, other.description) and - np.array_equal(self.ch_names, other.ch_names) and - self.orig_time == other.orig_time) + return ( + np.array_equal(self.onset, other.onset) + and np.array_equal(self.duration, other.duration) + and np.array_equal(self.description, other.description) + and np.array_equal(self.ch_names, other.ch_names) + and self.orig_time == other.orig_time + ) def __repr__(self): """Show the representation.""" counter = Counter(self.description) - kinds = ', '.join(['%s (%s)' % k for k in sorted(counter.items())]) - kinds = (': ' if len(kinds) > 0 else '') + kinds - ch_specific = ', channel-specific' if self._any_ch_names() else '' - s = ('Annotations | %s segment%s%s%s' % - (len(self.onset), _pl(len(self.onset)), ch_specific, kinds)) - return '<' + shorten(s, width=77, placeholder=' ...') + '>' + kinds = ", ".join(["%s (%s)" % k for k in sorted(counter.items())]) + kinds = (": " if len(kinds) > 0 else "") + kinds + ch_specific = ", channel-specific" if self._any_ch_names() else "" + s = "Annotations | %s segment%s%s%s" % ( + len(self.onset), + _pl(len(self.onset)), + ch_specific, + kinds, + ) + return "<" + shorten(s, width=77, placeholder=" ...") + ">" def __len__(self): """Return the number of annotations. @@ -303,12 +339,14 @@ def __iadd__(self, other): if len(self) == 0: self._orig_time = other.orig_time if self.orig_time != other.orig_time: - raise ValueError("orig_time should be the same to " - "add/concatenate 2 annotations " - "(got %s != %s)" % (self.orig_time, - other.orig_time)) - return self.append(other.onset, other.duration, other.description, - other.ch_names) + raise ValueError( + "orig_time should be the same to " + "add/concatenate 2 annotations " + "(got %s != %s)" % (self.orig_time, other.orig_time) + ) + return self.append( + other.onset, other.duration, other.description, other.ch_names + ) def __iter__(self): """Iterate over the annotations.""" @@ -321,21 +359,26 @@ def __iter__(self): def __getitem__(self, key, *, with_ch_names=None): """Propagate indexing and slicing to the underlying numpy structure.""" if isinstance(key, int_like): - out_keys = ('onset', 'duration', 'description', 'orig_time') - out_vals = (self.onset[key], self.duration[key], - self.description[key], self.orig_time) - if with_ch_names or (with_ch_names is None and - self._any_ch_names()): - out_keys += ('ch_names',) + out_keys = ("onset", "duration", "description", "orig_time") + out_vals = ( + self.onset[key], + self.duration[key], + self.description[key], + self.orig_time, + ) + if with_ch_names or (with_ch_names is None and self._any_ch_names()): + out_keys += ("ch_names",) out_vals += (self.ch_names[key],) return OrderedDict(zip(out_keys, out_vals)) else: key = list(key) if isinstance(key, tuple) else key - return Annotations(onset=self.onset[key], - duration=self.duration[key], - description=self.description[key], - orig_time=self.orig_time, - ch_names=self.ch_names[key]) + return Annotations( + onset=self.onset[key], + duration=self.duration[key], + description=self.description[key], + orig_time=self.orig_time, + ch_names=self.ch_names[key], + ) @fill_doc def append(self, onset, duration, description, ch_names=None): @@ -367,7 +410,8 @@ def append(self, onset, duration, description, ch_names=None): `list.extend `__. """ # noqa: E501 onset, duration, description, ch_names = _check_o_d_s_c( - onset, duration, description, ch_names) + onset, duration, description, ch_names + ) self.onset = np.append(self.onset, onset) self.duration = np.append(self.duration, duration) self.description = np.append(self.description, description) @@ -415,8 +459,7 @@ def to_data_frame(self): dt = _handle_meas_date(0) dt = dt.replace(tzinfo=None) onsets_dt = [dt + timedelta(seconds=o) for o in self.onset] - df = dict(onset=onsets_dt, duration=self.duration, - description=self.description) + df = dict(onset=onsets_dt, duration=self.duration, description=self.description) if self._any_ch_names(): df.update(ch_names=self.ch_names) df = pd.DataFrame(df) @@ -428,7 +471,7 @@ def _any_ch_names(self): def _prune_ch_names(self, info, on_missing): # this prunes channel names and if a given channel-specific annotation # no longer has any channels left, it gets dropped - keep = set(info['ch_names']) + keep = set(info["ch_names"]) ch_names = self.ch_names warned = False drop_idx = list() @@ -439,8 +482,10 @@ def _prune_ch_names(self, info, on_missing): if name not in keep: if not warned: _on_missing( - on_missing, 'At least one channel name in ' - f'annotations missing from info: {name}') + on_missing, + "At least one channel name in " + f"annotations missing from info: {name}", + ) warned = True else: names.append(name) @@ -477,9 +522,18 @@ def save(self, fname, *, overwrite=False, verbose=None): whereas :file:`.txt` files store onset as seconds since start of the recording (e.g., ``45.95597082905339``). """ - check_fname(fname, 'annotations', ('-annot.fif', '-annot.fif.gz', - '_annot.fif', '_annot.fif.gz', - '.txt', '.csv')) + check_fname( + fname, + "annotations", + ( + "-annot.fif", + "-annot.fif.gz", + "_annot.fif", + "_annot.fif.gz", + ".txt", + ".csv", + ), + ) fname = _check_fname(fname, overwrite=overwrite) if fname.suffix == ".txt": _write_annotations_txt(fname, self) @@ -501,8 +555,9 @@ def _sort(self): self.ch_names = self.ch_names[order] @verbose - def crop(self, tmin=None, tmax=None, emit_warning=False, - use_orig_time=True, verbose=None): + def crop( + self, tmin=None, tmax=None, emit_warning=False, use_orig_time=True, verbose=None + ): """Remove all annotation that are outside of [tmin, tmax]. The method operates inplace. @@ -535,39 +590,42 @@ def crop(self, tmin=None, tmax=None, emit_warning=False, if tmin is None: tmin = timedelta(seconds=self.onset.min()) + offset if tmax is None: - tmax = timedelta( - seconds=(self.onset + self.duration).max()) + offset - for key, val in [('tmin', tmin), ('tmax', tmax)]: - _validate_type(val, ('numeric', _datetime), key, - 'numeric, datetime, or None') + tmax = timedelta(seconds=(self.onset + self.duration).max()) + offset + for key, val in [("tmin", tmin), ("tmax", tmax)]: + _validate_type( + val, ("numeric", _datetime), key, "numeric, datetime, or None" + ) absolute_tmin = _handle_meas_date(tmin) absolute_tmax = _handle_meas_date(tmax) del tmin, tmax if absolute_tmin > absolute_tmax: - raise ValueError('tmax should be greater than or equal to tmin ' - '(%s < %s).' % (absolute_tmin, absolute_tmax)) - logger.debug('Cropping annotations %s - %s' % (absolute_tmin, - absolute_tmax)) + raise ValueError( + "tmax should be greater than or equal to tmin " + "(%s < %s)." % (absolute_tmin, absolute_tmax) + ) + logger.debug("Cropping annotations %s - %s" % (absolute_tmin, absolute_tmax)) onsets, durations, descriptions, ch_names = [], [], [], [] out_of_bounds, clip_left_elem, clip_right_elem = [], [], [] - for idx, (onset, duration, description, ch) in enumerate(zip( - self.onset, self.duration, self.description, self.ch_names)): + for idx, (onset, duration, description, ch) in enumerate( + zip(self.onset, self.duration, self.description, self.ch_names) + ): # if duration is NaN behave like a zero if np.isnan(duration): - duration = 0. + duration = 0.0 # convert to absolute times absolute_onset = timedelta(seconds=onset) + offset absolute_offset = absolute_onset + timedelta(seconds=duration) out_of_bounds.append( - absolute_onset > absolute_tmax or - absolute_offset < absolute_tmin) + absolute_onset > absolute_tmax or absolute_offset < absolute_tmin + ) if out_of_bounds[-1]: clip_left_elem.append(False) clip_right_elem.append(False) logger.debug( - f' [{idx}] Dropping ' - f'({absolute_onset} - {absolute_offset}: {description})') + f" [{idx}] Dropping " + f"({absolute_onset} - {absolute_offset}: {description})" + ) else: # clip the left side clip_left_elem.append(absolute_onset < absolute_tmin) @@ -577,19 +635,18 @@ def crop(self, tmin=None, tmax=None, emit_warning=False, if clip_right_elem[-1]: absolute_offset = absolute_tmax if clip_left_elem[-1] or clip_right_elem[-1]: - durations.append( - (absolute_offset - absolute_onset).total_seconds()) + durations.append((absolute_offset - absolute_onset).total_seconds()) else: durations.append(duration) - onsets.append( - (absolute_onset - offset).total_seconds()) + onsets.append((absolute_onset - offset).total_seconds()) logger.debug( - f' [{idx}] Keeping ' - f'({absolute_onset} - {absolute_offset} -> ' - f'{onset} - {onset + duration})') + f" [{idx}] Keeping " + f"({absolute_onset} - {absolute_offset} -> " + f"{onset} - {onset + duration})" + ) descriptions.append(description) ch_names.append(ch) - logger.debug(f'Cropping complete (kept {len(onsets)})') + logger.debug(f"Cropping complete (kept {len(onsets)})") self.onset = np.array(onsets, float) self.duration = np.array(durations, float) assert (self.duration >= 0).all() @@ -599,13 +656,16 @@ def crop(self, tmin=None, tmax=None, emit_warning=False, if emit_warning: omitted = np.array(out_of_bounds).sum() if omitted > 0: - warn('Omitted %s annotation(s) that were outside data' - ' range.' % omitted) - limited = (np.array(clip_left_elem) | - np.array(clip_right_elem)).sum() + warn( + "Omitted %s annotation(s) that were outside data" + " range." % omitted + ) + limited = (np.array(clip_left_elem) | np.array(clip_right_elem)).sum() if limited > 0: - warn('Limited %s annotation(s) that were expanding outside the' - ' data range.' % limited) + warn( + "Limited %s annotation(s) that were expanding outside the" + " data range." % limited + ) return self @@ -634,9 +694,12 @@ def set_durations(self, mapping, verbose=None): _validate_type(mapping, (int, float, dict)) if isinstance(mapping, dict): - _check_dict_keys(mapping, self.description, - valid_key_source="data", - key_description="Annotation description(s)") + _check_dict_keys( + mapping, + self.description, + valid_key_source="data", + key_description="Annotation description(s)", + ) for stim in mapping: map_idx = [desc == stim for desc in self.description] self.duration[map_idx] = mapping[stim] @@ -645,9 +708,11 @@ def set_durations(self, mapping, verbose=None): self.duration = np.ones(self.description.shape) * mapping else: - raise ValueError("Setting durations requires the mapping of " - "descriptions to times to be provided as a dict. " - f"Instead {type(mapping)} was provided.") + raise ValueError( + "Setting durations requires the mapping of " + "descriptions to times to be provided as a dict. " + f"Instead {type(mapping)} was provided." + ) return self @@ -672,10 +737,13 @@ def rename(self, mapping, verbose=None): .. versionadded:: 0.24.0 """ _validate_type(mapping, dict) - _check_dict_keys(mapping, self.description, valid_key_source="data", - key_description="Annotation description(s)") - self.description = np.array( - [str(mapping.get(d, d)) for d in self.description]) + _check_dict_keys( + mapping, + self.description, + valid_key_source="data", + key_description="Annotation description(s)", + ) + self.description = np.array([str(mapping.get(d, d)) for d in self.description]) return self @@ -687,8 +755,7 @@ def annotations(self): # noqa: D102 return self._annotations @verbose - def set_annotations(self, annotations, on_missing='raise', *, - verbose=None): + def set_annotations(self, annotations, on_missing="raise", *, verbose=None): """Setter for Epoch annotations from Raw. This method does not handle offsetting the times based @@ -728,16 +795,18 @@ def set_annotations(self, annotations, on_missing='raise', *, .. versionadded:: 1.0 """ - _validate_type(annotations, (Annotations, None), 'annotations') + _validate_type(annotations, (Annotations, None), "annotations") if annotations is None: self._annotations = None else: - if getattr(self, '_unsafe_annot_add', False): - warn('Adding annotations to Epochs created (and saved to ' - 'disk) before 1.0 will yield incorrect results if ' - 'decimation or resampling was performed on the instance, ' - 'we recommend regenerating the Epochs and re-saving them ' - 'to disk') + if getattr(self, "_unsafe_annot_add", False): + warn( + "Adding annotations to Epochs created (and saved to " + "disk) before 1.0 will yield incorrect results if " + "decimation or resampling was performed on the instance, " + "we recommend regenerating the Epochs and re-saving them " + "to disk" + ) new_annotations = annotations.copy() new_annotations._prune_ch_names(self.info, on_missing) self._annotations = new_annotations @@ -766,8 +835,9 @@ def get_annotations_per_epoch(self): # when each epoch and annotation starts/stops # no need to account for first_samp here... epoch_tzeros = self.events[:, 0] / self._raw_sfreq - epoch_starts, epoch_stops = np.atleast_2d( - epoch_tzeros) + np.atleast_2d(self.times[[0, -1]]).T + epoch_starts, epoch_stops = ( + np.atleast_2d(epoch_tzeros) + np.atleast_2d(self.times[[0, -1]]).T + ) # ... because first_samp isn't accounted for here either annot_starts = self._annotations.onset annot_stops = annot_starts + self._annotations.duration @@ -779,33 +849,40 @@ def get_annotations_per_epoch(self): # we care about is presence/absence of overlap). annot_straddles_epoch_start = np.logical_and( np.atleast_2d(epoch_starts) >= np.atleast_2d(annot_starts).T, - np.atleast_2d(epoch_starts) < np.atleast_2d(annot_stops).T) + np.atleast_2d(epoch_starts) < np.atleast_2d(annot_stops).T, + ) annot_straddles_epoch_end = np.logical_and( np.atleast_2d(epoch_stops) > np.atleast_2d(annot_starts).T, - np.atleast_2d(epoch_stops) <= np.atleast_2d(annot_stops).T) + np.atleast_2d(epoch_stops) <= np.atleast_2d(annot_stops).T, + ) # this captures the only remaining case we care about: annotations # fully contained within an epoch (or exactly coextensive with it). annot_fully_within_epoch = np.logical_and( np.atleast_2d(epoch_starts) <= np.atleast_2d(annot_starts).T, - np.atleast_2d(epoch_stops) >= np.atleast_2d(annot_stops).T) + np.atleast_2d(epoch_stops) >= np.atleast_2d(annot_stops).T, + ) # combine all cases to get array of shape (n_annotations, n_epochs). # Nonzero entries indicate overlap between the corresponding # annotation (row index) and epoch (column index). - all_cases = (annot_straddles_epoch_start + - annot_straddles_epoch_end + - annot_fully_within_epoch) + all_cases = ( + annot_straddles_epoch_start + + annot_straddles_epoch_end + + annot_fully_within_epoch + ) # for each Epoch-Annotation overlap occurrence: for annot_ix, epo_ix in zip(*np.nonzero(all_cases)): this_annot = self._annotations[annot_ix] this_tzero = epoch_tzeros[epo_ix] # adjust annotation onset to be relative to epoch tzero... - annot = (this_annot['onset'] - this_tzero, - this_annot['duration'], - this_annot['description']) + annot = ( + this_annot["onset"] - this_tzero, + this_annot["duration"], + this_annot["description"], + ) # ...then add it to the correct sublist of `epoch_annot_list` epoch_annot_list[epo_ix].append(annot) return epoch_annot_list @@ -841,8 +918,10 @@ def add_annotations_to_metadata(self, overwrite=False): # check if annotations exist if self.annotations is None: - warn(f'There were no Annotations stored in {self}, so ' - 'metadata was not modified.') + warn( + f"There were no Annotations stored in {self}, so " + "metadata was not modified." + ) return self # get existing metadata DataFrame or instantiate an empty one @@ -852,12 +931,17 @@ def add_annotations_to_metadata(self, overwrite=False): data = np.empty((len(self.events), 0)) metadata = pd.DataFrame(data=data) - if any(name in metadata.columns for name in - ['annot_onset', 'annot_duration', 'annot_description']) and \ - not overwrite: + if ( + any( + name in metadata.columns + for name in ["annot_onset", "annot_duration", "annot_description"] + ) + and not overwrite + ): raise RuntimeError( - 'Metadata for Epochs already contains columns ' - '"annot_onset", "annot_duration", or "annot_description".') + "Metadata for Epochs already contains columns " + '"annot_onset", "annot_duration", or "annot_description".' + ) # get the Epoch annotations, then convert to separate lists for # onsets, durations, and descriptions @@ -875,17 +959,18 @@ def add_annotations_to_metadata(self, overwrite=False): # Create a new Annotations column that is instantiated as an empty # list per Epoch. - metadata['annot_onset'] = pd.Series(onset) - metadata['annot_duration'] = pd.Series(duration) - metadata['annot_description'] = pd.Series(description) + metadata["annot_onset"] = pd.Series(onset) + metadata["annot_duration"] = pd.Series(duration) + metadata["annot_description"] = pd.Series(description) # reset the metadata self.metadata = metadata return self -def _combine_annotations(one, two, one_n_samples, one_first_samp, - two_first_samp, sfreq): +def _combine_annotations( + one, two, one_n_samples, one_first_samp, two_first_samp, sfreq +): """Combine a tuple of annotations.""" assert one is not None assert two is not None @@ -909,7 +994,7 @@ def _handle_meas_date(meas_date): time. """ if isinstance(meas_date, str): - ACCEPTED_ISO8601 = '%Y-%m-%d %H:%M:%S.%f' + ACCEPTED_ISO8601 = "%Y-%m-%d %H:%M:%S.%f" try: meas_date = datetime.strptime(meas_date, ACCEPTED_ISO8601) except ValueError: @@ -937,13 +1022,12 @@ def _handle_meas_date(meas_date): def _sync_onset(raw, onset, inverse=False): """Adjust onsets in relation to raw data.""" offset = (-1 if inverse else 1) * raw._first_time - assert raw.info['meas_date'] == raw.annotations.orig_time + assert raw.info["meas_date"] == raw.annotations.orig_time annot_start = onset - offset return annot_start -def _annotations_starts_stops(raw, kinds, name='skip_by_annotation', - invert=False): +def _annotations_starts_stops(raw, kinds, name="skip_by_annotation", invert=False): """Get starts and stops from given kinds. onsets and ends are inclusive. @@ -953,14 +1037,16 @@ def _annotations_starts_stops(raw, kinds, name='skip_by_annotation', kinds = [kinds] else: for kind in kinds: - _validate_type(kind, 'str', "All entries") + _validate_type(kind, "str", "All entries") if len(raw.annotations) == 0: onsets, ends = np.array([], int), np.array([], int) else: - idxs = [idx for idx, desc in enumerate(raw.annotations.description) - if any(desc.upper().startswith(kind.upper()) - for kind in kinds)] + idxs = [ + idx + for idx, desc in enumerate(raw.annotations.description) + if any(desc.upper().startswith(kind.upper()) for kind in kinds) + ] # onsets are already sorted onsets = raw.annotations.onset[idxs] onsets = _sync_onset(raw, onsets) @@ -975,7 +1061,7 @@ def _annotations_starts_stops(raw, kinds, name='skip_by_annotation', for onset, end in zip(onsets, ends): mask[onset:end] = True mask = ~mask - extras = (onsets == ends) + extras = onsets == ends extra_onsets, extra_ends = onsets[extras], ends[extras] onsets, ends = _mask_to_onsets_offsets(mask) # Keep ones where things were exactly equal @@ -992,25 +1078,28 @@ def _write_annotations(fid, annotations): """Write annotations.""" start_block(fid, FIFF.FIFFB_MNE_ANNOTATIONS) write_float(fid, FIFF.FIFF_MNE_BASELINE_MIN, annotations.onset) - write_float(fid, FIFF.FIFF_MNE_BASELINE_MAX, - annotations.duration + annotations.onset) + write_float( + fid, FIFF.FIFF_MNE_BASELINE_MAX, annotations.duration + annotations.onset + ) write_name_list_sanitized( - fid, FIFF.FIFF_COMMENT, annotations.description, name='description') + fid, FIFF.FIFF_COMMENT, annotations.description, name="description" + ) if annotations.orig_time is not None: - write_double(fid, FIFF.FIFF_MEAS_DATE, - _dt_to_stamp(annotations.orig_time)) + write_double(fid, FIFF.FIFF_MEAS_DATE, _dt_to_stamp(annotations.orig_time)) if annotations._any_ch_names(): - write_string(fid, FIFF.FIFF_MNE_EPOCHS_DROP_LOG, - json.dumps(tuple(annotations.ch_names))) + write_string( + fid, FIFF.FIFF_MNE_EPOCHS_DROP_LOG, json.dumps(tuple(annotations.ch_names)) + ) end_block(fid, FIFF.FIFFB_MNE_ANNOTATIONS) def _write_annotations_csv(fname, annot): annot = annot.to_data_frame() - if 'ch_names' in annot: - annot['ch_names'] = [ - _safe_name_list(ch, 'write', name=f'annot["ch_names"][{ci}') - for ci, ch in enumerate(annot['ch_names'])] + if "ch_names" in annot: + annot["ch_names"] = [ + _safe_name_list(ch, "write", name=f'annot["ch_names"][{ci}') + for ci, ch in enumerate(annot["ch_names"]) + ] annot.to_csv(fname, index=False) @@ -1022,21 +1111,24 @@ def _write_annotations_txt(fname, annot): content += "# onset, duration, description" data = [annot.onset, annot.duration, annot.description] if annot._any_ch_names(): - content += ', ch_names' - data.append([ - _safe_name_list(ch, 'write', f'annot.ch_names[{ci}]') - for ci, ch in enumerate(annot.ch_names)]) - content += '\n' + content += ", ch_names" + data.append( + [ + _safe_name_list(ch, "write", f"annot.ch_names[{ci}]") + for ci, ch in enumerate(annot.ch_names) + ] + ) + content += "\n" data = np.array(data, dtype=str).T assert data.ndim == 2 assert data.shape[0] == len(annot.onset) assert data.shape[1] in (3, 4) - with open(fname, 'wb') as fid: + with open(fname, "wb") as fid: fid.write(content.encode()) - np.savetxt(fid, data, delimiter=',', fmt="%s") + np.savetxt(fid, data, delimiter=",", fmt="%s") -def read_annotations(fname, sfreq='auto', uint16_codec=None): +def read_annotations(fname, sfreq="auto", uint16_codec=None): r"""Read annotations from a file. This function reads a ``.fif``, ``.fif.gz``, ``.vmrk``, ``.amrk``, @@ -1093,46 +1185,49 @@ def read_annotations(fname, sfreq='auto', uint16_codec=None): ) ) name = op.basename(fname) - if name.endswith(('fif', 'fif.gz')): + if name.endswith(("fif", "fif.gz")): # Read FiF files ff, tree, _ = fiff_open(fname, preload=False) with ff as fid: annotations = _read_annotations_fif(fid, tree) - elif name.endswith('txt'): + elif name.endswith("txt"): orig_time = _read_annotations_txt_parse_header(fname) onset, duration, description, ch_names = _read_annotations_txt(fname) - annotations = Annotations(onset=onset, duration=duration, - description=description, orig_time=orig_time, - ch_names=ch_names) + annotations = Annotations( + onset=onset, + duration=duration, + description=description, + orig_time=orig_time, + ch_names=ch_names, + ) - elif name.endswith(('vmrk', 'amrk')): + elif name.endswith(("vmrk", "amrk")): annotations = _read_annotations_brainvision(fname, sfreq=sfreq) - elif name.endswith('csv'): + elif name.endswith("csv"): annotations = _read_annotations_csv(fname) - elif name.endswith('cnt'): + elif name.endswith("cnt"): annotations = _read_annotations_cnt(fname) - elif name.endswith('ds'): + elif name.endswith("ds"): annotations = _read_annotations_ctf(fname) - elif name.endswith('cef'): + elif name.endswith("cef"): annotations = _read_annotations_curry(fname, sfreq=sfreq) - elif name.endswith('set'): - annotations = _read_annotations_eeglab(fname, - uint16_codec=uint16_codec) + elif name.endswith("set"): + annotations = _read_annotations_eeglab(fname, uint16_codec=uint16_codec) - elif name.endswith(('edf', 'bdf', 'gdf')): + elif name.endswith(("edf", "bdf", "gdf")): onset, duration, description = _read_annotations_edf(fname) onset = np.array(onset, dtype=float) duration = np.array(duration, dtype=float) - annotations = Annotations(onset=onset, duration=duration, - description=description, - orig_time=None) + annotations = Annotations( + onset=onset, duration=duration, description=description, orig_time=None + ) - elif name.startswith('events_') and fname.endswith('mat'): + elif name.startswith("events_") and fname.endswith("mat"): annotations = _read_brainstorm_annotations(fname) else: raise OSError('Unknown annotation file format "%s"' % fname) @@ -1157,23 +1252,27 @@ def _read_annotations_csv(fname): """ pd = _check_pandas_installed(strict=True) df = pd.read_csv(fname, keep_default_na=False) - orig_time = df['onset'].values[0] + orig_time = df["onset"].values[0] try: float(orig_time) - warn('It looks like you have provided annotation onsets as floats. ' - 'These will be interpreted as MILLISECONDS. If that is not what ' - 'you want, save your CSV as a TXT file; the TXT reader accepts ' - 'onsets in seconds.') + warn( + "It looks like you have provided annotation onsets as floats. " + "These will be interpreted as MILLISECONDS. If that is not what " + "you want, save your CSV as a TXT file; the TXT reader accepts " + "onsets in seconds." + ) except ValueError: pass - onset_dt = pd.to_datetime(df['onset']) + onset_dt = pd.to_datetime(df["onset"]) onset = (onset_dt - onset_dt[0]).dt.total_seconds() - duration = df['duration'].values.astype(float) - description = df['description'].values + duration = df["duration"].values.astype(float) + description = df["description"].values ch_names = None - if 'ch_names' in df.columns: - ch_names = [_safe_name_list(val, 'read', 'annotation channel name') - for val in df['ch_names'].values] + if "ch_names" in df.columns: + ch_names = [ + _safe_name_list(val, "read", "annotation channel name") + for val in df["ch_names"].values + ] return Annotations(onset, duration, description, orig_time, ch_names) @@ -1205,33 +1304,34 @@ def get_duration_from_times(t): annot_data = io.loadmat(fname) onsets, durations, descriptions = (list(), list(), list()) - for label, _, _, _, times, _, _ in annot_data['events'][0]: + for label, _, _, _, times, _, _ in annot_data["events"][0]: onsets.append(times[0]) durations.append(get_duration_from_times(times)) n_annot = len(times[0]) descriptions += [str(label[0])] * n_annot - return Annotations(onset=np.concatenate(onsets), - duration=np.concatenate(durations), - description=descriptions, - orig_time=orig_time) + return Annotations( + onset=np.concatenate(onsets), + duration=np.concatenate(durations), + description=descriptions, + orig_time=orig_time, + ) def _is_iso8601(candidate_str): - ISO8601 = r'^\d{4}-\d{2}-\d{2}[ T]\d{2}:\d{2}:\d{2}\.\d{6}$' + ISO8601 = r"^\d{4}-\d{2}-\d{2}[ T]\d{2}:\d{2}:\d{2}\.\d{6}$" return re.compile(ISO8601).match(candidate_str) is not None def _read_annotations_txt_parse_header(fname): def is_orig_time(x): - return x.startswith('# orig_time :') + return x.startswith("# orig_time :") with open(fname) as fid: - header = list(takewhile(lambda x: x.startswith('#'), fid)) + header = list(takewhile(lambda x: x.startswith("#"), fid)) orig_values = [h[13:].strip() for h in header if is_orig_time(h)] - orig_values = [_handle_meas_date(orig) for orig in orig_values - if _is_iso8601(orig)] + orig_values = [_handle_meas_date(orig) for orig in orig_values if _is_iso8601(orig)] return None if not orig_values else orig_values[0] @@ -1239,13 +1339,12 @@ def is_orig_time(x): def _read_annotations_txt(fname): with warnings.catch_warnings(record=True): warnings.simplefilter("ignore") - out = np.loadtxt(fname, delimiter=',', - dtype=np.bytes_, unpack=True) + out = np.loadtxt(fname, delimiter=",", dtype=np.bytes_, unpack=True) ch_names = None if len(out) == 0: onset, duration, desc = [], [], [] else: - _check_option('text header', len(out), (3, 4)) + _check_option("text header", len(out), (3, 4)) if len(out) == 3: onset, duration, desc = out else: @@ -1256,8 +1355,9 @@ def _read_annotations_txt(fname): desc = [str(d.decode()).strip() for d in np.atleast_1d(desc)] if ch_names is not None: ch_names = [ - _safe_name_list(ch.decode().strip(), 'read', f'ch_names[{ci}]') - for ci, ch in enumerate(ch_names)] + _safe_name_list(ch.decode().strip(), "read", f"ch_names[{ci}]") + for ci, ch in enumerate(ch_names) + ] return onset, duration, desc, ch_names @@ -1270,7 +1370,7 @@ def _read_annotations_fif(fid, tree): annot_data = annot_data[0] orig_time = ch_names = None onset, duration, description = list(), list(), list() - for ent in annot_data['directory']: + for ent in annot_data["directory"]: kind = ent.kind pos = ent.pos tag = read_tag(fid, pos) @@ -1281,7 +1381,7 @@ def _read_annotations_fif(fid, tree): duration = tag.data duration = list() if duration is None else duration - onset elif kind == FIFF.FIFF_COMMENT: - description = _safe_name_list(tag.data, 'read', 'description') + description = _safe_name_list(tag.data, "read", "description") elif kind == FIFF.FIFF_MEAS_DATE: orig_time = tag.data try: @@ -1291,14 +1391,13 @@ def _read_annotations_fif(fid, tree): elif kind == FIFF.FIFF_MNE_EPOCHS_DROP_LOG: ch_names = tuple(tuple(x) for x in json.loads(tag.data)) assert len(onset) == len(duration) == len(description) - annotations = Annotations(onset, duration, description, - orig_time, ch_names) + annotations = Annotations(onset, duration, description, orig_time, ch_names) return annotations def _select_annotations_based_on_description(descriptions, event_id, regexp): """Get a collection of descriptions and returns index of selected.""" - regexp_comp = re.compile('.*' if regexp is None else regexp) + regexp_comp = re.compile(".*" if regexp is None else regexp) event_id_ = dict() dropped = [] @@ -1323,11 +1422,10 @@ def _select_annotations_based_on_description(descriptions, event_id, regexp): else: dropped.append(desc) - event_sel = [ii for ii, kk in enumerate(descriptions) - if kk in event_id_] + event_sel = [ii for ii, kk in enumerate(descriptions) if kk in event_id_] if len(event_sel) == 0 and regexp is not None: - raise ValueError('Could not find any of the events you specified.') + raise ValueError("Could not find any of the events you specified.") return event_sel, event_id_ @@ -1345,7 +1443,7 @@ def _select_events_based_on_id(events, event_desc): event_sel = [ii for ii, e in enumerate(events) if e[2] in event_desc_] if len(event_sel) == 0: - raise ValueError('Could not find any of the events you specified.') + raise ValueError("Could not find any of the events you specified.") return event_sel, event_desc_ @@ -1358,20 +1456,23 @@ def _check_event_id(event_id, raw): if event_id is None: return _DefaultEventParser() - elif event_id == 'auto': + elif event_id == "auto": if isinstance(raw, RawBrainVision): return _BVEventParser() - elif (isinstance(raw, (RawFIF, RawArray)) and - _check_bv_annot(raw.annotations.description)): - logger.info('Non-RawBrainVision raw using branvision markers') + elif isinstance(raw, (RawFIF, RawArray)) and _check_bv_annot( + raw.annotations.description + ): + logger.info("Non-RawBrainVision raw using branvision markers") return _BVEventParser() else: return _DefaultEventParser() elif callable(event_id) or isinstance(event_id, dict): return event_id else: - raise ValueError('Invalid type for event_id (should be None, str, ' - 'dict or callable). Got {}'.format(type(event_id))) + raise ValueError( + "Invalid type for event_id (should be None, str, " + "dict or callable). Got {}".format(type(event_id)) + ) def _check_event_description(event_desc, events): @@ -1381,28 +1482,34 @@ def _check_event_description(event_desc, events): if isinstance(event_desc, dict): for val in event_desc.values(): - _validate_type(val, (str, None), 'Event names') + _validate_type(val, (str, None), "Event names") elif isinstance(event_desc, Iterable): event_desc = np.asarray(event_desc) if event_desc.ndim != 1: - raise ValueError('event_desc must be 1D, got shape {}'.format( - event_desc.shape)) + raise ValueError( + "event_desc must be 1D, got shape {}".format(event_desc.shape) + ) event_desc = dict(zip(event_desc, map(str, event_desc))) elif callable(event_desc): pass else: - raise ValueError('Invalid type for event_desc (should be None, list, ' - '1darray, dict or callable). Got {}'.format( - type(event_desc))) + raise ValueError( + "Invalid type for event_desc (should be None, list, " + "1darray, dict or callable). Got {}".format(type(event_desc)) + ) return event_desc @verbose -def events_from_annotations(raw, event_id="auto", - regexp=r'^(?![Bb][Aa][Dd]|[Ee][Dd][Gg][Ee]).*$', - use_rounding=True, chunk_duration=None, - verbose=None): +def events_from_annotations( + raw, + event_id="auto", + regexp=r"^(?![Bb][Aa][Dd]|[Ee][Dd][Gg][Ee]).*$", + use_rounding=True, + chunk_duration=None, + verbose=None, +): """Get :term:`events` and ``event_id`` from an Annotations object. Parameters @@ -1473,11 +1580,13 @@ def events_from_annotations(raw, event_id="auto", event_id = _check_event_id(event_id, raw) event_sel, event_id_ = _select_annotations_based_on_description( - annotations.description, event_id=event_id, regexp=regexp) + annotations.description, event_id=event_id, regexp=regexp + ) if chunk_duration is None: - inds = raw.time_as_index(annotations.onset, use_rounding=use_rounding, - origin=annotations.orig_time) + inds = raw.time_as_index( + annotations.onset, use_rounding=use_rounding, origin=annotations.orig_time + ) if annotations.orig_time is not None: inds += raw.first_samp values = [event_id_[kk] for kk in annotations.description[event_sel]] @@ -1485,33 +1594,36 @@ def events_from_annotations(raw, event_id="auto", else: inds = values = np.array([]).astype(int) for annot in annotations[event_sel]: - annot_offset = annot['onset'] + annot['duration'] - _onsets = np.arange(start=annot['onset'], stop=annot_offset, - step=chunk_duration) + annot_offset = annot["onset"] + annot["duration"] + _onsets = np.arange( + start=annot["onset"], stop=annot_offset, step=chunk_duration + ) good_events = annot_offset - _onsets >= chunk_duration if good_events.any(): _onsets = _onsets[good_events] - _inds = raw.time_as_index(_onsets, - use_rounding=use_rounding, - origin=annotations.orig_time) + _inds = raw.time_as_index( + _onsets, use_rounding=use_rounding, origin=annotations.orig_time + ) _inds += raw.first_samp inds = np.append(inds, _inds) - _values = np.full(shape=len(_inds), - fill_value=event_id_[annot['description']], - dtype=int) + _values = np.full( + shape=len(_inds), + fill_value=event_id_[annot["description"]], + dtype=int, + ) values = np.append(values, _values) events = np.c_[inds, np.zeros(len(inds)), values].astype(int) - logger.info('Used Annotations descriptions: %s' % - (list(event_id_.keys()),)) + logger.info("Used Annotations descriptions: %s" % (list(event_id_.keys()),)) return events, event_id_ @verbose -def annotations_from_events(events, sfreq, event_desc=None, first_samp=0, - orig_time=None, verbose=None): +def annotations_from_events( + events, sfreq, event_desc=None, first_samp=0, orig_time=None, verbose=None +): """Convert an event array to an Annotations object. Parameters @@ -1569,10 +1681,9 @@ def annotations_from_events(events, sfreq, event_desc=None, first_samp=0, durations = np.zeros(len(events_sel)) # dummy durations # Create annotations - annots = Annotations(onset=onsets, - duration=durations, - description=descriptions, - orig_time=orig_time) + annots = Annotations( + onset=onsets, duration=durations, description=descriptions, orig_time=orig_time + ) return annots @@ -1581,5 +1692,5 @@ def _adjust_onset_meas_date(annot, raw): """Adjust the annotation onsets based on raw meas_date.""" # If there is a non-None meas date, then the onset should take into # account the first_samp / first_time. - if raw.info['meas_date'] is not None: + if raw.info["meas_date"] is not None: annot.onset += raw.first_time diff --git a/mne/baseline.py b/mne/baseline.py index 10b868b46f9..21aebdde807 100644 --- a/mne/baseline.py +++ b/mne/baseline.py @@ -9,20 +9,22 @@ from .utils import logger, verbose, _check_option -def _log_rescale(baseline, mode='mean'): +def _log_rescale(baseline, mode="mean"): """Log the rescaling method.""" if baseline is not None: - _check_option('mode', mode, ['logratio', 'ratio', 'zscore', 'mean', - 'percent', 'zlogratio']) - msg = 'Applying baseline correction (mode: %s)' % mode + _check_option( + "mode", + mode, + ["logratio", "ratio", "zscore", "mean", "percent", "zlogratio"], + ) + msg = "Applying baseline correction (mode: %s)" % mode else: - msg = 'No baseline correction applied' + msg = "No baseline correction applied" return msg @verbose -def rescale(data, times, baseline, mode='mean', copy=True, picks=None, - verbose=None): +def rescale(data, times, baseline, mode="mean", copy=True, picks=None, verbose=None): """Rescale (baseline correct) data. Parameters @@ -73,44 +75,60 @@ def rescale(data, times, baseline, mode='mean', copy=True, picks=None, else: imin = np.where(times >= bmin)[0] if len(imin) == 0: - raise ValueError('bmin is too large (%s), it exceeds the largest ' - 'time value' % (bmin,)) + raise ValueError( + "bmin is too large (%s), it exceeds the largest " "time value" % (bmin,) + ) imin = int(imin[0]) if bmax is None: imax = len(times) else: imax = np.where(times <= bmax)[0] if len(imax) == 0: - raise ValueError('bmax is too small (%s), it is smaller than the ' - 'smallest time value' % (bmax,)) + raise ValueError( + "bmax is too small (%s), it is smaller than the " + "smallest time value" % (bmax,) + ) imax = int(imax[-1]) + 1 if imin >= imax: - raise ValueError('Bad rescaling slice (%s:%s) from time values %s, %s' - % (imin, imax, bmin, bmax)) + raise ValueError( + "Bad rescaling slice (%s:%s) from time values %s, %s" + % (imin, imax, bmin, bmax) + ) # technically this is inefficient when `picks` is given, but assuming # that we generally pick most channels for rescaling, it's not so bad mean = np.mean(data[..., imin:imax], axis=-1, keepdims=True) - if mode == 'mean': + if mode == "mean": + def fun(d, m): d -= m - elif mode == 'ratio': + + elif mode == "ratio": + def fun(d, m): d /= m - elif mode == 'logratio': + + elif mode == "logratio": + def fun(d, m): d /= m np.log10(d, out=d) - elif mode == 'percent': + + elif mode == "percent": + def fun(d, m): d -= m d /= m - elif mode == 'zscore': + + elif mode == "zscore": + def fun(d, m): d -= m d /= np.std(d[..., imin:imax], axis=-1, keepdims=True) - elif mode == 'zlogratio': + + elif mode == "zlogratio": + def fun(d, m): d /= m np.log10(d, out=d) @@ -124,7 +142,7 @@ def fun(d, m): return data -def _check_baseline(baseline, times, sfreq, on_baseline_outside_data='raise'): +def _check_baseline(baseline, times, sfreq, on_baseline_outside_data="raise"): """Check if the baseline is valid, and adjust it if requested. ``None`` values inside the baseline parameter will be replaced with @@ -158,16 +176,20 @@ def _check_baseline(baseline, times, sfreq, on_baseline_outside_data='raise'): return None if not isinstance(baseline, tuple) or len(baseline) != 2: - raise ValueError(f'`baseline={baseline}` is an invalid argument, must ' - f'be a tuple of length 2 or None') + raise ValueError( + f"`baseline={baseline}` is an invalid argument, must " + f"be a tuple of length 2 or None" + ) tmin, tmax = times[0], times[-1] - tstep = 1. / float(sfreq) + tstep = 1.0 / float(sfreq) # check default value of baseline and `tmin=0` if baseline == (None, 0) and tmin == 0: - raise ValueError('Baseline interval is only one sample. Use ' - '`baseline=(0, 0)` if this is desired.') + raise ValueError( + "Baseline interval is only one sample. Use " + "`baseline=(0, 0)` if this is desired." + ) baseline_tmin, baseline_tmax = baseline @@ -182,17 +204,20 @@ def _check_baseline(baseline, times, sfreq, on_baseline_outside_data='raise'): if baseline_tmin > baseline_tmax: raise ValueError( "Baseline min (%s) must be less than baseline max (%s)" - % (baseline_tmin, baseline_tmax)) + % (baseline_tmin, baseline_tmax) + ) if (baseline_tmin < tmin - tstep) or (baseline_tmax > tmax + tstep): - msg = (f"Baseline interval [{baseline_tmin}, {baseline_tmax}] s " - f"is outside of epochs data [{tmin}, {tmax}] s. Epochs were " - f"probably cropped.") - if on_baseline_outside_data == 'raise': + msg = ( + f"Baseline interval [{baseline_tmin}, {baseline_tmax}] s " + f"is outside of epochs data [{tmin}, {tmax}] s. Epochs were " + f"probably cropped." + ) + if on_baseline_outside_data == "raise": raise ValueError(msg) - elif on_baseline_outside_data == 'info': + elif on_baseline_outside_data == "info": logger.info(msg) - elif on_baseline_outside_data == 'adjust': + elif on_baseline_outside_data == "adjust": if baseline_tmin < tmin - tstep: baseline_tmin = tmin if baseline_tmax > tmax + tstep: diff --git a/mne/beamformer/__init__.py b/mne/beamformer/__init__.py index b82add2a7cc..a1e233686c4 100644 --- a/mne/beamformer/__init__.py +++ b/mne/beamformer/__init__.py @@ -1,9 +1,19 @@ """Beamformers for source localization.""" -from ._lcmv import (make_lcmv, apply_lcmv, apply_lcmv_epochs, apply_lcmv_raw, - apply_lcmv_cov) -from ._dics import (make_dics, apply_dics, apply_dics_epochs, - apply_dics_tfr_epochs, apply_dics_csd) +from ._lcmv import ( + make_lcmv, + apply_lcmv, + apply_lcmv_epochs, + apply_lcmv_raw, + apply_lcmv_cov, +) +from ._dics import ( + make_dics, + apply_dics, + apply_dics_epochs, + apply_dics_tfr_epochs, + apply_dics_csd, +) from ._rap_music import rap_music from ._compute_beamformer import Beamformer, read_beamformer from .resolution_matrix import make_lcmv_resolution_matrix diff --git a/mne/beamformer/_compute_beamformer.py b/mne/beamformer/_compute_beamformer.py index bfb547e9712..adc0c14ce40 100644 --- a/mne/beamformer/_compute_beamformer.py +++ b/mne/beamformer/_compute_beamformer.py @@ -15,78 +15,124 @@ from ..io.proj import make_projector, Projection from ..minimum_norm.inverse import _get_vertno, _prepare_forward from ..source_space import label_src_vertno_sel -from ..utils import (verbose, check_fname, _reg_pinv, _check_option, logger, - _pl, _check_src_normal, _sym_mat_pow, warn, - _import_h5io_funcs) +from ..utils import ( + verbose, + check_fname, + _reg_pinv, + _check_option, + logger, + _pl, + _check_src_normal, + _sym_mat_pow, + warn, + _import_h5io_funcs, +) from ..time_frequency.csd import CrossSpectralDensity def _check_proj_match(proj, filters): """Check whether SSP projections in data and spatial filter match.""" - proj_data, _, _ = make_projector(proj, filters['ch_names']) - if not np.allclose(proj_data, filters['proj'], - atol=np.finfo(float).eps, rtol=1e-13): - raise ValueError('The SSP projections present in the data ' - 'do not match the projections used when ' - 'calculating the spatial filter.') + proj_data, _, _ = make_projector(proj, filters["ch_names"]) + if not np.allclose( + proj_data, filters["proj"], atol=np.finfo(float).eps, rtol=1e-13 + ): + raise ValueError( + "The SSP projections present in the data " + "do not match the projections used when " + "calculating the spatial filter." + ) def _check_src_type(filters): """Check whether src_type is in filters and set custom warning.""" - if 'src_type' not in filters: - filters['src_type'] = None - warn_text = ('The spatial filter does not contain src_type and a robust ' - 'guess of src_type is not possible without src. Consider ' - 'recomputing the filter.') + if "src_type" not in filters: + filters["src_type"] = None + warn_text = ( + "The spatial filter does not contain src_type and a robust " + "guess of src_type is not possible without src. Consider " + "recomputing the filter." + ) return filters, warn_text -def _prepare_beamformer_input(info, forward, label=None, pick_ori=None, - noise_cov=None, rank=None, pca=False, loose=None, - combine_xyz='fro', exp=None, limit=None, - allow_fixed_depth=True, limit_depth_chs=False): +def _prepare_beamformer_input( + info, + forward, + label=None, + pick_ori=None, + noise_cov=None, + rank=None, + pca=False, + loose=None, + combine_xyz="fro", + exp=None, + limit=None, + allow_fixed_depth=True, + limit_depth_chs=False, +): """Input preparation common for LCMV, DICS, and RAP-MUSIC.""" - _check_option('pick_ori', pick_ori, - ('normal', 'max-power', 'vector', None)) + _check_option("pick_ori", pick_ori, ("normal", "max-power", "vector", None)) # Restrict forward solution to selected vertices if label is not None: - _, src_sel = label_src_vertno_sel(label, forward['src']) + _, src_sel = label_src_vertno_sel(label, forward["src"]) forward = _restrict_forward_to_src_sel(forward, src_sel) if loose is None: - loose = 0. if is_fixed_orient(forward) else 1. + loose = 0.0 if is_fixed_orient(forward) else 1.0 # TODO: Deduplicate with _check_one_ch_type, should not be necessary # (DICS hits this code path, LCMV does not) if noise_cov is None: - noise_cov = make_ad_hoc_cov(info, std=1.) - forward, info_picked, gain, _, orient_prior, _, trace_GRGT, noise_cov, \ - whitener = _prepare_forward( - forward, info, noise_cov, 'auto', loose, rank=rank, pca=pca, - use_cps=True, exp=exp, limit_depth_chs=limit_depth_chs, - combine_xyz=combine_xyz, limit=limit, - allow_fixed_depth=allow_fixed_depth) + noise_cov = make_ad_hoc_cov(info, std=1.0) + ( + forward, + info_picked, + gain, + _, + orient_prior, + _, + trace_GRGT, + noise_cov, + whitener, + ) = _prepare_forward( + forward, + info, + noise_cov, + "auto", + loose, + rank=rank, + pca=pca, + use_cps=True, + exp=exp, + limit_depth_chs=limit_depth_chs, + combine_xyz=combine_xyz, + limit=limit, + allow_fixed_depth=allow_fixed_depth, + ) is_free_ori = not is_fixed_orient(forward) # could have been changed - nn = forward['source_nn'] + nn = forward["source_nn"] if is_free_ori: # take Z coordinate nn = nn[2::3] nn = nn.copy() - vertno = _get_vertno(forward['src']) - if forward['surf_ori']: + vertno = _get_vertno(forward["src"]) + if forward["surf_ori"]: nn[...] = [0, 0, 1] # align to local +Z coordinate if pick_ori is not None and not is_free_ori: raise ValueError( - 'Normal or max-power orientation (got %r) can only be picked when ' - 'a forward operator with free orientation is used.' % (pick_ori,)) - if pick_ori == 'normal' and not forward['surf_ori']: - raise ValueError('Normal orientation can only be picked when a ' - 'forward operator oriented in surface coordinates is ' - 'used.') - _check_src_normal(pick_ori, forward['src']) + "Normal or max-power orientation (got %r) can only be picked when " + "a forward operator with free orientation is used." % (pick_ori,) + ) + if pick_ori == "normal" and not forward["surf_ori"]: + raise ValueError( + "Normal orientation can only be picked when a " + "forward operator oriented in surface coordinates is " + "used." + ) + _check_src_normal(pick_ori, forward["src"]) del forward, info # Undo the scaling that MNE prefers - scale = np.sqrt((noise_cov['eig'] > 0).sum() / trace_GRGT) + scale = np.sqrt((noise_cov["eig"] > 0).sum() / trace_GRGT) gain /= scale if orient_prior is not None: orient_std = np.sqrt(orient_prior) @@ -94,10 +140,8 @@ def _prepare_beamformer_input(info, forward, label=None, pick_ori=None, orient_std = np.ones(gain.shape[1]) # Get the projector - proj, _, _ = make_projector( - info_picked['projs'], info_picked['ch_names']) - return (is_free_ori, info_picked, proj, vertno, gain, whitener, nn, - orient_std) + proj, _, _ = make_projector(info_picked["projs"], info_picked["ch_names"]) + return (is_free_ori, info_picked, proj, vertno, gain, whitener, nn, orient_std) def _reduce_leadfield_rank(G): @@ -115,12 +159,12 @@ def _reduce_leadfield_rank(G): def _sym_inv_sm(x, reduce_rank, inversion, sk): """Symmetric inversion with single- or matrix-style inversion.""" if x.shape[1:] == (1, 1): - with np.errstate(divide='ignore', invalid='ignore'): - x_inv = 1. / x - x_inv[~np.isfinite(x_inv)] = 1. + with np.errstate(divide="ignore", invalid="ignore"): + x_inv = 1.0 / x + x_inv[~np.isfinite(x_inv)] = 1.0 else: assert x.shape[1:] == (3, 3) - if inversion == 'matrix': + if inversion == "matrix": x_inv = _sym_mat_pow(x, -1, reduce_rank=reduce_rank) # Reapply source covariance after inversion x_inv *= sk[:, :, np.newaxis] @@ -128,22 +172,33 @@ def _sym_inv_sm(x, reduce_rank, inversion, sk): else: # Invert for each dipole separately using plain division diags = np.diagonal(x, axis1=1, axis2=2) - assert not reduce_rank # guaranteed earlier - with np.errstate(divide='ignore'): - diags = 1. / diags + assert not reduce_rank # guaranteed earlier + with np.errstate(divide="ignore"): + diags = 1.0 / diags # set the diagonal of each 3x3 x_inv = np.zeros_like(x) for k in range(x.shape[0]): this = diags[k] # Reapply source covariance after inversion - this *= (sk[k] * sk[k]) + this *= sk[k] * sk[k] x_inv[k].flat[::4] = this return x_inv -def _compute_beamformer(G, Cm, reg, n_orient, weight_norm, pick_ori, - reduce_rank, rank, inversion, nn, orient_std, - whitener): +def _compute_beamformer( + G, + Cm, + reg, + n_orient, + weight_norm, + pick_ori, + reduce_rank, + rank, + inversion, + nn, + orient_std, + whitener, +): """Compute a spatial beamformer filter (LCMV or DICS). For more detailed information on the parameters, see the docstrings of @@ -181,22 +236,26 @@ def _compute_beamformer(G, Cm, reg, n_orient, weight_norm, pick_ori, W : ndarray, shape (n_dipoles, n_channels) The beamformer filter weights. """ - _check_option('weight_norm', weight_norm, - ['unit-noise-gain-invariant', 'unit-noise-gain', - 'nai', None]) + _check_option( + "weight_norm", + weight_norm, + ["unit-noise-gain-invariant", "unit-noise-gain", "nai", None], + ) # Whiten the data covariance Cm = whitener @ Cm @ whitener.T.conj() # Restore to properly Hermitian as large whitening coefs can have bad # rounding error - Cm[:] = (Cm + Cm.T.conj()) / 2. + Cm[:] = (Cm + Cm.T.conj()) / 2.0 assert Cm.shape == (G.shape[0],) * 2 s, _ = np.linalg.eigh(Cm) if not (s >= -s.max() * 1e-7).all(): # This shouldn't ever happen, but just in case - warn('data covariance does not appear to be positive semidefinite, ' - 'results will likely be incorrect') + warn( + "data covariance does not appear to be positive semidefinite, " + "results will likely be incorrect" + ) # Tikhonov regularization using reg parameter to control for # trade-off between spatial resolution and noise sensitivity # eq. 25 in Gross and Ioannides, 1999 Phys. Med. Biol. 44 2081 @@ -206,8 +265,9 @@ def _compute_beamformer(G, Cm, reg, n_orient, weight_norm, pick_ori, n_sources = G.shape[1] // n_orient assert nn.shape == (n_sources, 3) - logger.info('Computing beamformer filters for %d source%s' - % (n_sources, _pl(n_sources))) + logger.info( + "Computing beamformer filters for %d source%s" % (n_sources, _pl(n_sources)) + ) n_channels = G.shape[0] assert n_orient in (3, 1) Gk = np.reshape(G.T, (n_sources, n_orient, n_channels)).transpose(0, 2, 1) @@ -215,29 +275,37 @@ def _compute_beamformer(G, Cm, reg, n_orient, weight_norm, pick_ori, sk = np.reshape(orient_std, (n_sources, n_orient)) del G, orient_std - _check_option('reduce_rank', reduce_rank, (True, False)) + _check_option("reduce_rank", reduce_rank, (True, False)) # inversion of the denominator - _check_option('inversion', inversion, ('matrix', 'single')) - if inversion == 'single' and n_orient > 1 and pick_ori == 'vector' and \ - weight_norm == 'unit-noise-gain-invariant': + _check_option("inversion", inversion, ("matrix", "single")) + if ( + inversion == "single" + and n_orient > 1 + and pick_ori == "vector" + and weight_norm == "unit-noise-gain-invariant" + ): raise ValueError( 'Cannot use pick_ori="vector" with inversion="single" and ' - 'weight_norm="unit-noise-gain-invariant"') - if reduce_rank and inversion == 'single': - raise ValueError('reduce_rank cannot be used with inversion="single"; ' - 'consider using inversion="matrix" if you have a ' - 'rank-deficient forward model (i.e., from a sphere ' - 'model with MEG channels), otherwise consider using ' - 'reduce_rank=False') + 'weight_norm="unit-noise-gain-invariant"' + ) + if reduce_rank and inversion == "single": + raise ValueError( + 'reduce_rank cannot be used with inversion="single"; ' + 'consider using inversion="matrix" if you have a ' + "rank-deficient forward model (i.e., from a sphere " + "model with MEG channels), otherwise consider using " + "reduce_rank=False" + ) if n_orient > 1: _, Gk_s, _ = np.linalg.svd(Gk, full_matrices=False) assert Gk_s.shape == (n_sources, n_orient) if not reduce_rank and (Gk_s[:, 0] > 1e6 * Gk_s[:, 2]).any(): raise ValueError( - 'Singular matrix detected when estimating spatial filters. ' - 'Consider reducing the rank of the forward operator by using ' - 'reduce_rank=True.') + "Singular matrix detected when estimating spatial filters. " + "Consider reducing the rank of the forward operator by using " + "reduce_rank=True." + ) del Gk_s # @@ -254,7 +322,7 @@ def _compute_bf_terms(Gk, Cm_inv): # # 2. Reorient lead field in direction of max power or normal # - if pick_ori == 'max-power': + if pick_ori == "max-power": assert n_orient == 3 _, bf_denom = _compute_bf_terms(Gk, Cm_inv) if weight_norm is None: @@ -265,7 +333,8 @@ def _compute_bf_terms(Gk, Cm_inv): ori_numer = bf_denom # Cm_inv should be Hermitian so no need for .T.conj() ori_denom = np.matmul( - np.matmul(Gk.swapaxes(-2, -1).conj(), Cm_inv @ Cm_inv), Gk) + np.matmul(Gk.swapaxes(-2, -1).conj(), Cm_inv @ Cm_inv), Gk + ) ori_denom_inv = _sym_inv_sm(ori_denom, reduce_rank, inversion, sk) ori_pick = np.matmul(ori_denom_inv, ori_numer) assert ori_pick.shape == (n_sources, n_orient, n_orient) @@ -280,7 +349,7 @@ def _compute_bf_terms(Gk, Cm_inv): # set the (otherwise arbitrary) sign to match the normal signs = np.sign(np.sum(max_power_ori * nn, axis=1, keepdims=True)) - signs[signs == 0] = 1. + signs[signs == 0] = 1.0 max_power_ori *= signs # Compute the lead field for the optimal orientation, @@ -289,7 +358,7 @@ def _compute_bf_terms(Gk, Cm_inv): n_orient = 1 else: max_power_ori = None - if pick_ori == 'normal': + if pick_ori == "normal": Gk = Gk[..., 2:3] n_orient = 1 @@ -338,16 +407,17 @@ def _compute_bf_terms(Gk, Cm_inv): # # Sekihara 2008 says to use sqrt(diag(W_ug @ W_ug.T)), which is not # rotation invariant: - if weight_norm in ('unit-noise-gain', 'nai'): + if weight_norm in ("unit-noise-gain", "nai"): noise_norm = np.matmul(W, W.swapaxes(-2, -1).conj()).real noise_norm = np.reshape( # np.diag operation over last two axes - noise_norm, (n_sources, -1, 1))[:, ::n_orient + 1] + noise_norm, (n_sources, -1, 1) + )[:, :: n_orient + 1] np.sqrt(noise_norm, out=noise_norm) noise_norm[noise_norm == 0] = np.inf assert noise_norm.shape == (n_sources, n_orient, 1) W /= noise_norm else: - assert weight_norm == 'unit-noise-gain-invariant' + assert weight_norm == "unit-noise-gain-invariant" # Here we use sqrtm. The shortcut: # # use = W @@ -357,9 +427,9 @@ def _compute_bf_terms(Gk, Cm_inv): use = bf_numer inner = np.matmul(use, use.swapaxes(-2, -1).conj()) W = np.matmul(_sym_mat_pow(inner, -0.5), use) - noise_norm = 1. + noise_norm = 1.0 - if weight_norm == 'nai': + if weight_norm == "nai": # Estimate noise level based on covariance matrix, taking the # first eigenvalue that falls outside the signal subspace or the # loading factor used during regularization, whichever is largest. @@ -368,10 +438,11 @@ def _compute_bf_terms(Gk, Cm_inv): # Use the loading factor as noise ceiling. if loading_factor == 0: raise RuntimeError( - 'Cannot compute noise subspace with a full-rank ' - 'covariance matrix and no regularization. Try ' - 'manually specifying the rank of the covariance ' - 'matrix or using regularization.') + "Cannot compute noise subspace with a full-rank " + "covariance matrix and no regularization. Try " + "manually specifying the rank of the covariance " + "matrix or using regularization." + ) noise = loading_factor else: noise, _ = np.linalg.eigh(Cm) @@ -380,7 +451,7 @@ def _compute_bf_terms(Gk, Cm_inv): W /= np.sqrt(noise) W = W.reshape(n_sources * n_orient, n_channels) - logger.info('Filter computation complete') + logger.info("Filter computation complete") return W, max_power_ori @@ -402,8 +473,9 @@ def _compute_power(Cm, W, n_orient): n_sources = W.shape[0] // n_orient Wk = W.reshape(n_sources, n_orient, W.shape[1]) - source_power = np.trace((Wk @ Cm @ Wk.conj().transpose(0, 2, 1)).real, - axis1=1, axis2=2) + source_power = np.trace( + (Wk @ Cm @ Wk.conj().transpose(0, 2, 1)).real, axis1=1, axis2=2 + ) return source_power @@ -427,23 +499,27 @@ def copy(self): return deepcopy(self) def __repr__(self): # noqa: D105 - n_verts = sum(len(v) for v in self['vertices']) - n_channels = len(self['ch_names']) - if self['subject'] is None: - subject = 'unknown' + n_verts = sum(len(v) for v in self["vertices"]) + n_channels = len(self["ch_names"]) + if self["subject"] is None: + subject = "unknown" else: - subject = '"%s"' % (self['subject'],) - out = (' 1: - logger.info(' computing DICS spatial filter at ' - f'{round(freq, 2)} Hz ({i + 1}/{n_freqs})') + logger.info( + " computing DICS spatial filter at " + f"{round(freq, 2)} Hz ({i + 1}/{n_freqs})" + ) Cm = csd.get_data(index=i) @@ -228,29 +268,51 @@ def make_dics(info, forward, csd, reg=0.05, noise_csd=None, label=None, # compute spatial filter n_orient = 3 if is_free_ori else 1 W, max_power_ori = _compute_beamformer( - G, Cm, reg, n_orient, weight_norm, pick_ori, reduce_rank, - rank=csd_int_rank[i], inversion=inversion, nn=nn, - orient_std=orient_std, whitener=whitener) + G, + Cm, + reg, + n_orient, + weight_norm, + pick_ori, + reduce_rank, + rank=csd_int_rank[i], + inversion=inversion, + nn=nn, + orient_std=orient_std, + whitener=whitener, + ) Ws.append(W) max_oris.append(max_power_ori) Ws = np.array(Ws) - if pick_ori == 'max-power': + if pick_ori == "max-power": max_oris = np.array(max_oris) else: max_oris = None - src_type = _get_src_type(forward['src'], vertices) + src_type = _get_src_type(forward["src"], vertices) subject = _subject_from_forward(forward) - is_free_ori = is_free_ori if pick_ori in [None, 'vector'] else False + is_free_ori = is_free_ori if pick_ori in [None, "vector"] else False n_sources = np.sum([len(v) for v in vertices]) filters = Beamformer( - kind='DICS', weights=Ws, csd=csd, ch_names=ch_names, proj=proj, - vertices=vertices, n_sources=n_sources, subject=subject, - pick_ori=pick_ori, inversion=inversion, weight_norm=weight_norm, - src_type=src_type, source_nn=forward['source_nn'].copy(), - is_free_ori=is_free_ori, whitener=whitener, max_power_ori=max_oris) + kind="DICS", + weights=Ws, + csd=csd, + ch_names=ch_names, + proj=proj, + vertices=vertices, + n_sources=n_sources, + subject=subject, + pick_ori=pick_ori, + inversion=inversion, + weight_norm=weight_norm, + src_type=src_type, + source_nn=forward["source_nn"].copy(), + is_free_ori=is_free_ori, + whitener=whitener, + max_power_ori=max_oris, + ) return filters @@ -263,7 +325,7 @@ def _prepare_noise_csd(csd, noise_csd, real_filter): noise_csd = noise_csd.mean() noise_csd = noise_csd.get_data(as_cov=True) if real_filter: - noise_csd['data'] = noise_csd['data'].real + noise_csd["data"] = noise_csd["data"].real return csd, noise_csd @@ -275,10 +337,10 @@ def _apply_dics(data, filters, info, tmin, tfr=False): else: one_epoch = False - Ws = filters['weights'] + Ws = filters["weights"] one_freq = len(Ws) == 1 - subject = filters['subject'] + subject = filters["subject"] # compatibility with 0.16, add src_type as None if not present: filters, warn_text = _check_src_type(filters) @@ -288,35 +350,41 @@ def _apply_dics(data, filters, info, tmin, tfr=False): # Apply SSPs if not tfr: # save computation, only compute once - M_w = _proj_whiten_data(M, info['projs'], filters) + M_w = _proj_whiten_data(M, info["projs"], filters) stcs = [] for j, W in enumerate(Ws): - if tfr: # must compute for each frequency - M_w = _proj_whiten_data(M[:, j], info['projs'], filters) + M_w = _proj_whiten_data(M[:, j], info["projs"], filters) # project to source space using beamformer weights sol = np.dot(W, M_w) - if filters['is_free_ori'] and filters['pick_ori'] != 'vector': - logger.info('combining the current components...') + if filters["is_free_ori"] and filters["pick_ori"] != "vector": + logger.info("combining the current components...") sol = combine_xyz(sol) - tstep = 1.0 / info['sfreq'] - - stcs.append(_make_stc(sol, vertices=filters['vertices'], - src_type=filters['src_type'], tmin=tmin, - tstep=tstep, subject=subject, - vector=(filters['pick_ori'] == 'vector'), - source_nn=filters['source_nn'], - warn_text=warn_text)) + tstep = 1.0 / info["sfreq"] + + stcs.append( + _make_stc( + sol, + vertices=filters["vertices"], + src_type=filters["src_type"], + tmin=tmin, + tstep=tstep, + subject=subject, + vector=(filters["pick_ori"] == "vector"), + source_nn=filters["source_nn"], + warn_text=warn_text, + ) + ) if one_freq: yield stcs[0] else: yield stcs - logger.info('[done]') + logger.info("[done]") @verbose @@ -413,12 +481,12 @@ def apply_dics_epochs(epochs, filters, return_generator=False, verbose=None): """ _check_reference(epochs) - if len(filters['weights']) > 1: + if len(filters["weights"]) > 1: raise ValueError( - 'This function only works on DICS beamformer weights that have ' - 'been computed for a single frequency. When calling make_dics(), ' - 'make sure to use a CSD object with only a single frequency (or ' - 'frequency-bin) defined.' + "This function only works on DICS beamformer weights that have " + "been computed for a single frequency. When calling make_dics(), " + "make sure to use a CSD object with only a single frequency (or " + "frequency-bin) defined." ) info = epochs.info @@ -436,8 +504,7 @@ def apply_dics_epochs(epochs, filters, return_generator=False, verbose=None): @verbose -def apply_dics_tfr_epochs(epochs_tfr, filters, return_generator=False, - verbose=None): +def apply_dics_tfr_epochs(epochs_tfr, filters, return_generator=False, verbose=None): """Apply Dynamic Imaging of Coherent Sources (DICS) beamformer weights. Apply Dynamic Imaging of Coherent Sources (DICS) beamformer weights @@ -466,22 +533,23 @@ def apply_dics_tfr_epochs(epochs_tfr, filters, return_generator=False, apply_dics apply_dics_epochs apply_dics_csd - """ # noqa E501 + """ # noqa E501 _validate_type(epochs_tfr, EpochsTFR) _check_tfr_complex(epochs_tfr) - if filters['pick_ori'] == 'vector': - warn('Using a vector solution to compute power will lead to ' - 'inaccurate directions (only in the first quadrent) ' - 'because power is a strictly positive (squared) metric. ' - 'Using singular value decomposition (SVD) to determine ' - 'the direction is not yet supported in MNE.') + if filters["pick_ori"] == "vector": + warn( + "Using a vector solution to compute power will lead to " + "inaccurate directions (only in the first quadrent) " + "because power is a strictly positive (squared) metric. " + "Using singular value decomposition (SVD) to determine " + "the direction is not yet supported in MNE." + ) sel = _check_channels_spatial_filter(epochs_tfr.ch_names, filters) data = epochs_tfr.data[:, sel, :, :] - stcs = _apply_dics(data, filters, epochs_tfr.info, - epochs_tfr.tmin, tfr=True) + stcs = _apply_dics(data, filters, epochs_tfr.info, epochs_tfr.tmin, tfr=True) if not return_generator: stcs = [[stc for stc in tfr_stcs] for tfr_stcs in stcs] return stcs @@ -531,12 +599,12 @@ def apply_dics_csd(csd, filters, verbose=None): ---------- .. footbibliography:: """ # noqa: E501 - ch_names = filters['ch_names'] - vertices = filters['vertices'] - n_orient = 3 if filters['is_free_ori'] else 1 - subject = filters['subject'] - whitener = filters['whitener'] - n_sources = filters['n_sources'] + ch_names = filters["ch_names"] + vertices = filters["vertices"] + n_orient = 3 if filters["is_free_ori"] else 1 + subject = filters["subject"] + whitener = filters["whitener"] + n_sources = filters["n_sources"] # If CSD is summed over multiple frequencies, take the average frequency frequencies = [np.mean(dfreq) for dfreq in csd.frequencies] @@ -547,27 +615,37 @@ def apply_dics_csd(csd, filters, verbose=None): # Ensure the CSD is in the same order as the weights csd_picks = [csd.ch_names.index(ch) for ch in ch_names] - logger.info('Computing DICS source power...') + logger.info("Computing DICS source power...") for i, freq in enumerate(frequencies): if n_freqs > 1: - logger.info(' applying DICS spatial filter at ' - f'{round(freq, 2)} Hz ({i + 1}/{n_freqs})') + logger.info( + " applying DICS spatial filter at " + f"{round(freq, 2)} Hz ({i + 1}/{n_freqs})" + ) Cm = csd.get_data(index=i) Cm = Cm[csd_picks, :][:, csd_picks] - W = filters['weights'][i] + W = filters["weights"][i] # Whiten the CSD Cm = np.dot(whitener, np.dot(Cm, whitener.conj().T)) source_power[:, i] = _compute_power(Cm, W, n_orient) - logger.info('[done]') + logger.info("[done]") # compatibility with 0.16, add src_type as None if not present: filters, warn_text = _check_src_type(filters) - return (_make_stc(source_power, vertices=vertices, - src_type=filters['src_type'], tmin=0., tstep=1., - subject=subject, warn_text=warn_text), - frequencies) + return ( + _make_stc( + source_power, + vertices=vertices, + src_type=filters["src_type"], + tmin=0.0, + tstep=1.0, + subject=subject, + warn_text=warn_text, + ), + frequencies, + ) diff --git a/mne/beamformer/_lcmv.py b/mne/beamformer/_lcmv.py index 61c45a8ec66..3e67890da65 100644 --- a/mne/beamformer/_lcmv.py +++ b/mne/beamformer/_lcmv.py @@ -13,18 +13,39 @@ from ..forward import _subject_from_forward from ..minimum_norm.inverse import combine_xyz, _check_reference, _check_depth from ..source_estimate import _make_stc, _get_src_type -from ..utils import (logger, verbose, _check_channels_spatial_filter, - _check_one_ch_type, _check_info_inv) +from ..utils import ( + logger, + verbose, + _check_channels_spatial_filter, + _check_one_ch_type, + _check_info_inv, +) from ._compute_beamformer import ( - _prepare_beamformer_input, _compute_power, - _compute_beamformer, _check_src_type, Beamformer, _proj_whiten_data) + _prepare_beamformer_input, + _compute_power, + _compute_beamformer, + _check_src_type, + Beamformer, + _proj_whiten_data, +) @verbose -def make_lcmv(info, forward, data_cov, reg=0.05, noise_cov=None, label=None, - pick_ori=None, rank='info', - weight_norm='unit-noise-gain-invariant', - reduce_rank=False, depth=None, inversion='matrix', verbose=None): +def make_lcmv( + info, + forward, + data_cov, + reg=0.05, + noise_cov=None, + label=None, + pick_ori=None, + rank="info", + weight_norm="unit-noise-gain-invariant", + reduce_rank=False, + depth=None, + inversion="matrix", + verbose=None, +): """Compute LCMV spatial filter. Parameters @@ -144,7 +165,8 @@ def make_lcmv(info, forward, data_cov, reg=0.05, noise_cov=None, label=None, # check number of sensor types present in the data and ensure a noise cov info = _simplify_info(info) noise_cov, _, allow_mismatch = _check_one_ch_type( - 'lcmv', info, forward, data_cov, noise_cov) + "lcmv", info, forward, data_cov, noise_cov + ) # XXX we need this extra picking step (can't just rely on minimum norm's # because there can be a mismatch. Should probably add an extra arg to # _prepare_beamformer_input at some point (later) @@ -153,58 +175,97 @@ def make_lcmv(info, forward, data_cov, reg=0.05, noise_cov=None, label=None, data_rank = compute_rank(data_cov, rank=rank, info=info) noise_rank = compute_rank(noise_cov, rank=rank, info=info) for key in data_rank: - if (key not in noise_rank or data_rank[key] != noise_rank[key]) and \ - not allow_mismatch: - raise ValueError('%s data rank (%s) did not match the noise ' - 'rank (%s)' - % (key, data_rank[key], - noise_rank.get(key, None))) + if ( + key not in noise_rank or data_rank[key] != noise_rank[key] + ) and not allow_mismatch: + raise ValueError( + "%s data rank (%s) did not match the noise " + "rank (%s)" % (key, data_rank[key], noise_rank.get(key, None)) + ) del noise_rank rank = data_rank - logger.info('Making LCMV beamformer with rank %s' % (rank,)) + logger.info("Making LCMV beamformer with rank %s" % (rank,)) del data_rank - depth = _check_depth(depth, 'depth_sparse') - if inversion == 'single': - depth['combine_xyz'] = False - - is_free_ori, info, proj, vertno, G, whitener, nn, orient_std = \ - _prepare_beamformer_input( - info, forward, label, pick_ori, noise_cov=noise_cov, rank=rank, - pca=False, **depth) - ch_names = list(info['ch_names']) + depth = _check_depth(depth, "depth_sparse") + if inversion == "single": + depth["combine_xyz"] = False + + ( + is_free_ori, + info, + proj, + vertno, + G, + whitener, + nn, + orient_std, + ) = _prepare_beamformer_input( + info, + forward, + label, + pick_ori, + noise_cov=noise_cov, + rank=rank, + pca=False, + **depth + ) + ch_names = list(info["ch_names"]) data_cov = pick_channels_cov(data_cov, include=ch_names) Cm = data_cov._get_square() - if 'estimator' in data_cov: - del data_cov['estimator'] + if "estimator" in data_cov: + del data_cov["estimator"] rank_int = sum(rank.values()) del rank # compute spatial filter n_orient = 3 if is_free_ori else 1 W, max_power_ori = _compute_beamformer( - G, Cm, reg, n_orient, weight_norm, pick_ori, reduce_rank, rank_int, - inversion=inversion, nn=nn, orient_std=orient_std, - whitener=whitener) + G, + Cm, + reg, + n_orient, + weight_norm, + pick_ori, + reduce_rank, + rank_int, + inversion=inversion, + nn=nn, + orient_std=orient_std, + whitener=whitener, + ) # get src type to store with filters for _make_stc - src_type = _get_src_type(forward['src'], vertno) + src_type = _get_src_type(forward["src"], vertno) # get subject to store with filters subject_from = _subject_from_forward(forward) # Is the computed beamformer a scalar or vector beamformer? - is_free_ori = is_free_ori if pick_ori in [None, 'vector'] else False - is_ssp = bool(info['projs']) + is_free_ori = is_free_ori if pick_ori in [None, "vector"] else False + is_ssp = bool(info["projs"]) filters = Beamformer( - kind='LCMV', weights=W, data_cov=data_cov, noise_cov=noise_cov, - whitener=whitener, weight_norm=weight_norm, pick_ori=pick_ori, - ch_names=ch_names, proj=proj, is_ssp=is_ssp, vertices=vertno, - is_free_ori=is_free_ori, n_sources=forward['nsource'], - src_type=src_type, source_nn=forward['source_nn'].copy(), - subject=subject_from, rank=rank_int, max_power_ori=max_power_ori, - inversion=inversion) + kind="LCMV", + weights=W, + data_cov=data_cov, + noise_cov=noise_cov, + whitener=whitener, + weight_norm=weight_norm, + pick_ori=pick_ori, + ch_names=ch_names, + proj=proj, + is_ssp=is_ssp, + vertices=vertno, + is_free_ori=is_free_ori, + n_sources=forward["nsource"], + src_type=src_type, + source_nn=forward["source_nn"].copy(), + subject=subject_from, + rank=rank_int, + max_power_ori=max_power_ori, + inversion=inversion, + ) return filters @@ -217,45 +278,51 @@ def _apply_lcmv(data, filters, info, tmin): else: return_single = False - W = filters['weights'] + W = filters["weights"] for i, M in enumerate(data): - if len(M) != len(filters['ch_names']): - raise ValueError('data and picks must have the same length') + if len(M) != len(filters["ch_names"]): + raise ValueError("data and picks must have the same length") if not return_single: logger.info("Processing epoch : %d" % (i + 1)) - M = _proj_whiten_data(M, info['projs'], filters) + M = _proj_whiten_data(M, info["projs"], filters) # project to source space using beamformer weights vector = False - if filters['is_free_ori']: + if filters["is_free_ori"]: sol = np.dot(W, M) - if filters['pick_ori'] == 'vector': + if filters["pick_ori"] == "vector": vector = True else: - logger.info('combining the current components...') + logger.info("combining the current components...") sol = combine_xyz(sol) else: # Linear inverse: do computation here or delayed - if (M.shape[0] < W.shape[0] and - filters['pick_ori'] != 'max-power'): + if M.shape[0] < W.shape[0] and filters["pick_ori"] != "max-power": sol = (W, M) else: sol = np.dot(W, M) - tstep = 1.0 / info['sfreq'] + tstep = 1.0 / info["sfreq"] # compatibility with 0.16, add src_type as None if not present: filters, warn_text = _check_src_type(filters) - yield _make_stc(sol, vertices=filters['vertices'], tmin=tmin, - tstep=tstep, subject=filters['subject'], - vector=vector, source_nn=filters['source_nn'], - src_type=filters['src_type'], warn_text=warn_text) + yield _make_stc( + sol, + vertices=filters["vertices"], + tmin=tmin, + tstep=tstep, + subject=filters["subject"], + vector=vector, + source_nn=filters["source_nn"], + src_type=filters["src_type"], + warn_text=warn_text, + ) - logger.info('[done]') + logger.info("[done]") @verbose @@ -296,15 +363,13 @@ def apply_lcmv(evoked, filters, *, verbose=None): sel = _check_channels_spatial_filter(evoked.ch_names, filters) data = data[sel] - stc = _apply_lcmv(data=data, filters=filters, info=info, - tmin=tmin) + stc = _apply_lcmv(data=data, filters=filters, info=info, tmin=tmin) return next(stc) @verbose -def apply_lcmv_epochs(epochs, filters, *, return_generator=False, - verbose=None): +def apply_lcmv_epochs(epochs, filters, *, return_generator=False, verbose=None): """Apply Linearly Constrained Minimum Variance (LCMV) beamformer weights. Apply Linearly Constrained Minimum Variance (LCMV) beamformer weights @@ -338,8 +403,7 @@ def apply_lcmv_epochs(epochs, filters, *, return_generator=False, sel = _check_channels_spatial_filter(epochs.ch_names, filters) data = epochs.get_data()[:, sel, :] - stcs = _apply_lcmv(data=data, filters=filters, info=info, - tmin=tmin) + stcs = _apply_lcmv(data=data, filters=filters, info=info, tmin=tmin) if not return_generator: stcs = [s for s in stcs] @@ -418,17 +482,23 @@ def apply_lcmv_cov(data_cov, filters, verbose=None): sel_names = [data_cov.ch_names[ii] for ii in sel] data_cov = pick_channels_cov(data_cov, sel_names) - n_orient = filters['weights'].shape[0] // filters['n_sources'] + n_orient = filters["weights"].shape[0] // filters["n_sources"] # Need to project and whiten along both dimensions - data = _proj_whiten_data(data_cov['data'].T, data_cov['projs'], filters) - data = _proj_whiten_data(data.T, data_cov['projs'], filters) + data = _proj_whiten_data(data_cov["data"].T, data_cov["projs"], filters) + data = _proj_whiten_data(data.T, data_cov["projs"], filters) del data_cov - source_power = _compute_power(data, filters['weights'], n_orient) + source_power = _compute_power(data, filters["weights"], n_orient) # compatibility with 0.16, add src_type as None if not present: filters, warn_text = _check_src_type(filters) - return _make_stc(source_power, vertices=filters['vertices'], - src_type=filters['src_type'], tmin=0., tstep=1., - subject=filters['subject'], - source_nn=filters['source_nn'], warn_text=warn_text) + return _make_stc( + source_power, + vertices=filters["vertices"], + src_type=filters["src_type"], + tmin=0.0, + tstep=1.0, + subject=filters["subject"], + source_nn=filters["source_nn"], + warn_text=warn_text, + ) diff --git a/mne/beamformer/_rap_music.py b/mne/beamformer/_rap_music.py index 3b59fa90c46..d58de523b2a 100644 --- a/mne/beamformer/_rap_music.py +++ b/mne/beamformer/_rap_music.py @@ -17,8 +17,7 @@ @fill_doc -def _apply_rap_music(data, info, times, forward, noise_cov, n_dipoles=2, - picks=None): +def _apply_rap_music(data, info, times, forward, noise_cov, n_dipoles=2, picks=None): """RAP-MUSIC for evoked data. Parameters @@ -47,15 +46,17 @@ def _apply_rap_music(data, info, times, forward, noise_cov, n_dipoles=2, Computed only if return_explained_data is True. """ from scipy import linalg + info = pick_info(info, picks) del picks # things are much simpler if we avoid surface orientation - align = forward['source_nn'].copy() - if forward['surf_ori'] and not is_fixed_orient(forward): + align = forward["source_nn"].copy() + if forward["surf_ori"] and not is_fixed_orient(forward): forward = convert_forward_solution(forward, surf_ori=False) is_free_ori, info, _, _, G, whitener, _, _ = _prepare_beamformer_input( - info, forward, noise_cov=noise_cov, rank=None) - forward = pick_channels_forward(forward, info['ch_names'], ordered=True) + info, forward, noise_cov=noise_cov, rank=None + ) + forward = pick_channels_forward(forward, info["ch_names"], ordered=True) del info # whiten the data (leadfield already whitened) @@ -67,7 +68,7 @@ def _apply_rap_music(data, info, times, forward, noise_cov, n_dipoles=2, n_orient = 3 if is_free_ori else 1 G.shape = (G.shape[0], -1, n_orient) - gain = forward['sol']['data'].copy() + gain = forward["sol"]["data"].copy() gain.shape = G.shape n_channels = G.shape[0] A = np.empty((n_channels, n_dipoles)) @@ -80,7 +81,7 @@ def _apply_rap_music(data, info, times, forward, noise_cov, n_dipoles=2, idxs = list() for k in range(n_dipoles): - subcorr_max = -1. + subcorr_max = -1.0 source_idx, source_ori, source_pos = 0, [0, 0, 0], [0, 0, 0] for i_source in range(G.shape[1]): Gk = G_proj[:, i_source] @@ -89,13 +90,13 @@ def _apply_rap_music(data, info, times, forward, noise_cov, n_dipoles=2, subcorr_max = subcorr source_idx = i_source source_ori = ori - source_pos = forward['source_rr'][i_source] + source_pos = forward["source_rr"][i_source] if n_orient == 3 and align is not None: - surf_normal = forward['source_nn'][3 * i_source + 2] + surf_normal = forward["source_nn"][3 * i_source + 2] # make sure ori is aligned to the surface orientation - source_ori *= np.sign(source_ori @ surf_normal) or 1. + source_ori *= np.sign(source_ori @ surf_normal) or 1.0 if n_orient == 1: - source_ori = forward['source_nn'][i_source] + source_ori = forward["source_nn"][i_source] idxs.append(source_idx) if n_orient == 3: @@ -110,8 +111,8 @@ def _apply_rap_music(data, info, times, forward, noise_cov, n_dipoles=2, if n_orient == 3: logger.info("ori = %s %s %s" % tuple(oris[k])) - projection = _compute_proj(A[:, :k + 1]) - G_proj = np.einsum('ab,bso->aso', projection, G) + projection = _compute_proj(A[:, : k + 1]) + G_proj = np.einsum("ab,bso->aso", projection, G) phi_sig_proj = np.dot(projection, phi_sig) del G, G_proj @@ -126,8 +127,7 @@ def _apply_rap_music(data, info, times, forward, noise_cov, n_dipoles=2, if n_orient == 3: gain_dip = (oris * gain_active).sum(-1) idxs = np.array(idxs) - active_set = np.array( - [[3 * idxs, 3 * idxs + 1, 3 * idxs + 2]]).T.ravel() + active_set = np.array([[3 * idxs, 3 * idxs + 1, 3 * idxs + 2]]).T.ravel() else: gain_dip = gain_active[:, :, 0] active_set = idxs @@ -137,15 +137,15 @@ def _apply_rap_music(data, info, times, forward, noise_cov, n_dipoles=2, explained_data = gain_dip @ sol M_estimate = whitener @ explained_data _log_exp_var(M, M_estimate) - tstep = np.median(np.diff(times)) if len(times) > 1 else 1. + tstep = np.median(np.diff(times)) if len(times) > 1 else 1.0 dipoles = _make_dipoles_sparse( - X, active_set, forward, times[0], tstep, M, - gain_active, active_is_idx=True) + X, active_set, forward, times[0], tstep, M, gain_active, active_is_idx=True + ) for dipole, ori in zip(dipoles, oris): signs = np.sign((dipole.ori * ori).sum(-1, keepdims=True)) dipole.ori *= signs dipole.amplitude *= signs[:, 0] - logger.info('[done]') + logger.info("[done]") return dipoles, explained_data @@ -185,6 +185,7 @@ def _make_dipoles(times, poss, oris, sol, gof): def _compute_subcorr(G, phi_sig): """Compute the subspace correlation.""" from scipy import linalg + Ug, Sg, Vg = linalg.svd(G, full_matrices=False) # Now we look at the actual rank of the forward fields # in G and handle the fact that it might be rank defficient @@ -204,13 +205,15 @@ def _compute_subcorr(G, phi_sig): def _compute_proj(A): """Compute the orthogonal projection operation for a manifold vector A.""" from scipy import linalg + U, _, _ = linalg.svd(A, full_matrices=False) return np.identity(A.shape[0]) - np.dot(U, U.T.conjugate()) @verbose -def rap_music(evoked, forward, noise_cov, n_dipoles=5, return_residual=False, - verbose=None): +def rap_music( + evoked, forward, noise_cov, n_dipoles=5, return_residual=False, verbose=None +): """RAP-MUSIC source localization method. Compute Recursively Applied and Projected MUltiple SIgnal Classification @@ -269,16 +272,16 @@ def rap_music(evoked, forward, noise_cov, n_dipoles=5, return_residual=False, data = data[picks] - dipoles, explained_data = _apply_rap_music(data, info, times, forward, - noise_cov, n_dipoles, - picks) + dipoles, explained_data = _apply_rap_music( + data, info, times, forward, noise_cov, n_dipoles, picks + ) if return_residual: - residual = evoked.copy().pick([info['ch_names'][p] for p in picks]) + residual = evoked.copy().pick([info["ch_names"][p] for p in picks]) residual.data -= explained_data - active_projs = [p for p in residual.info['projs'] if p['active']] + active_projs = [p for p in residual.info["projs"] if p["active"]] for p in active_projs: - p['active'] = False + p["active"] = False residual.add_proj(active_projs, remove_existing=True) residual.apply_proj() return dipoles, residual diff --git a/mne/beamformer/resolution_matrix.py b/mne/beamformer/resolution_matrix.py index 5294de5a621..278ae65692a 100644 --- a/mne/beamformer/resolution_matrix.py +++ b/mne/beamformer/resolution_matrix.py @@ -33,8 +33,8 @@ def make_lcmv_resolution_matrix(filters, forward, info): for free dipole orientation versus factor 1 for scalar beamformers). """ # don't include bad channels from noise covariance matrix - bads_filt = filters['noise_cov']['bads'] - ch_names = filters['noise_cov']['names'] + bads_filt = filters["noise_cov"]["bads"] + ch_names = filters["noise_cov"]["names"] # good channels ch_names = [c for c in ch_names if (c not in bads_filt)] @@ -43,7 +43,7 @@ def make_lcmv_resolution_matrix(filters, forward, info): forward = pick_channels_forward(forward, ch_names, ordered=True) # get leadfield matrix from forward solution - leadfield = forward['sol']['data'] + leadfield = forward["sol"]["data"] # get the filter weights for beamformer as matrix filtmat = _get_matrix_from_lcmv(filters, forward, info) @@ -53,7 +53,7 @@ def make_lcmv_resolution_matrix(filters, forward, info): shape = resmat.shape - logger.info('Dimensions of LCMV resolution matrix: %d by %d.' % shape) + logger.info("Dimensions of LCMV resolution matrix: %d by %d." % shape) return resmat @@ -67,16 +67,15 @@ def _get_matrix_from_lcmv(filters, forward, info, verbose=None): Inverse matrix associated with LCMV beamformer filters. """ # number of channels for identity matrix - info = pick_info( - info, pick_channels(info['ch_names'], filters['ch_names'])) - n_chs = len(info['ch_names']) + info = pick_info(info, pick_channels(info["ch_names"], filters["ch_names"])) + n_chs = len(info["ch_names"]) # create identity matrix as input for inverse operator # set elements to zero for non-selected channels id_mat = np.eye(n_chs) # convert identity matrix to evoked data type (pretending it's an epochs - evo_ident = EvokedArray(id_mat, info=info, tmin=0.) + evo_ident = EvokedArray(id_mat, info=info, tmin=0.0) # apply beamformer to identity matrix stc_lcmv = apply_lcmv(evo_ident, filters, verbose=verbose) diff --git a/mne/beamformer/tests/test_dics.py b/mne/beamformer/tests/test_dics.py index 74d273a0b66..6bc18d81e3e 100644 --- a/mne/beamformer/tests/test_dics.py +++ b/mne/beamformer/tests/test_dics.py @@ -6,15 +6,20 @@ import copy as cp import pytest -from numpy.testing import (assert_array_equal, assert_allclose, - assert_array_less) +from numpy.testing import assert_array_equal, assert_allclose, assert_array_less import numpy as np import mne from mne import pick_types -from mne.beamformer import (make_dics, apply_dics, apply_dics_epochs, - apply_dics_tfr_epochs, apply_dics_csd, - read_beamformer, Beamformer) +from mne.beamformer import ( + make_dics, + apply_dics, + apply_dics_epochs, + apply_dics_tfr_epochs, + apply_dics_csd, + read_beamformer, + Beamformer, +) from mne.beamformer._compute_beamformer import _prepare_beamformer_input from mne.beamformer._dics import _prepare_noise_csd from mne.beamformer.tests.test_lcmv import _assert_weight_norm @@ -24,47 +29,40 @@ from mne.io.pick import pick_info from mne.proj import compute_proj_evoked, make_projector from mne.surface import _compute_nearest -from mne.time_frequency import (CrossSpectralDensity, csd_morlet, EpochsTFR, - csd_tfr) +from mne.time_frequency import CrossSpectralDensity, csd_morlet, EpochsTFR, csd_tfr from mne.time_frequency.csd import _sym_mat_to_vector from mne.transforms import invert_transform, apply_trans from mne.utils import object_diff, requires_version, catch_logging data_path = testing.data_path(download=False) fname_raw = data_path / "MEG" / "sample" / "sample_audvis_trunc_raw.fif" -fname_fwd = ( - data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-4-fwd.fif" -) -fname_fwd_vol = ( - data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-vol-7-fwd.fif" -) +fname_fwd = data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-4-fwd.fif" +fname_fwd_vol = data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-vol-7-fwd.fif" fname_event = data_path / "MEG" / "sample" / "sample_audvis_trunc_raw-eve.fif" subjects_dir = data_path / "subjects" -@pytest.fixture(scope='module', params=[testing._pytest_param()]) +@pytest.fixture(scope="module", params=[testing._pytest_param()]) def _load_forward(): """Load forward models.""" fwd_free = mne.read_forward_solution(fname_fwd) fwd_free = mne.pick_types_forward(fwd_free, meg=True, eeg=False) fwd_free = mne.convert_forward_solution(fwd_free, surf_ori=False) - fwd_surf = mne.convert_forward_solution(fwd_free, surf_ori=True, - use_cps=False) - fwd_fixed = mne.convert_forward_solution(fwd_free, force_fixed=True, - use_cps=False) + fwd_surf = mne.convert_forward_solution(fwd_free, surf_ori=True, use_cps=False) + fwd_fixed = mne.convert_forward_solution(fwd_free, force_fixed=True, use_cps=False) fwd_vol = mne.read_forward_solution(fname_fwd_vol) return fwd_free, fwd_surf, fwd_fixed, fwd_vol def _simulate_data(fwd, idx): # Somewhere on the frontal lobe by default """Simulate an oscillator on the cortex.""" - pytest.importorskip('nibabel') - source_vertno = fwd['src'][0]['vertno'][idx] + pytest.importorskip("nibabel") + source_vertno = fwd["src"][0]["vertno"][idx] - sfreq = 50. # Hz. + sfreq = 50.0 # Hz. times = np.arange(10 * sfreq) / sfreq # 10 seconds of data signal = np.sin(20 * 2 * np.pi * times) # 20 Hz oscillator - signal[:len(times) // 2] *= 2 # Make signal louder at the beginning + signal[: len(times) // 2] *= 2 # Make signal louder at the beginning signal *= 1e-9 # Scale to be in the ballpark of MEG data # Construct a SourceEstimate object that describes the signal at the @@ -74,16 +72,16 @@ def _simulate_data(fwd, idx): # Somewhere on the frontal lobe by default vertices=[[source_vertno], []], tmin=0, tstep=1 / sfreq, - subject='sample', + subject="sample", ) # Create an info object that holds information about the sensors - info = mne.create_info(fwd['info']['ch_names'], sfreq, ch_types='grad') + info = mne.create_info(fwd["info"]["ch_names"], sfreq, ch_types="grad") with info._unlock(): - info.update(fwd['info']) # Merge in sensor position information + info.update(fwd["info"]) # Merge in sensor position information # heavily decimate sensors to make it much faster - info = mne.pick_info(info, np.arange(info['nchan'])[::5]) - fwd = mne.pick_channels_forward(fwd, info['ch_names']) + info = mne.pick_info(info, np.arange(info["nchan"])[::5]) + fwd = mne.pick_channels_forward(fwd, info["ch_names"]) # Run the simulated signal through the forward model, obtaining # simulated sensor data. @@ -95,31 +93,39 @@ def _simulate_data(fwd, idx): # Somewhere on the frontal lobe by default raw._data += noise # Define a single epoch (weird baseline but shouldn't matter) - epochs = mne.Epochs(raw, [[0, 0, 1]], event_id=1, tmin=0, - tmax=raw.times[-1], baseline=(0., 0.), preload=True) + epochs = mne.Epochs( + raw, + [[0, 0, 1]], + event_id=1, + tmin=0, + tmax=raw.times[-1], + baseline=(0.0, 0.0), + preload=True, + ) evoked = epochs.average() # Compute the cross-spectral density matrix csd = csd_morlet(epochs, frequencies=[10, 20], n_cycles=[5, 10], decim=5) - labels = mne.read_labels_from_annot( - 'sample', hemi='lh', subjects_dir=subjects_dir) - label = [ - label for label in labels if np.in1d(source_vertno, label.vertices)[0]] + labels = mne.read_labels_from_annot("sample", hemi="lh", subjects_dir=subjects_dir) + label = [label for label in labels if np.in1d(source_vertno, label.vertices)[0]] assert len(label) == 1 label = label[0] - vertices = np.intersect1d(label.vertices, fwd['src'][0]['vertno']) + vertices = np.intersect1d(label.vertices, fwd["src"][0]["vertno"]) source_ind = vertices.tolist().index(source_vertno) assert vertices[source_ind] == source_vertno return epochs, evoked, csd, source_vertno, label, vertices, source_ind -idx_param = pytest.mark.parametrize('idx', [ - 0, - pytest.param(100, marks=pytest.mark.slowtest), - 200, - pytest.param(233, marks=pytest.mark.slowtest), -]) +idx_param = pytest.mark.parametrize( + "idx", + [ + 0, + pytest.param(100, marks=pytest.mark.slowtest), + 200, + pytest.param(233, marks=pytest.mark.slowtest), + ], +) def _rand_csd(rng, info): @@ -130,7 +136,7 @@ def _rand_csd(rng, info): data = data @ data.conj().T data *= scales data *= scales[:, np.newaxis] - data.flat[::n + 1] = scales + data.flat[:: n + 1] = scales return data @@ -141,67 +147,74 @@ def _make_rand_csd(info, csd): s, u = np.linalg.eigh(csd.get_data(csd.frequencies[0])) mask = np.abs(s) >= s[-1] * 1e-7 rank = mask.sum() - assert rank == len(data) == len(info['ch_names']) + assert rank == len(data) == len(info["ch_names"]) noise_csd = CrossSpectralDensity( - _sym_mat_to_vector(data), info['ch_names'], 0., csd.n_fft) + _sym_mat_to_vector(data), info["ch_names"], 0.0, csd.n_fft + ) return noise_csd, rank @pytest.mark.slowtest @testing.requires_testing_data -@requires_version('h5io') +@requires_version("h5io") @idx_param -@pytest.mark.parametrize('whiten', [ - pytest.param(False, marks=pytest.mark.slowtest), - True, -]) +@pytest.mark.parametrize( + "whiten", + [ + pytest.param(False, marks=pytest.mark.slowtest), + True, + ], +) def test_make_dics(tmp_path, _load_forward, idx, whiten): """Test making DICS beamformer filters.""" # We only test proper handling of parameters here. Testing the results is # done in test_apply_dics_timeseries and test_apply_dics_csd. fwd_free, fwd_surf, fwd_fixed, fwd_vol = _load_forward - epochs, _, csd, _, label, vertices, source_ind = \ - _simulate_data(fwd_fixed, idx) - with pytest.raises(ValueError, match='several sensor types'): + epochs, _, csd, _, label, vertices, source_ind = _simulate_data(fwd_fixed, idx) + with pytest.raises(ValueError, match="several sensor types"): make_dics(epochs.info, fwd_surf, csd, label=label, pick_ori=None) if whiten: noise_csd, rank = _make_rand_csd(epochs.info, csd) - assert rank == len(epochs.info['ch_names']) == 62 + assert rank == len(epochs.info["ch_names"]) == 62 else: noise_csd = None - epochs.pick_types(meg='grad') + epochs.pick_types(meg="grad") with pytest.raises(ValueError, match="Invalid value for the 'pick_ori'"): - make_dics(epochs.info, fwd_fixed, csd, pick_ori="notexistent", - noise_csd=noise_csd) - with pytest.raises(ValueError, match='rank, if str'): - make_dics(epochs.info, fwd_fixed, csd, rank='foo', noise_csd=noise_csd) - with pytest.raises(TypeError, match='rank must be'): - make_dics(epochs.info, fwd_fixed, csd, rank=1., noise_csd=noise_csd) + make_dics( + epochs.info, fwd_fixed, csd, pick_ori="notexistent", noise_csd=noise_csd + ) + with pytest.raises(ValueError, match="rank, if str"): + make_dics(epochs.info, fwd_fixed, csd, rank="foo", noise_csd=noise_csd) + with pytest.raises(TypeError, match="rank must be"): + make_dics(epochs.info, fwd_fixed, csd, rank=1.0, noise_csd=noise_csd) # Test if fixed forward operator is detected when picking normal # orientation - with pytest.raises(ValueError, match='forward operator with free ori'): - make_dics(epochs.info, fwd_fixed, csd, pick_ori="normal", - noise_csd=noise_csd) + with pytest.raises(ValueError, match="forward operator with free ori"): + make_dics(epochs.info, fwd_fixed, csd, pick_ori="normal", noise_csd=noise_csd) # Test if non-surface oriented forward operator is detected when picking # normal orientation - with pytest.raises(ValueError, match='oriented in surface coordinates'): - make_dics(epochs.info, fwd_free, csd, pick_ori="normal", - noise_csd=noise_csd) + with pytest.raises(ValueError, match="oriented in surface coordinates"): + make_dics(epochs.info, fwd_free, csd, pick_ori="normal", noise_csd=noise_csd) # Test if volume forward operator is detected when picking normal # orientation - with pytest.raises(ValueError, match='oriented in surface coordinates'): - make_dics(epochs.info, fwd_vol, csd, pick_ori="normal", - noise_csd=noise_csd) + with pytest.raises(ValueError, match="oriented in surface coordinates"): + make_dics(epochs.info, fwd_vol, csd, pick_ori="normal", noise_csd=noise_csd) # Test invalid combinations of parameters - with pytest.raises(ValueError, match='reduce_rank cannot be used with'): - make_dics(epochs.info, fwd_free, csd, inversion='single', - reduce_rank=True, noise_csd=noise_csd) + with pytest.raises(ValueError, match="reduce_rank cannot be used with"): + make_dics( + epochs.info, + fwd_free, + csd, + inversion="single", + reduce_rank=True, + noise_csd=noise_csd, + ) # TODO: Restore this? # with pytest.raises(ValueError, match='not stable with depth'): # make_dics(epochs.info, fwd_free, csd, weight_norm='unit-noise-gain', @@ -209,83 +222,136 @@ def test_make_dics(tmp_path, _load_forward, idx, whiten): # Sanity checks on the returned filters n_freq = len(csd.frequencies) - vertices = np.intersect1d(label.vertices, fwd_free['src'][0]['vertno']) + vertices = np.intersect1d(label.vertices, fwd_free["src"][0]["vertno"]) n_verts = len(vertices) n_orient = 3 n_channels = len(epochs.ch_names) # Test return values - weight_norm = 'unit-noise-gain' - inversion = 'single' - filters = make_dics(epochs.info, fwd_surf, csd, label=label, pick_ori=None, - weight_norm=weight_norm, depth=None, real_filter=False, - noise_csd=noise_csd, inversion=inversion) - assert filters['weights'].shape == (n_freq, n_verts * n_orient, n_channels) - assert np.iscomplexobj(filters['weights']) - assert filters['csd'].ch_names == epochs.ch_names - assert isinstance(filters['csd'], CrossSpectralDensity) - assert filters['ch_names'] == epochs.ch_names - assert_array_equal(filters['proj'], np.eye(n_channels)) - assert_array_equal(filters['vertices'][0], vertices) - assert_array_equal(filters['vertices'][1], []) # Label was on the LH - assert filters['subject'] == fwd_free['src']._subject - assert filters['pick_ori'] is None - assert filters['is_free_ori'] - assert filters['inversion'] == inversion - assert filters['weight_norm'] == weight_norm - assert 'DICS' in repr(filters) + weight_norm = "unit-noise-gain" + inversion = "single" + filters = make_dics( + epochs.info, + fwd_surf, + csd, + label=label, + pick_ori=None, + weight_norm=weight_norm, + depth=None, + real_filter=False, + noise_csd=noise_csd, + inversion=inversion, + ) + assert filters["weights"].shape == (n_freq, n_verts * n_orient, n_channels) + assert np.iscomplexobj(filters["weights"]) + assert filters["csd"].ch_names == epochs.ch_names + assert isinstance(filters["csd"], CrossSpectralDensity) + assert filters["ch_names"] == epochs.ch_names + assert_array_equal(filters["proj"], np.eye(n_channels)) + assert_array_equal(filters["vertices"][0], vertices) + assert_array_equal(filters["vertices"][1], []) # Label was on the LH + assert filters["subject"] == fwd_free["src"]._subject + assert filters["pick_ori"] is None + assert filters["is_free_ori"] + assert filters["inversion"] == inversion + assert filters["weight_norm"] == weight_norm + assert "DICS" in repr(filters) assert 'subject "sample"' in repr(filters) assert str(len(vertices)) in repr(filters) assert str(n_channels) in repr(filters) - assert 'rank' not in repr(filters) + assert "rank" not in repr(filters) _, noise_cov = _prepare_noise_csd(csd, noise_csd, real_filter=False) _, _, _, _, G, _, _, _ = _prepare_beamformer_input( - epochs.info, fwd_surf, label, 'vector', combine_xyz=False, exp=None, - noise_cov=noise_cov) + epochs.info, + fwd_surf, + label, + "vector", + combine_xyz=False, + exp=None, + noise_cov=noise_cov, + ) G.shape = (n_channels, n_verts, n_orient) G = G.transpose(1, 2, 0).conj() # verts, orient, ch _assert_weight_norm(filters, G) - inversion = 'matrix' - filters = make_dics(epochs.info, fwd_surf, csd, label=label, pick_ori=None, - weight_norm=weight_norm, depth=None, - noise_csd=noise_csd, inversion=inversion) + inversion = "matrix" + filters = make_dics( + epochs.info, + fwd_surf, + csd, + label=label, + pick_ori=None, + weight_norm=weight_norm, + depth=None, + noise_csd=noise_csd, + inversion=inversion, + ) _assert_weight_norm(filters, G) - weight_norm = 'unit-noise-gain-invariant' - inversion = 'single' - filters = make_dics(epochs.info, fwd_surf, csd, label=label, pick_ori=None, - weight_norm=weight_norm, depth=None, - noise_csd=noise_csd, inversion=inversion) + weight_norm = "unit-noise-gain-invariant" + inversion = "single" + filters = make_dics( + epochs.info, + fwd_surf, + csd, + label=label, + pick_ori=None, + weight_norm=weight_norm, + depth=None, + noise_csd=noise_csd, + inversion=inversion, + ) _assert_weight_norm(filters, G) # Test picking orientations. Also test weight norming under these different # conditions. - weight_norm = 'unit-noise-gain' - filters = make_dics(epochs.info, fwd_surf, csd, label=label, - pick_ori='normal', weight_norm=weight_norm, - depth=None, noise_csd=noise_csd, inversion=inversion) + weight_norm = "unit-noise-gain" + filters = make_dics( + epochs.info, + fwd_surf, + csd, + label=label, + pick_ori="normal", + weight_norm=weight_norm, + depth=None, + noise_csd=noise_csd, + inversion=inversion, + ) n_orient = 1 - assert filters['weights'].shape == (n_freq, n_verts * n_orient, n_channels) - assert not filters['is_free_ori'] + assert filters["weights"].shape == (n_freq, n_verts * n_orient, n_channels) + assert not filters["is_free_ori"] _assert_weight_norm(filters, G) - filters = make_dics(epochs.info, fwd_surf, csd, label=label, - pick_ori='max-power', weight_norm=weight_norm, - depth=None, noise_csd=noise_csd, inversion=inversion) + filters = make_dics( + epochs.info, + fwd_surf, + csd, + label=label, + pick_ori="max-power", + weight_norm=weight_norm, + depth=None, + noise_csd=noise_csd, + inversion=inversion, + ) n_orient = 1 - assert filters['weights'].shape == (n_freq, n_verts * n_orient, n_channels) - assert not filters['is_free_ori'] + assert filters["weights"].shape == (n_freq, n_verts * n_orient, n_channels) + assert not filters["is_free_ori"] _assert_weight_norm(filters, G) # From here on, only work on a single frequency csd = csd[0] # Test using a real-valued filter - filters = make_dics(epochs.info, fwd_surf, csd, label=label, - pick_ori='normal', real_filter=True, - noise_csd=noise_csd) - assert not np.iscomplexobj(filters['weights']) + filters = make_dics( + epochs.info, + fwd_surf, + csd, + label=label, + pick_ori="normal", + real_filter=True, + noise_csd=noise_csd, + ) + assert not np.iscomplexobj(filters["weights"]) # Test forward normalization. When inversion='single', the power of a # unit-noise CSD should be 1, even without weight normalization. @@ -294,105 +360,151 @@ def test_make_dics(tmp_path, _load_forward, idx, whiten): inds = np.triu_indices(csd.n_channels) # Using [:, :] syntax for in-place broadcasting csd_noise._data[:, :] = np.eye(csd.n_channels)[inds][:, np.newaxis] - filters = make_dics(epochs.info, fwd_surf, csd_noise, label=label, - weight_norm=None, depth=1., noise_csd=noise_csd, - inversion='single') - w = filters['weights'][0][:3] - assert_allclose(np.diag(w.dot(w.conjugate().T)), 1.0, rtol=1e-6, - atol=0) + filters = make_dics( + epochs.info, + fwd_surf, + csd_noise, + label=label, + weight_norm=None, + depth=1.0, + noise_csd=noise_csd, + inversion="single", + ) + w = filters["weights"][0][:3] + assert_allclose(np.diag(w.dot(w.conjugate().T)), 1.0, rtol=1e-6, atol=0) # Test turning off both forward and weight normalization - filters = make_dics(epochs.info, fwd_surf, csd, label=label, - weight_norm=None, depth=None, noise_csd=noise_csd) - w = filters['weights'][0][:3] - assert not np.allclose(np.diag(w.dot(w.conjugate().T)), 1.0, - rtol=1e-2, atol=0) + filters = make_dics( + epochs.info, + fwd_surf, + csd, + label=label, + weight_norm=None, + depth=None, + noise_csd=noise_csd, + ) + w = filters["weights"][0][:3] + assert not np.allclose(np.diag(w.dot(w.conjugate().T)), 1.0, rtol=1e-2, atol=0) # Test neural-activity-index weight normalization. It should be a scaled # version of the unit-noise-gain beamformer. filters_nai = make_dics( - epochs.info, fwd_surf, csd, label=label, pick_ori='max-power', - weight_norm='nai', depth=None, noise_csd=noise_csd) - w_nai = filters_nai['weights'][0] + epochs.info, + fwd_surf, + csd, + label=label, + pick_ori="max-power", + weight_norm="nai", + depth=None, + noise_csd=noise_csd, + ) + w_nai = filters_nai["weights"][0] filters_ung = make_dics( - epochs.info, fwd_surf, csd, label=label, pick_ori='max-power', - weight_norm='unit-noise-gain', depth=None, noise_csd=noise_csd) - w_ung = filters_ung['weights'][0] - assert_allclose(np.corrcoef(np.abs(w_nai).ravel(), - np.abs(w_ung).ravel()), 1, atol=1e-7) + epochs.info, + fwd_surf, + csd, + label=label, + pick_ori="max-power", + weight_norm="unit-noise-gain", + depth=None, + noise_csd=noise_csd, + ) + w_ung = filters_ung["weights"][0] + assert_allclose( + np.corrcoef(np.abs(w_nai).ravel(), np.abs(w_ung).ravel()), 1, atol=1e-7 + ) # Test whether spatial filter contains src_type - assert 'src_type' in filters + assert "src_type" in filters fname = tmp_path / "filters-dics.h5" filters.save(fname) filters_read = read_beamformer(fname) assert isinstance(filters, Beamformer) assert isinstance(filters_read, Beamformer) - for key in ['tmin', 'tmax']: # deal with strictness of object_diff - setattr(filters['csd'], key, np.float64(getattr(filters['csd'], key))) - assert object_diff(filters, filters_read) == '' + for key in ["tmin", "tmax"]: # deal with strictness of object_diff + setattr(filters["csd"], key, np.float64(getattr(filters["csd"], key))) + assert object_diff(filters, filters_read) == "" def _fwd_dist(power, fwd, vertices, source_ind, tidx=1): idx = np.argmax(power.data[:, tidx]) - rr_got = fwd['src'][0]['rr'][vertices[idx]] - rr_want = fwd['src'][0]['rr'][vertices[source_ind]] + rr_got = fwd["src"][0]["rr"][vertices[idx]] + rr_want = fwd["src"][0]["rr"][vertices[source_ind]] return np.linalg.norm(rr_got - rr_want) @idx_param -@pytest.mark.parametrize('inversion, weight_norm', [ - ('single', None), - ('matrix', 'unit-noise-gain'), -]) +@pytest.mark.parametrize( + "inversion, weight_norm", + [ + ("single", None), + ("matrix", "unit-noise-gain"), + ], +) def test_apply_dics_csd(_load_forward, idx, inversion, weight_norm): """Test applying a DICS beamformer to a CSD matrix.""" fwd_free, fwd_surf, fwd_fixed, _ = _load_forward - epochs, _, csd, source_vertno, label, vertices, source_ind = \ - _simulate_data(fwd_fixed, idx) + epochs, _, csd, source_vertno, label, vertices, source_ind = _simulate_data( + fwd_fixed, idx + ) reg = 1 # Lots of regularization for our toy dataset - with pytest.raises(ValueError, match='several sensor types'): + with pytest.raises(ValueError, match="several sensor types"): make_dics(epochs.info, fwd_free, csd) - epochs.pick_types(meg='grad') + epochs.pick_types(meg="grad") # Try different types of forward models - assert label.hemi == 'lh' + assert label.hemi == "lh" for fwd in [fwd_free, fwd_surf, fwd_fixed]: - filters = make_dics(epochs.info, fwd, csd, label=label, reg=reg, - inversion=inversion, weight_norm=weight_norm) + filters = make_dics( + epochs.info, + fwd, + csd, + label=label, + reg=reg, + inversion=inversion, + weight_norm=weight_norm, + ) power, f = apply_dics_csd(csd, filters) assert f == [10, 20] # Did we find the true source at 20 Hz? dist = _fwd_dist(power, fwd_free, vertices, source_ind) - assert dist == 0. + assert dist == 0.0 # Is the signal stronger at 20 Hz than 10? assert power.data[source_ind, 1] > power.data[source_ind, 0] -@pytest.mark.parametrize('pick_ori', [None, 'normal', 'max-power', 'vector']) -@pytest.mark.parametrize('inversion', ['single', 'matrix']) +@pytest.mark.parametrize("pick_ori", [None, "normal", "max-power", "vector"]) +@pytest.mark.parametrize("inversion", ["single", "matrix"]) @idx_param def test_apply_dics_ori_inv(_load_forward, pick_ori, inversion, idx): """Test picking different orientations and inversion modes.""" fwd_free, fwd_surf, fwd_fixed, fwd_vol = _load_forward - epochs, _, csd, source_vertno, label, vertices, source_ind = \ - _simulate_data(fwd_fixed, idx) - epochs.pick_types(meg='grad') - - reg_ = 5 if inversion == 'matrix' else 1 - filters = make_dics(epochs.info, fwd_surf, csd, label=label, - reg=reg_, pick_ori=pick_ori, - inversion=inversion, depth=None, - weight_norm='unit-noise-gain') + epochs, _, csd, source_vertno, label, vertices, source_ind = _simulate_data( + fwd_fixed, idx + ) + epochs.pick_types(meg="grad") + + reg_ = 5 if inversion == "matrix" else 1 + filters = make_dics( + epochs.info, + fwd_surf, + csd, + label=label, + reg=reg_, + pick_ori=pick_ori, + inversion=inversion, + depth=None, + weight_norm="unit-noise-gain", + ) power, f = apply_dics_csd(csd, filters) assert f == [10, 20] dist = _fwd_dist(power, fwd_surf, vertices, source_ind) # This is 0. for unit-noise-gain-invariant: - assert dist <= (0.02 if inversion == 'matrix' else 0.) + assert dist <= (0.02 if inversion == "matrix" else 0.0) assert power.data[source_ind, 1] > power.data[source_ind, 0] # Test unit-noise-gain weighting @@ -400,40 +512,55 @@ def test_apply_dics_ori_inv(_load_forward, pick_ori, inversion, idx): inds = np.triu_indices(csd.n_channels) csd_noise._data[...] = np.eye(csd.n_channels)[inds][:, np.newaxis] noise_power, f = apply_dics_csd(csd_noise, filters) - want_norm = 3 if pick_ori in (None, 'vector') else 1 + want_norm = 3 if pick_ori in (None, "vector") else 1 assert_allclose(noise_power.data, want_norm, atol=1e-7) # Test filter with forward normalization instead of weight # normalization - filters = make_dics(epochs.info, fwd_surf, csd, label=label, - reg=reg_, pick_ori=pick_ori, - inversion=inversion, weight_norm=None, - depth=1.) + filters = make_dics( + epochs.info, + fwd_surf, + csd, + label=label, + reg=reg_, + pick_ori=pick_ori, + inversion=inversion, + weight_norm=None, + depth=1.0, + ) power, f = apply_dics_csd(csd, filters) assert f == [10, 20] dist = _fwd_dist(power, fwd_surf, vertices, source_ind) mat_tol = {0: 0.055, 100: 0.20, 200: 0.015, 233: 0.035}[idx] - max_ = (mat_tol if inversion == 'matrix' else 0.) + max_ = mat_tol if inversion == "matrix" else 0.0 assert 0 <= dist <= max_ assert power.data[source_ind, 1] > power.data[source_ind, 0] def _nearest_vol_ind(fwd_vol, fwd, vertices, source_ind): return _compute_nearest( - fwd_vol['source_rr'], - fwd['src'][0]['rr'][vertices][source_ind][np.newaxis])[0] + fwd_vol["source_rr"], fwd["src"][0]["rr"][vertices][source_ind][np.newaxis] + )[0] @idx_param def test_real(_load_forward, idx): """Test using a real-valued filter.""" fwd_free, fwd_surf, fwd_fixed, fwd_vol = _load_forward - epochs, _, csd, source_vertno, label, vertices, source_ind = \ - _simulate_data(fwd_fixed, idx) - epochs.pick_types(meg='grad') + epochs, _, csd, source_vertno, label, vertices, source_ind = _simulate_data( + fwd_fixed, idx + ) + epochs.pick_types(meg="grad") reg = 1 # Lots of regularization for our toy dataset - filters_real = make_dics(epochs.info, fwd_surf, csd, label=label, reg=reg, - real_filter=True, inversion='single') + filters_real = make_dics( + epochs.info, + fwd_surf, + csd, + label=label, + reg=reg, + real_filter=True, + inversion="single", + ) # Also test here that no warnings are thrown - implemented to check whether # src should not be None warning occurs: power, f = apply_dics_csd(csd, filters_real) @@ -444,9 +571,16 @@ def test_real(_load_forward, idx): assert power.data[source_ind, 1] > power.data[source_ind, 0] # Test rank reduction - filters_real = make_dics(epochs.info, fwd_surf, csd, label=label, reg=5, - pick_ori='max-power', inversion='matrix', - reduce_rank=True) + filters_real = make_dics( + epochs.info, + fwd_surf, + csd, + label=label, + reg=5, + pick_ori="max-power", + inversion="matrix", + reduce_rank=True, + ) power, f = apply_dics_csd(csd, filters_real) assert f == [10, 20] dist = _fwd_dist(power, fwd_surf, vertices, source_ind) @@ -454,57 +588,58 @@ def test_real(_load_forward, idx): assert power.data[source_ind, 1] > power.data[source_ind, 0] # Test computing source power on a volume source space - filters_vol = make_dics(epochs.info, fwd_vol, csd, reg=reg, - inversion='single') + filters_vol = make_dics(epochs.info, fwd_vol, csd, reg=reg, inversion="single") power, f = apply_dics_csd(csd, filters_vol) vol_source_ind = _nearest_vol_ind(fwd_vol, fwd_surf, vertices, source_ind) assert f == [10, 20] - dist = _fwd_dist( - power, fwd_vol, fwd_vol['src'][0]['vertno'], vol_source_ind) + dist = _fwd_dist(power, fwd_vol, fwd_vol["src"][0]["vertno"], vol_source_ind) vol_tols = {100: 0.008, 200: 0.008} - assert dist <= vol_tols.get(idx, 0.) + assert dist <= vol_tols.get(idx, 0.0) assert power.data[vol_source_ind, 1] > power.data[vol_source_ind, 0] # check whether a filters object without src_type throws expected warning - del filters_vol['src_type'] # emulate 0.16 behaviour to cause warning - with pytest.warns(RuntimeWarning, match='spatial filter does not contain ' - 'src_type'): + del filters_vol["src_type"] # emulate 0.16 behaviour to cause warning + with pytest.warns( + RuntimeWarning, match="spatial filter does not contain " "src_type" + ): apply_dics_csd(csd, filters_vol) -@pytest.mark.filterwarnings("ignore:The use of several sensor types with the" - ":RuntimeWarning") +@pytest.mark.filterwarnings( + "ignore:The use of several sensor types with the" ":RuntimeWarning" +) @idx_param def test_apply_dics_timeseries(_load_forward, idx): """Test DICS applied to timeseries data.""" fwd_free, fwd_surf, fwd_fixed, fwd_vol = _load_forward - epochs, evoked, csd, source_vertno, label, vertices, source_ind = \ - _simulate_data(fwd_fixed, idx) + epochs, evoked, csd, source_vertno, label, vertices, source_ind = _simulate_data( + fwd_fixed, idx + ) reg = 5 # Lots of regularization for our toy dataset - with pytest.raises(ValueError, match='several sensor types'): + with pytest.raises(ValueError, match="several sensor types"): make_dics(evoked.info, fwd_surf, csd) - evoked.pick_types(meg='grad') + evoked.pick_types(meg="grad") - multiple_filters = make_dics(evoked.info, fwd_surf, csd, label=label, - reg=reg) + multiple_filters = make_dics(evoked.info, fwd_surf, csd, label=label, reg=reg) # Sanity checks on the resulting STC after applying DICS on evoked stcs = apply_dics(evoked, multiple_filters) assert isinstance(stcs, list) - assert len(stcs) == len(multiple_filters['weights']) - assert_array_equal(stcs[0].vertices[0], multiple_filters['vertices'][0]) - assert_array_equal(stcs[0].vertices[1], multiple_filters['vertices'][1]) + assert len(stcs) == len(multiple_filters["weights"]) + assert_array_equal(stcs[0].vertices[0], multiple_filters["vertices"][0]) + assert_array_equal(stcs[0].vertices[1], multiple_filters["vertices"][1]) assert_allclose(stcs[0].times, evoked.times) # Applying filters for multiple frequencies on epoch data should fail - with pytest.raises(ValueError, match='computed for a single frequency'): + with pytest.raises(ValueError, match="computed for a single frequency"): apply_dics_epochs(epochs, multiple_filters) # From now on, only apply filters with a single frequency (20 Hz). csd20 = csd.pick_frequency(20) - filters = make_dics(evoked.info, fwd_surf, csd20, label=label, reg=reg, - inversion='single') + filters = make_dics( + evoked.info, fwd_surf, csd20, label=label, reg=reg, inversion="single" + ) # Sanity checks on the resulting STC after applying DICS on epochs. # Also test here that no warnings are thrown - implemented to check whether @@ -513,8 +648,8 @@ def test_apply_dics_timeseries(_load_forward, idx): assert isinstance(stcs, list) assert len(stcs) == 1 - assert_array_equal(stcs[0].vertices[0], filters['vertices'][0]) - assert_array_equal(stcs[0].vertices[1], filters['vertices'][1]) + assert_array_equal(stcs[0].vertices[0], filters["vertices"][0]) + assert_array_equal(stcs[0].vertices[1], filters["vertices"][1]) assert_allclose(stcs[0].times, epochs.times) # Did we find the source? @@ -524,14 +659,14 @@ def test_apply_dics_timeseries(_load_forward, idx): # Apply filters to evoked stc = apply_dics(evoked, filters) - stc = (stc ** 2).mean() + stc = (stc**2).mean() dist = _fwd_dist(stc, fwd_surf, vertices, source_ind, tidx=0) assert dist == 0 # Test if wrong channel selection is detected in application of filter evoked_ch = cp.deepcopy(evoked) evoked_ch.pick_channels(evoked_ch.ch_names[:-1]) - with pytest.raises(ValueError, match='MEG 2633 which is not present'): + with pytest.raises(ValueError, match="MEG 2633 which is not present"): apply_dics(evoked_ch, filters) # Test whether projections are applied, by adding a custom projection @@ -542,13 +677,13 @@ def test_apply_dics_timeseries(_load_forward, idx): proj_matrix = make_projector(p, evoked_proj.ch_names)[0] evoked_proj.add_proj(p) filters_proj = make_dics(evoked_proj.info, fwd_surf, csd20, label=label) - assert_array_equal(filters_proj['proj'], proj_matrix) + assert_array_equal(filters_proj["proj"], proj_matrix) stc_proj = apply_dics(evoked_proj, filters_proj) assert np.any(np.not_equal(stc_noproj.data, stc_proj.data)) # Test detecting incompatible projections - filters_proj['proj'] = filters_proj['proj'][:-1, :-1] - with pytest.raises(ValueError, match='operands could not be broadcast'): + filters_proj["proj"] = filters_proj["proj"][:-1, :-1] + with pytest.raises(ValueError, match="operands could not be broadcast"): apply_dics(evoked_proj, filters_proj) # Test returning a generator @@ -557,30 +692,28 @@ def test_apply_dics_timeseries(_load_forward, idx): assert_array_equal(stcs[0].data, next(stcs_gen).data) # Test computing timecourses on a volume source space - filters_vol = make_dics(evoked.info, fwd_vol, csd20, reg=reg, - inversion='single') + filters_vol = make_dics(evoked.info, fwd_vol, csd20, reg=reg, inversion="single") stc = apply_dics(evoked, filters_vol) - stc = (stc ** 2).mean() + stc = (stc**2).mean() assert stc.data.shape[1] == 1 vol_source_ind = _nearest_vol_ind(fwd_vol, fwd_surf, vertices, source_ind) - dist = _fwd_dist(stc, fwd_vol, fwd_vol['src'][0]['vertno'], vol_source_ind, - tidx=0) + dist = _fwd_dist(stc, fwd_vol, fwd_vol["src"][0]["vertno"], vol_source_ind, tidx=0) vol_tols = {100: 0.008, 200: 0.015} - vol_tol = vol_tols.get(idx, 0.) + vol_tol = vol_tols.get(idx, 0.0) assert dist <= vol_tol # check whether a filters object without src_type throws expected warning - del filters_vol['src_type'] # emulate 0.16 behaviour to cause warning - with pytest.warns(RuntimeWarning, match='filter does not contain src_typ'): + del filters_vol["src_type"] # emulate 0.16 behaviour to cause warning + with pytest.warns(RuntimeWarning, match="filter does not contain src_typ"): apply_dics_epochs(epochs, filters_vol) @testing.requires_testing_data -@pytest.mark.parametrize('return_generator', (True, False)) +@pytest.mark.parametrize("return_generator", (True, False)) def test_apply_dics_tfr(return_generator): """Test DICS applied to time-frequency objects.""" info = read_info(fname_raw) - info = pick_info(info, pick_types(info, meg='grad')) + info = pick_info(info, pick_types(info, meg="grad")) forward = mne.read_forward_solution(fname_fwd) rng = np.random.default_rng(11) @@ -589,7 +722,7 @@ def test_apply_dics_tfr(return_generator): n_chans = len(info.ch_names) freqs = [8, 9] n_times = 300 - times = np.arange(n_times) / info['sfreq'] + times = np.arange(n_times) / info["sfreq"] data = rng.random((n_epochs, n_chans, len(freqs), n_times)) data *= 1e-6 data = data + data * 1j # add imag. component to simulate phase @@ -606,18 +739,23 @@ def test_apply_dics_tfr(return_generator): assert_allclose(stcs[0][0].times, times) assert len(stcs) == len(epochs_tfr) # check same number of epochs assert all([len(s) == len(freqs) for s in stcs]) # check nested freqs - assert all([s.data.shape == (forward['nsource'], n_times) - for these_stcs in stcs for s in these_stcs]) + assert all( + [ + s.data.shape == (forward["nsource"], n_times) + for these_stcs in stcs + for s in these_stcs + ] + ) # Compute power from the source space TFR. This should yield the same # result as the apply_dics_csd function. - source_power = np.zeros((forward['nsource'], len(freqs))) + source_power = np.zeros((forward["nsource"], len(freqs))) for stcs_epoch in stcs: for i, stc_freq in enumerate(stcs_epoch): power = (stc_freq.data * np.conj(stc_freq.data)).real power = power.mean(axis=-1) # mean over time # Scaling by sampling frequency for compatibility with Matlab - power /= epochs_tfr.info['sfreq'] + power /= epochs_tfr.info["sfreq"] source_power[:, i] += power.T source_power /= n_epochs @@ -628,86 +766,111 @@ def test_apply_dics_tfr(return_generator): # Test that real-value only data fails, due to non-linearity of computing # power, it is recommended to transform to source-space first before # converting to power. - with pytest.raises(RuntimeError, - match='Time-frequency data must be complex'): + with pytest.raises(RuntimeError, match="Time-frequency data must be complex"): epochs_tfr_real = epochs_tfr.copy() epochs_tfr_real.data = epochs_tfr_real.data.real stcs = apply_dics_tfr_epochs(epochs_tfr_real, filters) filters_vector = filters.copy() - filters_vector['pick_ori'] = 'vector' - with pytest.warns(match='vector solution'): + filters_vector["pick_ori"] = "vector" + with pytest.warns(match="vector solution"): apply_dics_tfr_epochs(epochs_tfr, filters_vector) def _cov_as_csd(cov, info): rng = np.random.RandomState(0) - assert cov['data'].ndim == 2 - assert len(cov['data']) == len(cov['names']) + assert cov["data"].ndim == 2 + assert len(cov["data"]) == len(cov["names"]) # we need to make this have at least some complex structure - data = cov['data'] + 1e-1 * _rand_csd(rng, info) + data = cov["data"] + 1e-1 * _rand_csd(rng, info) assert data.dtype == np.complex128 - return CrossSpectralDensity(_sym_mat_to_vector(data), cov['names'], 0., 16) + return CrossSpectralDensity(_sym_mat_to_vector(data), cov["names"], 0.0, 16) # Just test free ori here (assume fixed is same as LCMV if these are) # Changes here should be synced with test_lcmv.py @pytest.mark.slowtest @pytest.mark.parametrize( - 'reg, pick_ori, weight_norm, use_cov, depth, lower, upper, real_filter', [ - (0.05, 'vector', 'unit-noise-gain-invariant', - False, None, 26, 28, True), - (0.05, 'vector', 'unit-noise-gain', False, None, 13, 15, True), - (0.05, 'vector', 'nai', False, None, 13, 15, True), - (0.05, None, 'unit-noise-gain-invariant', False, None, 26, 28, False), - (0.05, None, 'unit-noise-gain-invariant', True, None, 40, 42, False), - (0.05, None, 'unit-noise-gain-invariant', True, None, 40, 42, True), - (0.05, None, 'unit-noise-gain', False, None, 13, 14, False), - (0.05, None, 'unit-noise-gain', True, None, 35, 37, False), - (0.05, None, 'nai', True, None, 35, 37, False), + "reg, pick_ori, weight_norm, use_cov, depth, lower, upper, real_filter", + [ + (0.05, "vector", "unit-noise-gain-invariant", False, None, 26, 28, True), + (0.05, "vector", "unit-noise-gain", False, None, 13, 15, True), + (0.05, "vector", "nai", False, None, 13, 15, True), + (0.05, None, "unit-noise-gain-invariant", False, None, 26, 28, False), + (0.05, None, "unit-noise-gain-invariant", True, None, 40, 42, False), + (0.05, None, "unit-noise-gain-invariant", True, None, 40, 42, True), + (0.05, None, "unit-noise-gain", False, None, 13, 14, False), + (0.05, None, "unit-noise-gain", True, None, 35, 37, False), + (0.05, None, "nai", True, None, 35, 37, False), (0.05, None, None, True, None, 12, 14, False), (0.05, None, None, True, 0.8, 39, 43, False), - (0.05, 'max-power', 'unit-noise-gain-invariant', False, None, 17, 20, - False), - (0.05, 'max-power', 'unit-noise-gain', False, None, 17, 20, False), - (0.05, 'max-power', 'unit-noise-gain', False, None, 17, 20, True), - (0.05, 'max-power', 'nai', True, None, 21, 24, False), - (0.05, 'max-power', None, True, None, 7, 10, False), - (0.05, 'max-power', None, True, 0.8, 15, 18, False), + (0.05, "max-power", "unit-noise-gain-invariant", False, None, 17, 20, False), + (0.05, "max-power", "unit-noise-gain", False, None, 17, 20, False), + (0.05, "max-power", "unit-noise-gain", False, None, 17, 20, True), + (0.05, "max-power", "nai", True, None, 21, 24, False), + (0.05, "max-power", None, True, None, 7, 10, False), + (0.05, "max-power", None, True, 0.8, 15, 18, False), # skip most no-reg tests, assume others are equal to LCMV if these are (0.00, None, None, True, None, 21, 32, False), - (0.00, 'max-power', None, True, None, 13, 19, False), - ]) -def test_localization_bias_free(bias_params_free, reg, pick_ori, weight_norm, - use_cov, depth, lower, upper, real_filter): + (0.00, "max-power", None, True, None, 13, 19, False), + ], +) +def test_localization_bias_free( + bias_params_free, + reg, + pick_ori, + weight_norm, + use_cov, + depth, + lower, + upper, + real_filter, +): """Test localization bias for free-orientation DICS.""" evoked, fwd, noise_cov, data_cov, want = bias_params_free noise_csd = _cov_as_csd(noise_cov, evoked.info) data_csd = _cov_as_csd(data_cov, evoked.info) del noise_cov, data_cov if not use_cov: - evoked.pick_types(meg='grad') + evoked.pick_types(meg="grad") noise_csd = None filters = make_dics( - evoked.info, fwd, data_csd, reg, noise_csd, pick_ori=pick_ori, - weight_norm=weight_norm, depth=depth, real_filter=real_filter) + evoked.info, + fwd, + data_csd, + reg, + noise_csd, + pick_ori=pick_ori, + weight_norm=weight_norm, + depth=depth, + real_filter=real_filter, + ) loc = apply_dics(evoked, filters).data - loc = np.linalg.norm(loc, axis=1) if pick_ori == 'vector' else np.abs(loc) + loc = np.linalg.norm(loc, axis=1) if pick_ori == "vector" else np.abs(loc) # Compute the percentage of sources for which there is no loc bias: perc = (want == np.argmax(loc, axis=0)).mean() * 100 assert lower <= perc <= upper @pytest.mark.parametrize( - 'weight_norm, lower, upper, lower_ori, upper_ori, real_filter', [ - ('unit-noise-gain-invariant', 57, 58, 0.60, 0.61, False), - ('unit-noise-gain', 57, 58, 0.60, 0.61, False), - ('unit-noise-gain', 57, 58, 0.60, 0.61, True), + "weight_norm, lower, upper, lower_ori, upper_ori, real_filter", + [ + ("unit-noise-gain-invariant", 57, 58, 0.60, 0.61, False), + ("unit-noise-gain", 57, 58, 0.60, 0.61, False), + ("unit-noise-gain", 57, 58, 0.60, 0.61, True), (None, 27, 28, 0.56, 0.57, False), - ]) -def test_orientation_max_power(bias_params_fixed, bias_params_free, - weight_norm, lower, upper, lower_ori, upper_ori, - real_filter): + ], +) +def test_orientation_max_power( + bias_params_fixed, + bias_params_free, + weight_norm, + lower, + upper, + lower_ori, + upper_ori, + real_filter, +): """Test orientation selection for bias for max-power DICS.""" # we simulate data for the fixed orientation forward and beamform using # the free orientation forward, and check the orientation match at the end @@ -716,11 +879,19 @@ def test_orientation_max_power(bias_params_fixed, bias_params_free, data_csd = _cov_as_csd(data_cov, evoked.info) del data_cov, noise_cov fwd = bias_params_free[1] - filters = make_dics(evoked.info, fwd, data_csd, 0.05, noise_csd, - pick_ori='max-power', weight_norm=weight_norm, - depth=None, real_filter=real_filter) + filters = make_dics( + evoked.info, + fwd, + data_csd, + 0.05, + noise_csd, + pick_ori="max-power", + weight_norm=weight_norm, + depth=None, + real_filter=real_filter, + ) loc = np.abs(apply_dics(evoked, filters).data) - ori = filters['max_power_ori'][0] + ori = filters["max_power_ori"][0] assert ori.shape == (246, 3) loc = np.abs(loc) # Compute the percentage of sources for which there is no loc bias: @@ -730,11 +901,10 @@ def test_orientation_max_power(bias_params_fixed, bias_params_free, assert lower <= perc <= upper # Compute the dot products of our forward normals and # assert we get some hopefully reasonable agreement - assert fwd['coord_frame'] == FIFF.FIFFV_COORD_HEAD - nn = np.concatenate( - [s['nn'][v] for s, v in zip(fwd['src'], filters['vertices'])]) + assert fwd["coord_frame"] == FIFF.FIFFV_COORD_HEAD + nn = np.concatenate([s["nn"][v] for s, v in zip(fwd["src"], filters["vertices"])]) nn = nn[want] - nn = apply_trans(invert_transform(fwd['mri_head_t']), nn, move=False) + nn = apply_trans(invert_transform(fwd["mri_head_t"]), nn, move=False) assert_allclose(np.linalg.norm(nn, axis=1), 1, atol=1e-6) assert_allclose(np.linalg.norm(ori, axis=1), 1, atol=1e-12) dots = np.abs((nn[mask] * ori[mask]).sum(-1)) @@ -746,40 +916,46 @@ def test_orientation_max_power(bias_params_fixed, bias_params_free, @testing.requires_testing_data @idx_param -@pytest.mark.parametrize('whiten', (False, True)) +@pytest.mark.parametrize("whiten", (False, True)) def test_make_dics_rank(_load_forward, idx, whiten): """Test making DICS beamformer filters with rank param.""" _, fwd_surf, fwd_fixed, _ = _load_forward epochs, _, csd, _, label, _, _ = _simulate_data(fwd_fixed, idx) if whiten: noise_csd, want_rank = _make_rand_csd(epochs.info, csd) - kind = 'mag + grad' + kind = "mag + grad" else: noise_csd = None - epochs.pick_types(meg='grad') + epochs.pick_types(meg="grad") want_rank = len(epochs.ch_names) assert want_rank == 41 - kind = 'grad' + kind = "grad" with catch_logging() as log: filters = make_dics( - epochs.info, fwd_surf, csd, label=label, noise_csd=noise_csd, - verbose=True) + epochs.info, fwd_surf, csd, label=label, noise_csd=noise_csd, verbose=True + ) log = log.getvalue() - assert f'Estimated rank ({kind}): {want_rank}' in log, log + assert f"Estimated rank ({kind}): {want_rank}" in log, log stc, _ = apply_dics_csd(csd, filters) other_rank = want_rank - 1 # shouldn't make a huge difference use_rank = dict(meg=other_rank) if not whiten: # XXX it's a bug that our rank functions don't treat "meg" # properly here... - use_rank['grad'] = use_rank.pop('meg') + use_rank["grad"] = use_rank.pop("meg") with catch_logging() as log: filters_2 = make_dics( - epochs.info, fwd_surf, csd, label=label, noise_csd=noise_csd, - rank=use_rank, verbose=True) + epochs.info, + fwd_surf, + csd, + label=label, + noise_csd=noise_csd, + rank=use_rank, + verbose=True, + ) log = log.getvalue() - assert f'Computing rank from covariance with rank={use_rank}' in log, log + assert f"Computing rank from covariance with rank={use_rank}" in log, log stc_2, _ = apply_dics_csd(csd, filters_2) corr = np.corrcoef(stc_2.data.ravel(), stc.data.ravel())[0, 1] assert 0.8 < corr < 0.999999 @@ -787,10 +963,15 @@ def test_make_dics_rank(_load_forward, idx, whiten): # degenerate conditions if whiten: # make rank deficient - data = noise_csd.get_data(0.) + data = noise_csd.get_data(0.0) data[0] = data[:0] = 0 noise_csd._data[:, 0] = _sym_mat_to_vector(data) - with pytest.raises(ValueError, match='meg data rank.*the noise rank'): + with pytest.raises(ValueError, match="meg data rank.*the noise rank"): filters = make_dics( - epochs.info, fwd_surf, csd, label=label, noise_csd=noise_csd, - verbose=True) + epochs.info, + fwd_surf, + csd, + label=label, + noise_csd=noise_csd, + verbose=True, + ) diff --git a/mne/beamformer/tests/test_external.py b/mne/beamformer/tests/test_external.py index a20cb3b3e79..6195f572bae 100644 --- a/mne/beamformer/tests/test_external.py +++ b/mne/beamformer/tests/test_external.py @@ -17,19 +17,15 @@ ft_data_path = data_path / "fieldtrip" / "beamformer" fname_raw = data_path / "MEG" / "sample" / "sample_audvis_trunc_raw.fif" fname_cov = data_path / "MEG" / "sample" / "sample_audvis_trunc-cov.fif" -fname_fwd = ( - data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-4-fwd.fif" -) -fname_fwd_vol = ( - data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-vol-7-fwd.fif" -) +fname_fwd = data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-4-fwd.fif" +fname_fwd_vol = data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-vol-7-fwd.fif" fname_event = data_path / "MEG" / "sample" / "sample_audvis_trunc_raw-eve.fif" fname_label = data_path / "MEG" / "sample" / "labels" / "Aud-lh.label" reject = dict(grad=4000e-13, mag=4e-12) -@pytest.fixture(scope='function', params=[testing._pytest_param()]) +@pytest.fixture(scope="function", params=[testing._pytest_param()]) def _get_bf_data(save_fieldtrip=False): raw, epochs, evoked, data_cov, _, _, _, _, _, fwd = _get_data(proj=False) @@ -38,28 +34,29 @@ def _get_bf_data(save_fieldtrip=False): raw.save(ft_data_path / "raw.fif", overwrite=True) # src (tris are not available in fwd['src'] once imported into MATLAB) - src = fwd['src'].copy() + src = fwd["src"].copy() mne.write_source_spaces( - ft_data_path / "src.fif", src, verbose='error', overwrite=True + ft_data_path / "src.fif", src, verbose="error", overwrite=True ) # pick gradiometers only: - epochs.pick_types(meg='grad') - evoked.pick_types(meg='grad') + epochs.pick_types(meg="grad") + evoked.pick_types(meg="grad") # compute covariance matrix (ignore false alarm about no baseline) - data_cov = mne.compute_covariance(epochs, tmin=0.04, tmax=0.145, - method='empirical', verbose='error') + data_cov = mne.compute_covariance( + epochs, tmin=0.04, tmax=0.145, method="empirical", verbose="error" + ) if save_fieldtrip is True: # if the covariance matrix and epochs need resaving: # data covariance: cov_savepath = ft_data_path / "sample_cov.mat" - sample_cov = {'sample_cov': data_cov['data']} + sample_cov = {"sample_cov": data_cov["data"]} savemat(cov_savepath, sample_cov) # evoked data: ev_savepath = ft_data_path / "sample_evoked.mat" - data_ev = {'sample_evoked': evoked.data} + data_ev = {"sample_evoked": evoked.data} savemat(ev_savepath, data_ev) return evoked, data_cov, fwd @@ -67,23 +64,33 @@ def _get_bf_data(save_fieldtrip=False): # beamformer types to be tested: unit-gain (vector and scalar) and # unit-noise-gain (time series and power output [apply_lcmv_cov]) -@requires_version('pymatreader') -@pytest.mark.parametrize('bf_type, weight_norm, pick_ori, pwr', [ - ['ug_scal', None, 'max-power', False], - ['ung', 'unit-noise-gain', 'max-power', False], - ['ung_pow', 'unit-noise-gain', 'max-power', True], - ['ug_vec', None, 'vector', False], - ['ung_vec', 'unit-noise-gain', 'vector', False], -]) +@requires_version("pymatreader") +@pytest.mark.parametrize( + "bf_type, weight_norm, pick_ori, pwr", + [ + ["ug_scal", None, "max-power", False], + ["ung", "unit-noise-gain", "max-power", False], + ["ung_pow", "unit-noise-gain", "max-power", True], + ["ug_vec", None, "vector", False], + ["ung_vec", "unit-noise-gain", "vector", False], + ], +) def test_lcmv_fieldtrip(_get_bf_data, bf_type, weight_norm, pick_ori, pwr): """Test LCMV vs fieldtrip output.""" from pymatreader import read_mat + evoked, data_cov, fwd = _get_bf_data # run the MNE-Python beamformer - filters = make_lcmv(evoked.info, fwd, data_cov=data_cov, - noise_cov=None, pick_ori=pick_ori, reg=0.05, - weight_norm=weight_norm) + filters = make_lcmv( + evoked.info, + fwd, + data_cov=data_cov, + noise_cov=None, + pick_ori=pick_ori, + reg=0.05, + weight_norm=weight_norm, + ) if pwr: stc_mne = apply_lcmv_cov(data_cov, filters) else: @@ -91,18 +98,21 @@ def test_lcmv_fieldtrip(_get_bf_data, bf_type, weight_norm, pick_ori, pwr): # load the FieldTrip output ft_fname = ft_data_path / ("ft_source_" + bf_type + "-vol.mat") - stc_ft_data = read_mat(ft_fname)['stc'] + stc_ft_data = read_mat(ft_fname)["stc"] if stc_ft_data.ndim == 1: stc_ft_data.shape = (stc_ft_data.size, 1) if stc_mne.data.ndim == 2: signs = np.sign((stc_mne.data * stc_ft_data).sum(-1, keepdims=True)) if pwr: - assert_array_equal(signs, 1.) + assert_array_equal(signs, 1.0) stc_mne.data *= signs assert stc_ft_data.shape == stc_mne.data.shape - if pick_ori == 'vector': + if pick_ori == "vector": # compare norms first - assert_allclose(np.linalg.norm(stc_mne.data, axis=1), - np.linalg.norm(stc_ft_data, axis=1), rtol=1e-6) + assert_allclose( + np.linalg.norm(stc_mne.data, axis=1), + np.linalg.norm(stc_ft_data, axis=1), + rtol=1e-6, + ) assert_allclose(stc_mne.data, stc_ft_data, rtol=1e-6) diff --git a/mne/beamformer/tests/test_lcmv.py b/mne/beamformer/tests/test_lcmv.py index 7f8e654c9bf..ae7a64f844e 100644 --- a/mne/beamformer/tests/test_lcmv.py +++ b/mne/beamformer/tests/test_lcmv.py @@ -5,17 +5,35 @@ import numpy as np from scipy import linalg from scipy.spatial.distance import cdist -from numpy.testing import (assert_array_almost_equal, assert_array_equal, - assert_allclose, assert_array_less) +from numpy.testing import ( + assert_array_almost_equal, + assert_array_equal, + assert_allclose, + assert_array_less, +) import mne from mne.transforms import apply_trans, invert_transform -from mne import (convert_forward_solution, read_forward_solution, compute_rank, - VolVectorSourceEstimate, VolSourceEstimate, EvokedArray, - pick_channels_cov, read_vectorview_selection) -from mne.beamformer import (make_lcmv, apply_lcmv, apply_lcmv_epochs, - apply_lcmv_raw, Beamformer, - read_beamformer, apply_lcmv_cov, make_dics) +from mne import ( + convert_forward_solution, + read_forward_solution, + compute_rank, + VolVectorSourceEstimate, + VolSourceEstimate, + EvokedArray, + pick_channels_cov, + read_vectorview_selection, +) +from mne.beamformer import ( + make_lcmv, + apply_lcmv, + apply_lcmv_epochs, + apply_lcmv_raw, + Beamformer, + read_beamformer, + apply_lcmv_cov, + make_dics, +) from mne.beamformer._compute_beamformer import _prepare_beamformer_input from mne.datasets import testing from mne.io.compensator import set_current_comp @@ -23,19 +41,14 @@ from mne.minimum_norm import make_inverse_operator, apply_inverse from mne.minimum_norm.tests.test_inverse import _assert_free_ori_match from mne.simulation import simulate_evoked -from mne.utils import (object_diff, requires_version, catch_logging, - _record_warnings) +from mne.utils import object_diff, requires_version, catch_logging, _record_warnings data_path = testing.data_path(download=False) fname_raw = data_path / "MEG" / "sample" / "sample_audvis_trunc_raw.fif" fname_cov = data_path / "MEG" / "sample" / "sample_audvis_trunc-cov.fif" -fname_fwd = ( - data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-4-fwd.fif" -) -fname_fwd_vol = ( - data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-vol-7-fwd.fif" -) +fname_fwd = data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-4-fwd.fif" +fname_fwd_vol = data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-vol-7-fwd.fif" fname_event = data_path / "MEG" / "sample" / "sample_audvis_trunc_raw-eve.fif" fname_label = data_path / "MEG" / "sample" / "labels" / "Aud-lh.label" ctf_fname = data_path / "CTF" / "somMDYO-18av.ds" @@ -49,18 +62,25 @@ def _read_forward_solution_meg(*args, **kwargs): return mne.pick_types_forward(fwd, meg=True, eeg=False) -def _get_data(tmin=-0.1, tmax=0.15, all_forward=True, epochs=True, - epochs_preload=True, data_cov=True, proj=True): +def _get_data( + tmin=-0.1, + tmax=0.15, + all_forward=True, + epochs=True, + epochs_preload=True, + data_cov=True, + proj=True, +): """Read in data used in tests.""" label = mne.read_label(fname_label) events = mne.read_events(fname_event) raw = mne.io.read_raw_fif(fname_raw, preload=True) forward = mne.read_forward_solution(fname_fwd) if all_forward: - forward_surf_ori = _read_forward_solution_meg( - fname_fwd, surf_ori=True) + forward_surf_ori = _read_forward_solution_meg(fname_fwd, surf_ori=True) forward_fixed = _read_forward_solution_meg( - fname_fwd, force_fixed=True, surf_ori=True, use_cps=False) + fname_fwd, force_fixed=True, surf_ori=True, use_cps=False + ) forward_vol = _read_forward_solution_meg(fname_fwd_vol) else: forward_surf_ori = None @@ -70,11 +90,10 @@ def _get_data(tmin=-0.1, tmax=0.15, all_forward=True, epochs=True, event_id, tmin, tmax = 1, tmin, tmax # Setup for reading the raw data - raw.info['bads'] = ['MEG 2443', 'EEG 053'] # 2 bad channels + raw.info["bads"] = ["MEG 2443", "EEG 053"] # 2 bad channels # Set up pick list: MEG - bad channels - left_temporal_channels = read_vectorview_selection('Left-temporal') - picks = mne.pick_types(raw.info, meg=True, - selection=left_temporal_channels) + left_temporal_channels = read_vectorview_selection("Left-temporal") + picks = mne.pick_types(raw.info, meg=True, selection=left_temporal_channels) picks = picks[::2] # decimate for speed # add a couple channels we will consider bad bad_picks = [100, 101] @@ -84,7 +103,7 @@ def _get_data(tmin=-0.1, tmax=0.15, all_forward=True, epochs=True, raw.pick_channels([raw.ch_names[ii] for ii in picks], ordered=True) del picks - raw.info['bads'] = bads # add more bads + raw.info["bads"] = bads # add more bads if proj: raw.info.normalize_proj() # avoid projection warnings else: @@ -93,8 +112,16 @@ def _get_data(tmin=-0.1, tmax=0.15, all_forward=True, epochs=True, if epochs: # Read epochs epochs = mne.Epochs( - raw, events, event_id, tmin, tmax, proj=True, - baseline=(None, 0), preload=epochs_preload, reject=reject) + raw, + events, + event_id, + tmin, + tmax, + proj=True, + baseline=(None, 0), + preload=epochs_preload, + reject=reject, + ) if epochs_preload: epochs.resample(200, npad=0) epochs.crop(0, None) @@ -106,17 +133,29 @@ def _get_data(tmin=-0.1, tmax=0.15, all_forward=True, epochs=True, info = raw.info noise_cov = mne.read_cov(fname_cov) - noise_cov['projs'] = [] # avoid warning - noise_cov = mne.cov.regularize(noise_cov, info, mag=0.05, grad=0.05, - eeg=0.1, proj=True, rank=None) + noise_cov["projs"] = [] # avoid warning + noise_cov = mne.cov.regularize( + noise_cov, info, mag=0.05, grad=0.05, eeg=0.1, proj=True, rank=None + ) if data_cov: data_cov = mne.compute_covariance( - epochs, tmin=0.04, tmax=0.145, verbose='error') # baseline warning + epochs, tmin=0.04, tmax=0.145, verbose="error" + ) # baseline warning else: data_cov = None - return raw, epochs, evoked, data_cov, noise_cov, label, forward,\ - forward_surf_ori, forward_fixed, forward_vol + return ( + raw, + epochs, + evoked, + data_cov, + noise_cov, + label, + forward, + forward_surf_ori, + forward_fixed, + forward_vol, + ) @pytest.mark.slowtest @@ -127,40 +166,43 @@ def test_lcmv_vector(): # For speed and for rank-deficiency calculation simplicity, # just use grads - info = mne.pick_info(info, mne.pick_types(info, meg='grad', exclude=())) + info = mne.pick_info(info, mne.pick_types(info, meg="grad", exclude=())) with info._unlock(): info.update(bads=[], projs=[]) forward = mne.read_forward_solution(fname_fwd) - forward = mne.pick_channels_forward(forward, info['ch_names']) - vertices = [s['vertno'][::200] for s in forward['src']] + forward = mne.pick_channels_forward(forward, info["ch_names"]) + vertices = [s["vertno"][::200] for s in forward["src"]] n_vertices = sum(len(v) for v in vertices) assert n_vertices == 4 amplitude = 100e-9 - stc = mne.SourceEstimate(amplitude * np.eye(n_vertices), vertices, - 0, 1. / info['sfreq']) - forward_sim = mne.convert_forward_solution(forward, force_fixed=True, - use_cps=True, copy=True) + stc = mne.SourceEstimate( + amplitude * np.eye(n_vertices), vertices, 0, 1.0 / info["sfreq"] + ) + forward_sim = mne.convert_forward_solution( + forward, force_fixed=True, use_cps=True, copy=True + ) forward_sim = mne.forward.restrict_forward_to_stc(forward_sim, stc) noise_cov = mne.make_ad_hoc_cov(info) - noise_cov.update(data=np.diag(noise_cov['data']), diag=False) + noise_cov.update(data=np.diag(noise_cov["data"]), diag=False) evoked = simulate_evoked(forward_sim, stc, info, noise_cov, nave=1) - source_nn = forward_sim['source_nn'] - source_rr = forward_sim['source_rr'] + source_nn = forward_sim["source_nn"] + source_rr = forward_sim["source_rr"] # Figure out our indices - mask = np.concatenate([np.in1d(s['vertno'], v) - for s, v in zip(forward['src'], vertices)]) + mask = np.concatenate( + [np.in1d(s["vertno"], v) for s, v in zip(forward["src"], vertices)] + ) mapping = np.where(mask)[0] - assert_array_equal(source_rr, forward['source_rr'][mapping]) + assert_array_equal(source_rr, forward["source_rr"][mapping]) # Don't check NN because we didn't rotate to surf ori del forward_sim # Let's do minimum norm as a sanity check (dipole_fit is slower) - inv = make_inverse_operator(info, forward, noise_cov, loose=1.) - stc_vector_mne = apply_inverse(evoked, inv, pick_ori='vector') + inv = make_inverse_operator(info, forward, noise_cov, loose=1.0) + stc_vector_mne = apply_inverse(evoked, inv, pick_ori="vector") mne_ori = stc_vector_mne.data[mapping, :, np.arange(n_vertices)] mne_ori /= np.linalg.norm(mne_ori, axis=-1)[:, np.newaxis] mne_angles = np.rad2deg(np.arccos(np.sum(mne_ori * source_nn, axis=-1))) @@ -169,28 +211,34 @@ def test_lcmv_vector(): # Now let's do LCMV data_cov = mne.make_ad_hoc_cov(info) # just a stub for later with pytest.raises(ValueError, match="pick_ori"): - make_lcmv(info, forward, data_cov, 0.05, noise_cov, pick_ori='bad') + make_lcmv(info, forward, data_cov, 0.05, noise_cov, pick_ori="bad") lcmv_ori = list() for ti in range(n_vertices): this_evoked = evoked.copy().crop(evoked.times[ti], evoked.times[ti]) - data_cov['diag'] = False - data_cov['data'] = (np.outer(this_evoked.data, this_evoked.data) + - noise_cov['data']) - vals = linalg.svdvals(data_cov['data']) + data_cov["diag"] = False + data_cov["data"] = ( + np.outer(this_evoked.data, this_evoked.data) + noise_cov["data"] + ) + vals = linalg.svdvals(data_cov["data"]) assert vals[0] / vals[-1] < 1e5 # not rank deficient with catch_logging() as log: - filters = make_lcmv(info, forward, data_cov, 0.05, noise_cov, - verbose=True) + filters = make_lcmv(info, forward, data_cov, 0.05, noise_cov, verbose=True) log = log.getvalue() - assert '498 sources' in log + assert "498 sources" in log with catch_logging() as log: - filters_vector = make_lcmv(info, forward, data_cov, 0.05, - noise_cov, pick_ori='vector', - verbose=True) + filters_vector = make_lcmv( + info, + forward, + data_cov, + 0.05, + noise_cov, + pick_ori="vector", + verbose=True, + ) log = log.getvalue() - assert '498 sources' in log + assert "498 sources" in log stc = apply_lcmv(this_evoked, filters) stc_vector = apply_lcmv(this_evoked, filters_vector) assert isinstance(stc, mne.SourceEstimate) @@ -199,7 +247,7 @@ def test_lcmv_vector(): # Check the orientation by pooling across some neighbors, as LCMV can # have some "holes" at the points of interest - idx = np.where(cdist(forward['source_rr'], source_rr[[ti]]) < 0.02)[0] + idx = np.where(cdist(forward["source_rr"], source_rr[[ti]]) < 0.02)[0] lcmv_ori.append(np.mean(stc_vector.data[idx, :, 0], axis=0)) lcmv_ori[-1] /= np.linalg.norm(lcmv_ori[-1]) @@ -208,27 +256,39 @@ def test_lcmv_vector(): @pytest.mark.slowtest -@requires_version('h5io') +@requires_version("h5io") @testing.requires_testing_data -@pytest.mark.parametrize('reg, proj, kind', [ - (0.01, True, 'volume'), - (0., False, 'volume'), - (0.01, False, 'surface'), - (0., True, 'surface'), -]) +@pytest.mark.parametrize( + "reg, proj, kind", + [ + (0.01, True, "volume"), + (0.0, False, "volume"), + (0.01, False, "surface"), + (0.0, True, "surface"), + ], +) def test_make_lcmv_bem(tmp_path, reg, proj, kind): """Test LCMV with evoked data and single trials.""" - raw, epochs, evoked, data_cov, noise_cov, label, forward,\ - forward_surf_ori, forward_fixed, forward_vol = _get_data(proj=proj) - - if kind == 'surface': + ( + raw, + epochs, + evoked, + data_cov, + noise_cov, + label, + forward, + forward_surf_ori, + forward_fixed, + forward_vol, + ) = _get_data(proj=proj) + + if kind == "surface": fwd = forward else: fwd = forward_vol - assert kind == 'volume' + assert kind == "volume" - filters = make_lcmv(evoked.info, fwd, data_cov, reg=reg, - noise_cov=noise_cov) + filters = make_lcmv(evoked.info, fwd, data_cov, reg=reg, noise_cov=noise_cov) stc = apply_lcmv(evoked, filters) stc.crop(0.02, None) @@ -240,11 +300,17 @@ def test_make_lcmv_bem(tmp_path, reg, proj, kind): assert 0.08 < tmax < 0.15, tmax assert 0.9 < np.max(max_stc) < 3.5, np.max(max_stc) - if kind == 'surface': + if kind == "surface": # Test picking normal orientation (surface source space only). - filters = make_lcmv(evoked.info, forward_surf_ori, data_cov, - reg=reg, noise_cov=noise_cov, - pick_ori='normal', weight_norm=None) + filters = make_lcmv( + evoked.info, + forward_surf_ori, + data_cov, + reg=reg, + noise_cov=noise_cov, + pick_ori="normal", + weight_norm=None, + ) stc_normal = apply_lcmv(evoked, filters) stc_normal.crop(0.02, None) @@ -264,8 +330,9 @@ def test_make_lcmv_bem(tmp_path, reg, proj, kind): assert (np.abs(stc_normal.data) <= stc.data).all() # Test picking source orientation maximizing output source power - filters = make_lcmv(evoked.info, fwd, data_cov, reg=reg, - noise_cov=noise_cov, pick_ori='max-power') + filters = make_lcmv( + evoked.info, fwd, data_cov, reg=reg, noise_cov=noise_cov, pick_ori="max-power" + ) stc_max_power = apply_lcmv(evoked, filters) stc_max_power.crop(0.02, None) stc_pow = np.sum(np.abs(stc_max_power.data), axis=1) @@ -275,85 +342,125 @@ def test_make_lcmv_bem(tmp_path, reg, proj, kind): lower = 0.08 if proj else 0.04 assert lower < tmax < 0.15, tmax - assert 0.8 < np.max(max_stc) < 3., np.max(max_stc) + assert 0.8 < np.max(max_stc) < 3.0, np.max(max_stc) stc_max_power.data[:, :] = np.abs(stc_max_power.data) - if kind == 'surface': + if kind == "surface": # Maximum output source power orientation results should be # similar to free orientation results in areas with channel # coverage label = mne.read_label(fname_label) - mean_stc = stc.extract_label_time_course( - label, fwd['src'], mode='mean') - mean_stc_max_pow = \ - stc_max_power.extract_label_time_course( - label, fwd['src'], mode='mean') + mean_stc = stc.extract_label_time_course(label, fwd["src"], mode="mean") + mean_stc_max_pow = stc_max_power.extract_label_time_course( + label, fwd["src"], mode="mean" + ) assert_array_less(np.abs(mean_stc - mean_stc_max_pow), 1.0) # Test if spatial filter contains src_type - assert filters['src_type'] == kind + assert filters["src_type"] == kind # __repr__ assert len(evoked.ch_names) == 22 - assert len(evoked.info['projs']) == (3 if proj else 0) - assert len(evoked.info['bads']) == 2 + assert len(evoked.info["projs"]) == (3 if proj else 0) + assert len(evoked.info["bads"]) == 2 rank = 17 if proj else 20 - assert 'LCMV' in repr(filters) - assert 'unknown subject' not in repr(filters) + assert "LCMV" in repr(filters) + assert "unknown subject" not in repr(filters) assert f'{fwd["nsource"]} vert' in repr(filters) - assert '20 ch' in repr(filters) - assert 'rank %s' % rank in repr(filters) + assert "20 ch" in repr(filters) + assert "rank %s" % rank in repr(filters) # I/O fname = tmp_path / "filters.h5" - with pytest.warns(RuntimeWarning, match='-lcmv.h5'): + with pytest.warns(RuntimeWarning, match="-lcmv.h5"): filters.save(fname) filters_read = read_beamformer(fname) assert isinstance(filters, Beamformer) assert isinstance(filters_read, Beamformer) # deal with object_diff strictness - filters_read['rank'] = int(filters_read['rank']) - filters['rank'] = int(filters['rank']) - assert object_diff(filters, filters_read) == '' + filters_read["rank"] = int(filters_read["rank"]) + filters["rank"] = int(filters["rank"]) + assert object_diff(filters, filters_read) == "" - if kind != 'surface': + if kind != "surface": return # Test if fixed forward operator is detected when picking normal or # max-power orientation - pytest.raises(ValueError, make_lcmv, evoked.info, forward_fixed, data_cov, - reg=0.01, noise_cov=noise_cov, pick_ori='normal') - pytest.raises(ValueError, make_lcmv, evoked.info, forward_fixed, data_cov, - reg=0.01, noise_cov=noise_cov, pick_ori='max-power') + pytest.raises( + ValueError, + make_lcmv, + evoked.info, + forward_fixed, + data_cov, + reg=0.01, + noise_cov=noise_cov, + pick_ori="normal", + ) + pytest.raises( + ValueError, + make_lcmv, + evoked.info, + forward_fixed, + data_cov, + reg=0.01, + noise_cov=noise_cov, + pick_ori="max-power", + ) # Test if non-surface oriented forward operator is detected when picking # normal orientation - pytest.raises(ValueError, make_lcmv, evoked.info, forward, data_cov, - reg=0.01, noise_cov=noise_cov, pick_ori='normal') + pytest.raises( + ValueError, + make_lcmv, + evoked.info, + forward, + data_cov, + reg=0.01, + noise_cov=noise_cov, + pick_ori="normal", + ) # Test if volume forward operator is detected when picking normal # orientation - pytest.raises(ValueError, make_lcmv, evoked.info, forward_vol, data_cov, - reg=0.01, noise_cov=noise_cov, pick_ori='normal') + pytest.raises( + ValueError, + make_lcmv, + evoked.info, + forward_vol, + data_cov, + reg=0.01, + noise_cov=noise_cov, + pick_ori="normal", + ) # Test if missing of noise covariance matrix is detected when more than # one channel type is present in the data - pytest.raises(ValueError, make_lcmv, evoked.info, forward_vol, - data_cov=data_cov, reg=0.01, noise_cov=None, - pick_ori='max-power') + pytest.raises( + ValueError, + make_lcmv, + evoked.info, + forward_vol, + data_cov=data_cov, + reg=0.01, + noise_cov=None, + pick_ori="max-power", + ) # Test if wrong channel selection is detected in application of filter evoked_ch = deepcopy(evoked) evoked_ch.pick_channels(evoked_ch.ch_names[1:]) - filters = make_lcmv(evoked.info, forward_vol, data_cov, reg=0.01, - noise_cov=noise_cov) + filters = make_lcmv( + evoked.info, forward_vol, data_cov, reg=0.01, noise_cov=noise_cov + ) # Test if discrepancies in channel selection of data and fwd model are # handled correctly in apply_lcmv # make filter with data where first channel was removed - filters = make_lcmv(evoked_ch.info, forward_vol, data_cov, reg=0.01, - noise_cov=noise_cov) + filters = make_lcmv( + evoked_ch.info, forward_vol, data_cov, reg=0.01, noise_cov=noise_cov + ) # applying that filter to the full data set should automatically exclude # this channel from the data # also test here that no warnings are thrown - implemented to check whether @@ -368,34 +475,36 @@ def test_make_lcmv_bem(tmp_path, reg, proj, kind): # Test if non-matching SSP projection is detected in application of filter if proj: raw_proj = raw.copy().del_proj() - with pytest.raises(ValueError, match='do not match the projections'): + with pytest.raises(ValueError, match="do not match the projections"): apply_lcmv_raw(raw_proj, filters) # Test apply_lcmv_raw use_raw = raw.copy().crop(0, 1) stc = apply_lcmv_raw(use_raw, filters) assert_allclose(stc.times, use_raw.times) - assert_array_equal(stc.vertices[0], forward_vol['src'][0]['vertno']) + assert_array_equal(stc.vertices[0], forward_vol["src"][0]["vertno"]) # Test if spatial filter contains src_type - assert 'src_type' in filters + assert "src_type" in filters # check whether a filters object without src_type throws expected warning - del filters['src_type'] # emulate 0.16 behaviour to cause warning - with pytest.warns(RuntimeWarning, match='spatial filter does not contain ' - 'src_type'): + del filters["src_type"] # emulate 0.16 behaviour to cause warning + with pytest.warns( + RuntimeWarning, match="spatial filter does not contain " "src_type" + ): apply_lcmv(evoked, filters) # Now test single trial using fixed orientation forward solution # so we can compare it to the evoked solution - filters = make_lcmv(epochs.info, forward_fixed, data_cov, reg=0.01, - noise_cov=noise_cov) + filters = make_lcmv( + epochs.info, forward_fixed, data_cov, reg=0.01, noise_cov=noise_cov + ) stcs = apply_lcmv_epochs(epochs, filters) stcs_ = apply_lcmv_epochs(epochs, filters, return_generator=True) assert_array_equal(stcs[0].data, next(stcs_).data) epochs.drop_bad() - assert (len(epochs.events) == len(stcs)) + assert len(epochs.events) == len(stcs) # average the single trial estimates stc_avg = np.zeros_like(stcs[0].data) @@ -404,15 +513,17 @@ def test_make_lcmv_bem(tmp_path, reg, proj, kind): stc_avg /= len(stcs) # compare it to the solution using evoked with fixed orientation - filters = make_lcmv(evoked.info, forward_fixed, data_cov, reg=0.01, - noise_cov=noise_cov) + filters = make_lcmv( + evoked.info, forward_fixed, data_cov, reg=0.01, noise_cov=noise_cov + ) stc_fixed = apply_lcmv(evoked, filters) assert_array_almost_equal(stc_avg, stc_fixed.data) # use a label so we have few source vertices and delayed computation is # not used - filters = make_lcmv(epochs.info, forward_fixed, data_cov, reg=0.01, - noise_cov=noise_cov, label=label) + filters = make_lcmv( + epochs.info, forward_fixed, data_cov, reg=0.01, noise_cov=noise_cov, label=label + ) stcs_label = apply_lcmv_epochs(epochs, filters) assert_array_almost_equal(stcs_label[0].data, stcs[0].in_label(label).data) @@ -420,54 +531,78 @@ def test_make_lcmv_bem(tmp_path, reg, proj, kind): # Test condition where the filters weights are zero. There should not be # any divide-by-zero errors zero_cov = data_cov.copy() - zero_cov['data'][:] = 0 - filters = make_lcmv(epochs.info, forward_fixed, zero_cov, reg=0.01, - noise_cov=noise_cov) - assert_array_equal(filters['weights'], 0) + zero_cov["data"][:] = 0 + filters = make_lcmv( + epochs.info, forward_fixed, zero_cov, reg=0.01, noise_cov=noise_cov + ) + assert_array_equal(filters["weights"], 0) # Test condition where one channel type is picked # (avoid "grad data rank (13) did not match the noise rank (None)") data_cov_grad = pick_channels_cov( - data_cov, [ch_name for ch_name in epochs.info['ch_names'] - if ch_name.endswith(('2', '3'))], ordered=False) - assert len(data_cov_grad['names']) > 4 - make_lcmv(epochs.info, forward_fixed, data_cov_grad, reg=0.01, - noise_cov=noise_cov) + data_cov, + [ + ch_name + for ch_name in epochs.info["ch_names"] + if ch_name.endswith(("2", "3")) + ], + ordered=False, + ) + assert len(data_cov_grad["names"]) > 4 + make_lcmv(epochs.info, forward_fixed, data_cov_grad, reg=0.01, noise_cov=noise_cov) @testing.requires_testing_data @pytest.mark.slowtest -@pytest.mark.parametrize('weight_norm, pick_ori', [ - ('unit-noise-gain', 'max-power'), - ('unit-noise-gain', 'vector'), - ('unit-noise-gain', None), - ('nai', 'vector'), - (None, 'max-power'), -]) +@pytest.mark.parametrize( + "weight_norm, pick_ori", + [ + ("unit-noise-gain", "max-power"), + ("unit-noise-gain", "vector"), + ("unit-noise-gain", None), + ("nai", "vector"), + (None, "max-power"), + ], +) def test_make_lcmv_sphere(pick_ori, weight_norm): """Test LCMV with sphere head model.""" # unit-noise gain beamformer and orientation # selection and rank reduction of the leadfield _, _, evoked, data_cov, noise_cov, _, _, _, _, _ = _get_data(proj=True) - assert 'eeg' not in evoked - assert 'meg' in evoked - sphere = mne.make_sphere_model(r0=(0., 0., 0.), head_radius=0.080) + assert "eeg" not in evoked + assert "meg" in evoked + sphere = mne.make_sphere_model(r0=(0.0, 0.0, 0.0), head_radius=0.080) src = mne.setup_volume_source_space( - pos=25., sphere=sphere, mindist=5.0, exclude=2.0) + pos=25.0, sphere=sphere, mindist=5.0, exclude=2.0 + ) fwd_sphere = mne.make_forward_solution(evoked.info, None, src, sphere) # Test that we get an error if not reducing rank - with pytest.raises(ValueError, match='Singular matrix detected'): - with pytest.warns(RuntimeWarning, match='positive semidefinite'): + with pytest.raises(ValueError, match="Singular matrix detected"): + with pytest.warns(RuntimeWarning, match="positive semidefinite"): make_lcmv( - evoked.info, fwd_sphere, data_cov, reg=0.1, - noise_cov=noise_cov, weight_norm=weight_norm, - pick_ori=pick_ori, reduce_rank=False, rank='full') + evoked.info, + fwd_sphere, + data_cov, + reg=0.1, + noise_cov=noise_cov, + weight_norm=weight_norm, + pick_ori=pick_ori, + reduce_rank=False, + rank="full", + ) # Now let's reduce it - filters = make_lcmv(evoked.info, fwd_sphere, data_cov, reg=0.1, - noise_cov=noise_cov, weight_norm=weight_norm, - pick_ori=pick_ori, reduce_rank=True) + filters = make_lcmv( + evoked.info, + fwd_sphere, + data_cov, + reg=0.1, + noise_cov=noise_cov, + weight_norm=weight_norm, + pick_ori=pick_ori, + reduce_rank=True, + ) stc_sphere = apply_lcmv(evoked, filters) if isinstance(stc_sphere, VolVectorSourceEstimate): stc_sphere = stc_sphere.magnitude() @@ -489,21 +624,36 @@ def test_make_lcmv_sphere(pick_ori, weight_norm): @testing.requires_testing_data -@pytest.mark.parametrize('weight_norm', (None, 'unit-noise-gain')) -@pytest.mark.parametrize('pick_ori', ('max-power', 'normal')) +@pytest.mark.parametrize("weight_norm", (None, "unit-noise-gain")) +@pytest.mark.parametrize("pick_ori", ("max-power", "normal")) def test_lcmv_cov(weight_norm, pick_ori): """Test LCMV source power computation.""" - raw, epochs, evoked, data_cov, noise_cov, label, forward,\ - forward_surf_ori, forward_fixed, forward_vol = _get_data() + ( + raw, + epochs, + evoked, + data_cov, + noise_cov, + label, + forward, + forward_surf_ori, + forward_fixed, + forward_vol, + ) = _get_data() convert_forward_solution(forward, surf_ori=True, copy=False) - filters = make_lcmv(evoked.info, forward, data_cov, noise_cov=noise_cov, - weight_norm=weight_norm, pick_ori=pick_ori) + filters = make_lcmv( + evoked.info, + forward, + data_cov, + noise_cov=noise_cov, + weight_norm=weight_norm, + pick_ori=pick_ori, + ) for cov in (data_cov, noise_cov): this_cov = pick_channels_cov(cov, evoked.ch_names, ordered=False) - this_evoked = evoked.copy().pick_channels( - this_cov['names'], ordered=True) - this_cov['projs'] = this_evoked.info['projs'] - assert this_evoked.ch_names == this_cov['names'] + this_evoked = evoked.copy().pick_channels(this_cov["names"], ordered=True) + this_cov["projs"] = this_evoked.info["projs"] + assert this_evoked.ch_names == this_cov["names"] stc = apply_lcmv_cov(this_cov, filters) assert stc.data.min() > 0 assert stc.shape == (498, 1) @@ -530,29 +680,35 @@ def test_lcmv_ctf_comp(): evoked = epochs.average() data_cov = mne.compute_covariance(epochs) - fwd = mne.make_forward_solution(evoked.info, None, - mne.setup_volume_source_space(pos=30.0), - mne.make_sphere_model()) - with pytest.raises(ValueError, match='reduce_rank'): + fwd = mne.make_forward_solution( + evoked.info, + None, + mne.setup_volume_source_space(pos=30.0), + mne.make_sphere_model(), + ) + with pytest.raises(ValueError, match="reduce_rank"): make_lcmv(evoked.info, fwd, data_cov) filters = make_lcmv(evoked.info, fwd, data_cov, reduce_rank=True) - assert 'weights' in filters + assert "weights" in filters # test whether different compensations throw error info_comp = evoked.info.copy() set_current_comp(info_comp, 1) - with pytest.raises(RuntimeError, match='Compensation grade .* not match'): + with pytest.raises(RuntimeError, match="Compensation grade .* not match"): make_lcmv(info_comp, fwd, data_cov) @pytest.mark.slowtest @testing.requires_testing_data -@pytest.mark.parametrize('proj, weight_norm', [ - (True, 'unit-noise-gain'), - (False, 'unit-noise-gain'), - (True, None), - (True, 'nai'), -]) +@pytest.mark.parametrize( + "proj, weight_norm", + [ + (True, "unit-noise-gain"), + (False, "unit-noise-gain"), + (True, None), + (True, "nai"), + ], +) def test_lcmv_reg_proj(proj, weight_norm): """Test LCMV with and without proj.""" raw = mne.io.read_raw_fif(fname_raw, preload=True) @@ -560,58 +716,70 @@ def test_lcmv_reg_proj(proj, weight_norm): raw.pick_types(meg=True) assert len(raw.ch_names) == 305 epochs = mne.Epochs(raw, events, None, preload=True, proj=proj) - with pytest.warns(RuntimeWarning, match='Too few samples'): + with pytest.warns(RuntimeWarning, match="Too few samples"): noise_cov = mne.compute_covariance(epochs, tmax=0) data_cov = mne.compute_covariance(epochs, tmin=0.04, tmax=0.15) forward = mne.read_forward_solution(fname_fwd) - filters = make_lcmv(epochs.info, forward, data_cov, reg=0.05, - noise_cov=noise_cov, pick_ori='max-power', - weight_norm='nai', rank=None, verbose=True) + filters = make_lcmv( + epochs.info, + forward, + data_cov, + reg=0.05, + noise_cov=noise_cov, + pick_ori="max-power", + weight_norm="nai", + rank=None, + verbose=True, + ) want_rank = 302 # 305 good channels - 3 MEG projs - assert filters['rank'] == want_rank + assert filters["rank"] == want_rank # And also with and without noise_cov - with pytest.raises(ValueError, match='several sensor types'): - make_lcmv(epochs.info, forward, data_cov, reg=0.05, - noise_cov=None) - epochs.pick_types(meg='grad') + with pytest.raises(ValueError, match="several sensor types"): + make_lcmv(epochs.info, forward, data_cov, reg=0.05, noise_cov=None) + epochs.pick_types(meg="grad") kwargs = dict(reg=0.05, pick_ori=None, weight_norm=weight_norm) - filters_cov = make_lcmv(epochs.info, forward, data_cov, - noise_cov=noise_cov, **kwargs) - filters_nocov = make_lcmv(epochs.info, forward, data_cov, - noise_cov=None, **kwargs) + filters_cov = make_lcmv( + epochs.info, forward, data_cov, noise_cov=noise_cov, **kwargs + ) + filters_nocov = make_lcmv(epochs.info, forward, data_cov, noise_cov=None, **kwargs) ad_hoc = mne.make_ad_hoc_cov(epochs.info) - filters_adhoc = make_lcmv(epochs.info, forward, data_cov, - noise_cov=ad_hoc, **kwargs) + filters_adhoc = make_lcmv( + epochs.info, forward, data_cov, noise_cov=ad_hoc, **kwargs + ) evoked = epochs.average() stc_cov = apply_lcmv(evoked, filters_cov) stc_nocov = apply_lcmv(evoked, filters_nocov) stc_adhoc = apply_lcmv(evoked, filters_adhoc) # Compare adhoc and nocov: scale difference is necessitated by using std=1. - if weight_norm == 'unit-noise-gain': - scale = np.sqrt(ad_hoc['data'][0]) + if weight_norm == "unit-noise-gain": + scale = np.sqrt(ad_hoc["data"][0]) else: - scale = 1. + scale = 1.0 assert_allclose(stc_nocov.data, stc_adhoc.data * scale) - a = np.dot(filters_nocov['weights'], filters_nocov['whitener']) - b = np.dot(filters_adhoc['weights'], filters_adhoc['whitener']) * scale + a = np.dot(filters_nocov["weights"], filters_nocov["whitener"]) + b = np.dot(filters_adhoc["weights"], filters_adhoc["whitener"]) * scale atol = np.mean(np.sqrt(a * a)) * 1e-7 assert_allclose(a, b, atol=atol, rtol=1e-7) # Compare adhoc and cov: locs might not be equivalent, but the same # general profile should persist, so look at the std and be lenient: - if weight_norm == 'unit-noise-gain': + if weight_norm == "unit-noise-gain": adhoc_scale = 0.12 else: - adhoc_scale = 1. + adhoc_scale = 1.0 assert_allclose( np.linalg.norm(stc_adhoc.data, axis=0) * adhoc_scale, - np.linalg.norm(stc_cov.data, axis=0), rtol=0.3) + np.linalg.norm(stc_cov.data, axis=0), + rtol=0.3, + ) assert_allclose( np.linalg.norm(stc_nocov.data, axis=0) / scale * adhoc_scale, - np.linalg.norm(stc_cov.data, axis=0), rtol=0.3) + np.linalg.norm(stc_cov.data, axis=0), + rtol=0.3, + ) - if weight_norm == 'nai': + if weight_norm == "nai": # NAI is always normalized by noise-level (based on eigenvalues) for stc in (stc_nocov, stc_cov): assert_allclose(stc.data.std(), 0.584, rtol=0.2) @@ -621,34 +789,47 @@ def test_lcmv_reg_proj(proj, weight_norm): for stc in (stc_nocov, stc_cov): assert_allclose(stc.data.std(), 2.8e-8, rtol=0.1) else: - assert weight_norm == 'unit-noise-gain' + assert weight_norm == "unit-noise-gain" # Channel scalings depend on presence of noise_cov assert_allclose(stc_nocov.data.std(), 7.8e-13, rtol=0.1) assert_allclose(stc_cov.data.std(), 0.187, rtol=0.2) -@pytest.mark.parametrize('reg, weight_norm, use_cov, depth, lower, upper', [ - (0.05, 'unit-noise-gain', True, None, 97, 98), - (0.05, 'nai', True, None, 96, 98), - (0.05, 'nai', True, 0.8, 96, 98), - (0.05, None, True, None, 74, 76), - (0.05, None, True, 0.8, 90, 93), # depth improves weight_norm=None - (0.05, 'unit-noise-gain', False, None, 83, 86), - (0.05, 'unit-noise-gain', False, 0.8, 83, 86), # depth same for wn != None - # no reg - (0.00, 'unit-noise-gain', True, None, 35, 99), # TODO: Still not stable -]) -def test_localization_bias_fixed(bias_params_fixed, reg, weight_norm, use_cov, - depth, lower, upper): +@pytest.mark.parametrize( + "reg, weight_norm, use_cov, depth, lower, upper", + [ + (0.05, "unit-noise-gain", True, None, 97, 98), + (0.05, "nai", True, None, 96, 98), + (0.05, "nai", True, 0.8, 96, 98), + (0.05, None, True, None, 74, 76), + (0.05, None, True, 0.8, 90, 93), # depth improves weight_norm=None + (0.05, "unit-noise-gain", False, None, 83, 86), + (0.05, "unit-noise-gain", False, 0.8, 83, 86), # depth same for wn != None + # no reg + (0.00, "unit-noise-gain", True, None, 35, 99), # TODO: Still not stable + ], +) +def test_localization_bias_fixed( + bias_params_fixed, reg, weight_norm, use_cov, depth, lower, upper +): """Test localization bias for fixed-orientation LCMV.""" evoked, fwd, noise_cov, data_cov, want = bias_params_fixed if not use_cov: - evoked.pick_types(meg='grad') + evoked.pick_types(meg="grad") noise_cov = None - assert data_cov['data'].shape[0] == len(data_cov['names']) - loc = apply_lcmv(evoked, make_lcmv(evoked.info, fwd, data_cov, reg, - noise_cov, depth=depth, - weight_norm=weight_norm)).data + assert data_cov["data"].shape[0] == len(data_cov["names"]) + loc = apply_lcmv( + evoked, + make_lcmv( + evoked.info, + fwd, + data_cov, + reg, + noise_cov, + depth=depth, + weight_norm=weight_norm, + ), + ).data loc = np.abs(loc) # Compute the percentage of sources for which there is no loc bias: perc = (want == np.argmax(loc, axis=0)).mean() * 100 @@ -657,51 +838,117 @@ def test_localization_bias_fixed(bias_params_fixed, reg, weight_norm, use_cov, # Changes here should be synced with test_dics.py @pytest.mark.parametrize( - 'reg, pick_ori, weight_norm, use_cov, depth, lower, upper, ' - 'lower_ori, upper_ori', [ - (0.05, 'vector', 'unit-noise-gain-invariant', False, None, 26, 28, 0.82, 0.84), # noqa: E501 - (0.05, 'vector', 'unit-noise-gain-invariant', True, None, 40, 42, 0.96, 0.98), # noqa: E501 - (0.05, 'vector', 'unit-noise-gain', False, None, 13, 14, 0.79, 0.81), - (0.05, 'vector', 'unit-noise-gain', True, None, 35, 37, 0.98, 0.99), - (0.05, 'vector', 'nai', True, None, 35, 37, 0.98, 0.99), - (0.05, 'vector', None, True, None, 12, 14, 0.97, 0.98), - (0.05, 'vector', None, True, 0.8, 39, 43, 0.97, 0.98), - (0.05, 'max-power', 'unit-noise-gain-invariant', False, None, 17, 20, 0, 0), # noqa: E501 - (0.05, 'max-power', 'unit-noise-gain', False, None, 17, 20, 0, 0), - (0.05, 'max-power', 'nai', True, None, 21, 24, 0, 0), - (0.05, 'max-power', None, True, None, 7, 10, 0, 0), - (0.05, 'max-power', None, True, 0.8, 15, 18, 0, 0), + "reg, pick_ori, weight_norm, use_cov, depth, lower, upper, " "lower_ori, upper_ori", + [ + ( + 0.05, + "vector", + "unit-noise-gain-invariant", + False, + None, + 26, + 28, + 0.82, + 0.84, + ), # noqa: E501 + ( + 0.05, + "vector", + "unit-noise-gain-invariant", + True, + None, + 40, + 42, + 0.96, + 0.98, + ), # noqa: E501 + (0.05, "vector", "unit-noise-gain", False, None, 13, 14, 0.79, 0.81), + (0.05, "vector", "unit-noise-gain", True, None, 35, 37, 0.98, 0.99), + (0.05, "vector", "nai", True, None, 35, 37, 0.98, 0.99), + (0.05, "vector", None, True, None, 12, 14, 0.97, 0.98), + (0.05, "vector", None, True, 0.8, 39, 43, 0.97, 0.98), + ( + 0.05, + "max-power", + "unit-noise-gain-invariant", + False, + None, + 17, + 20, + 0, + 0, + ), # noqa: E501 + (0.05, "max-power", "unit-noise-gain", False, None, 17, 20, 0, 0), + (0.05, "max-power", "nai", True, None, 21, 24, 0, 0), + (0.05, "max-power", None, True, None, 7, 10, 0, 0), + (0.05, "max-power", None, True, 0.8, 15, 18, 0, 0), (0.05, None, None, True, 0.8, 40, 42, 0, 0), # no reg - (0.00, 'vector', None, True, None, 23, 24, 0.96, 0.97), - (0.00, 'vector', 'unit-noise-gain-invariant', True, None, 52, 54, 0.95, 0.96), # noqa: E501 - (0.00, 'vector', 'unit-noise-gain', True, None, 44, 48, 0.97, 0.99), - (0.00, 'vector', 'nai', True, None, 44, 48, 0.97, 0.99), - (0.00, 'max-power', None, True, None, 14, 15, 0, 0), - (0.00, 'max-power', 'unit-noise-gain-invariant', True, None, 35, 37, 0, 0), # noqa: E501 - (0.00, 'max-power', 'unit-noise-gain', True, None, 35, 37, 0, 0), - (0.00, 'max-power', 'nai', True, None, 35, 37, 0, 0), - ]) -def test_localization_bias_free(bias_params_free, reg, pick_ori, weight_norm, - use_cov, depth, lower, upper, - lower_ori, upper_ori): + (0.00, "vector", None, True, None, 23, 24, 0.96, 0.97), + ( + 0.00, + "vector", + "unit-noise-gain-invariant", + True, + None, + 52, + 54, + 0.95, + 0.96, + ), # noqa: E501 + (0.00, "vector", "unit-noise-gain", True, None, 44, 48, 0.97, 0.99), + (0.00, "vector", "nai", True, None, 44, 48, 0.97, 0.99), + (0.00, "max-power", None, True, None, 14, 15, 0, 0), + ( + 0.00, + "max-power", + "unit-noise-gain-invariant", + True, + None, + 35, + 37, + 0, + 0, + ), # noqa: E501 + (0.00, "max-power", "unit-noise-gain", True, None, 35, 37, 0, 0), + (0.00, "max-power", "nai", True, None, 35, 37, 0, 0), + ], +) +def test_localization_bias_free( + bias_params_free, + reg, + pick_ori, + weight_norm, + use_cov, + depth, + lower, + upper, + lower_ori, + upper_ori, +): """Test localization bias for free-orientation LCMV.""" evoked, fwd, noise_cov, data_cov, want = bias_params_free if not use_cov: - evoked.pick_types(meg='grad') + evoked.pick_types(meg="grad") noise_cov = None with _record_warnings(): # rank deficiency of data_cov - filters = make_lcmv(evoked.info, fwd, data_cov, reg, - noise_cov, pick_ori=pick_ori, - weight_norm=weight_norm, - depth=depth) + filters = make_lcmv( + evoked.info, + fwd, + data_cov, + reg, + noise_cov, + pick_ori=pick_ori, + weight_norm=weight_norm, + depth=depth, + ) loc = apply_lcmv(evoked, filters).data - if pick_ori == 'vector': + if pick_ori == "vector": ori = loc.copy() / np.linalg.norm(loc, axis=1, keepdims=True) else: # doesn't make sense for pooled (None) or max-power (can't be all 3) ori = None - loc = np.linalg.norm(loc, axis=1) if pick_ori == 'vector' else np.abs(loc) + loc = np.linalg.norm(loc, axis=1) if pick_ori == "vector" else np.abs(loc) # Compute the percentage of sources for which there is no loc bias: max_idx = np.argmax(loc, axis=0) perc = (want == max_idx).mean() * 100 @@ -712,35 +959,52 @@ def test_localization_bias_free(bias_params_free, reg, pick_ori, weight_norm, # Changes here should be synced with the ones above, but these have meaningful # orientation values @pytest.mark.parametrize( - 'reg, weight_norm, use_cov, depth, lower, upper, lower_ori, upper_ori', [ - (0.05, 'unit-noise-gain-invariant', False, None, 38, 40, 0.54, 0.55), - (0.05, 'unit-noise-gain', False, None, 38, 40, 0.54, 0.55), - (0.05, 'nai', True, None, 56, 57, 0.59, 0.61), + "reg, weight_norm, use_cov, depth, lower, upper, lower_ori, upper_ori", + [ + (0.05, "unit-noise-gain-invariant", False, None, 38, 40, 0.54, 0.55), + (0.05, "unit-noise-gain", False, None, 38, 40, 0.54, 0.55), + (0.05, "nai", True, None, 56, 57, 0.59, 0.61), (0.05, None, True, None, 27, 28, 0.56, 0.57), (0.05, None, True, 0.8, 42, 43, 0.56, 0.57), # no reg (0.00, None, True, None, 50, 51, 0.58, 0.59), - (0.00, 'unit-noise-gain-invariant', True, None, 73, 75, 0.59, 0.61), - (0.00, 'unit-noise-gain', True, None, 73, 75, 0.59, 0.61), - (0.00, 'nai', True, None, 73, 75, 0.59, 0.61), - ]) -def test_orientation_max_power(bias_params_fixed, bias_params_free, - reg, weight_norm, use_cov, depth, lower, upper, - lower_ori, upper_ori): + (0.00, "unit-noise-gain-invariant", True, None, 73, 75, 0.59, 0.61), + (0.00, "unit-noise-gain", True, None, 73, 75, 0.59, 0.61), + (0.00, "nai", True, None, 73, 75, 0.59, 0.61), + ], +) +def test_orientation_max_power( + bias_params_fixed, + bias_params_free, + reg, + weight_norm, + use_cov, + depth, + lower, + upper, + lower_ori, + upper_ori, +): """Test orientation selection for bias for max-power LCMV.""" # we simulate data for the fixed orientation forward and beamform using # the free orientation forward, and check the orientation match at the end evoked, _, noise_cov, data_cov, want = bias_params_fixed fwd = bias_params_free[1] if not use_cov: - evoked.pick_types(meg='grad') + evoked.pick_types(meg="grad") noise_cov = None - filters = make_lcmv(evoked.info, fwd, data_cov, reg, - noise_cov, pick_ori='max-power', - weight_norm=weight_norm, - depth=depth) + filters = make_lcmv( + evoked.info, + fwd, + data_cov, + reg, + noise_cov, + pick_ori="max-power", + weight_norm=weight_norm, + depth=depth, + ) loc = apply_lcmv(evoked, filters).data - ori = filters['max_power_ori'] + ori = filters["max_power_ori"] assert ori.shape == (246, 3) loc = np.abs(loc) # Compute the percentage of sources for which there is no loc bias: @@ -749,11 +1013,10 @@ def test_orientation_max_power(bias_params_fixed, bias_params_free, perc = mask.mean() * 100 assert lower <= perc <= upper # Compute the dot products of our forward normals and - assert fwd['coord_frame'] == FIFF.FIFFV_COORD_HEAD - nn = np.concatenate( - [s['nn'][v] for s, v in zip(fwd['src'], filters['vertices'])]) + assert fwd["coord_frame"] == FIFF.FIFFV_COORD_HEAD + nn = np.concatenate([s["nn"][v] for s, v in zip(fwd["src"], filters["vertices"])]) nn = nn[want] - nn = apply_trans(invert_transform(fwd['mri_head_t']), nn, move=False) + nn = apply_trans(invert_transform(fwd["mri_head_t"]), nn, move=False) assert_allclose(np.linalg.norm(nn, axis=1), 1, atol=1e-6) assert_allclose(np.linalg.norm(ori, axis=1), 1, atol=1e-12) dots = np.abs((nn[mask] * ori[mask]).sum(-1)) @@ -763,21 +1026,44 @@ def test_orientation_max_power(bias_params_fixed, bias_params_free, assert lower_ori < got < upper_ori -@pytest.mark.parametrize('weight_norm, pick_ori', [ - pytest.param('nai', 'max-power', marks=pytest.mark.slowtest), - ('unit-noise-gain', 'vector'), - ('unit-noise-gain', 'max-power'), - pytest.param('unit-noise-gain', None, marks=pytest.mark.slowtest), -]) +@pytest.mark.parametrize( + "weight_norm, pick_ori", + [ + pytest.param("nai", "max-power", marks=pytest.mark.slowtest), + ("unit-noise-gain", "vector"), + ("unit-noise-gain", "max-power"), + pytest.param("unit-noise-gain", None, marks=pytest.mark.slowtest), + ], +) def test_depth_does_not_matter(bias_params_free, weight_norm, pick_ori): """Test that depth weighting does not matter for normalized filters.""" evoked, fwd, noise_cov, data_cov, _ = bias_params_free - data = apply_lcmv(evoked, make_lcmv( - evoked.info, fwd, data_cov, 0.05, noise_cov, pick_ori=pick_ori, - weight_norm=weight_norm, depth=0.)).data - data_depth = apply_lcmv(evoked, make_lcmv( - evoked.info, fwd, data_cov, 0.05, noise_cov, pick_ori=pick_ori, - weight_norm=weight_norm, depth=1.)).data + data = apply_lcmv( + evoked, + make_lcmv( + evoked.info, + fwd, + data_cov, + 0.05, + noise_cov, + pick_ori=pick_ori, + weight_norm=weight_norm, + depth=0.0, + ), + ).data + data_depth = apply_lcmv( + evoked, + make_lcmv( + evoked.info, + fwd, + data_cov, + 0.05, + noise_cov, + pick_ori=pick_ori, + weight_norm=weight_norm, + depth=1.0, + ), + ).data assert data.shape == data_depth.shape for d1, d2 in zip(data, data_depth): # Sign flips can change when nearly orthogonal to the normal direction @@ -793,59 +1079,78 @@ def test_lcmv_maxfiltered(): raw_sss = mne.preprocessing.maxwell_filter(raw) events = mne.find_events(raw_sss) del raw - raw_sss.pick_types(meg='mag') + raw_sss.pick_types(meg="mag") assert len(raw_sss.ch_names) == 102 epochs = mne.Epochs(raw_sss, events) data_cov = mne.compute_covariance(epochs, tmin=0) fwd = mne.read_forward_solution(fname_fwd) rank = compute_rank(data_cov, info=epochs.info) - assert rank == {'mag': 71} - for use_rank in ('info', rank, 'full', None): + assert rank == {"mag": 71} + for use_rank in ("info", rank, "full", None): make_lcmv(epochs.info, fwd, data_cov, rank=use_rank) # To reduce test time, only test combinations that should matter rather than # all of them @testing.requires_testing_data -@pytest.mark.parametrize('pick_ori, weight_norm, reg, inversion', [ - ('vector', 'unit-noise-gain-invariant', 0.05, 'matrix'), - ('vector', 'unit-noise-gain-invariant', 0.05, 'single'), - ('vector', 'unit-noise-gain', 0.05, 'matrix'), - ('vector', 'unit-noise-gain', 0.05, 'single'), - ('vector', 'unit-noise-gain', 0.0, 'matrix'), - ('vector', 'unit-noise-gain', 0.0, 'single'), - ('vector', 'nai', 0.05, 'matrix'), - ('max-power', 'unit-noise-gain', 0.05, 'matrix'), - ('max-power', 'unit-noise-gain', 0.0, 'single'), - ('max-power', 'unit-noise-gain', 0.05, 'single'), - ('max-power', 'unit-noise-gain-invariant', 0.05, 'matrix'), - ('normal', 'unit-noise-gain', 0.05, 'matrix'), - ('normal', 'nai', 0.0, 'matrix'), -]) +@pytest.mark.parametrize( + "pick_ori, weight_norm, reg, inversion", + [ + ("vector", "unit-noise-gain-invariant", 0.05, "matrix"), + ("vector", "unit-noise-gain-invariant", 0.05, "single"), + ("vector", "unit-noise-gain", 0.05, "matrix"), + ("vector", "unit-noise-gain", 0.05, "single"), + ("vector", "unit-noise-gain", 0.0, "matrix"), + ("vector", "unit-noise-gain", 0.0, "single"), + ("vector", "nai", 0.05, "matrix"), + ("max-power", "unit-noise-gain", 0.05, "matrix"), + ("max-power", "unit-noise-gain", 0.0, "single"), + ("max-power", "unit-noise-gain", 0.05, "single"), + ("max-power", "unit-noise-gain-invariant", 0.05, "matrix"), + ("normal", "unit-noise-gain", 0.05, "matrix"), + ("normal", "nai", 0.0, "matrix"), + ], +) def test_unit_noise_gain_formula(pick_ori, weight_norm, reg, inversion): """Test unit-noise-gain filter against formula.""" raw = mne.io.read_raw_fif(fname_raw, preload=True) events = mne.find_events(raw) - raw.pick_types(meg='mag') + raw.pick_types(meg="mag") assert len(raw.ch_names) == 102 epochs = mne.Epochs(raw, events, None, preload=True) data_cov = mne.compute_covariance(epochs, tmin=0.04, tmax=0.15) # for now, avoid whitening to make life easier - noise_cov = mne.make_ad_hoc_cov(epochs.info, std=dict(grad=1., mag=1.)) + noise_cov = mne.make_ad_hoc_cov(epochs.info, std=dict(grad=1.0, mag=1.0)) forward = mne.read_forward_solution(fname_fwd) convert_forward_solution(forward, surf_ori=True, copy=False) rank = None - kwargs = dict(reg=reg, noise_cov=noise_cov, pick_ori=pick_ori, - weight_norm=weight_norm, rank=rank, inversion=inversion) - if inversion == 'single' and pick_ori == 'vector' and \ - weight_norm == 'unit-noise-gain-invariant': - with pytest.raises(ValueError, match='Cannot use'): + kwargs = dict( + reg=reg, + noise_cov=noise_cov, + pick_ori=pick_ori, + weight_norm=weight_norm, + rank=rank, + inversion=inversion, + ) + if ( + inversion == "single" + and pick_ori == "vector" + and weight_norm == "unit-noise-gain-invariant" + ): + with pytest.raises(ValueError, match="Cannot use"): make_lcmv(epochs.info, forward, data_cov, **kwargs) return filters = make_lcmv(epochs.info, forward, data_cov, **kwargs) _, _, _, _, G, _, _, _ = _prepare_beamformer_input( - epochs.info, forward, None, 'vector', noise_cov=noise_cov, rank=rank, - pca=False, exp=None) + epochs.info, + forward, + None, + "vector", + noise_cov=noise_cov, + rank=rank, + pca=False, + exp=None, + ) n_channels, n_sources = G.shape n_sources //= 3 G.shape = (n_channels, n_sources, 3) @@ -855,26 +1160,26 @@ def test_unit_noise_gain_formula(pick_ori, weight_norm, reg, inversion): def _assert_weight_norm(filters, G): """Check the result of the chosen weight normalization strategy.""" - weights, max_power_ori = filters['weights'], filters['max_power_ori'] + weights, max_power_ori = filters["weights"], filters["max_power_ori"] # Make the dimensions of the weight matrix equal for both DICS (which # defines weights for multiple frequencies) and LCMV (which does not). - if filters['kind'] == 'LCMV': + if filters["kind"] == "LCMV": weights = weights[np.newaxis] if max_power_ori is not None: max_power_ori = max_power_ori[np.newaxis] if max_power_ori is not None: max_power_ori = max_power_ori[..., np.newaxis] - weight_norm = filters['weight_norm'] - inversion = filters['inversion'] + weight_norm = filters["weight_norm"] + inversion = filters["inversion"] n_channels = weights.shape[2] - if inversion == 'matrix': + if inversion == "matrix": # Dipoles are grouped in groups with size n_orient - n_sources = filters['n_sources'] - n_orient = 3 if filters['is_free_ori'] else 1 - elif inversion == 'single': + n_sources = filters["n_sources"] + n_orient = 3 if filters["is_free_ori"] else 1 + elif inversion == "single": # Every dipole is treated as a unique source n_sources = weights.shape[1] n_orient = 1 @@ -884,13 +1189,13 @@ def _assert_weight_norm(filters, G): # Compute leadfield in the direction chosen during the computation of # the beamformer. - if filters['pick_ori'] == 'max-power': + if filters["pick_ori"] == "max-power": use_G = np.sum(G * max_power_ori[wi], axis=1, keepdims=True) - elif filters['pick_ori'] == 'normal': + elif filters["pick_ori"] == "normal": use_G = G[:, -1:] else: use_G = G - if inversion == 'single': + if inversion == "single": # Every dipole is treated as a unique source use_G = use_G.reshape(n_sources, 1, n_channels) assert w.shape == use_G.shape == (n_sources, n_orient, n_channels) @@ -898,32 +1203,32 @@ def _assert_weight_norm(filters, G): # Test weight normalization scheme got = np.matmul(w, w.conj().swapaxes(-2, -1)) desired = np.repeat(np.eye(n_orient)[np.newaxis], w.shape[0], axis=0) - if n_orient == 3 and weight_norm in ('unit-noise-gain', 'nai'): + if n_orient == 3 and weight_norm in ("unit-noise-gain", "nai"): # only the diagonal is correct! assert not np.allclose(got, desired, atol=1e-7) - got = got.reshape(n_sources, -1)[:, ::n_orient + 1] + got = got.reshape(n_sources, -1)[:, :: n_orient + 1] desired = np.ones_like(got) - if weight_norm == 'nai': # additional scale factor, should be fixed + if weight_norm == "nai": # additional scale factor, should be fixed atol = 1e-7 * got.flat[0] desired *= got.flat[0] else: atol = 1e-7 - assert_allclose(got, desired, atol=atol, err_msg='w @ w.conj().T = I') + assert_allclose(got, desired, atol=atol, err_msg="w @ w.conj().T = I") # Check that the result here is a diagonal matrix for Sekihara - if n_orient > 1 and weight_norm != 'unit-noise-gain-invariant': + if n_orient > 1 and weight_norm != "unit-noise-gain-invariant": got = w @ use_G.swapaxes(-2, -1) diags = np.diagonal(got, 0, -2, -1) want = np.apply_along_axis(np.diagflat, 1, diags) atol = np.mean(diags).real * 1e-12 - assert_allclose(got, want, atol=atol, err_msg='G.T @ w = θI') + assert_allclose(got, want, atol=atol, err_msg="G.T @ w = θI") def test_api(): """Test LCMV/DICS API equivalence.""" lcmv_names = list(signature(make_lcmv).parameters) dics_names = list(signature(make_dics).parameters) - dics_names[dics_names.index('csd')] = 'data_cov' - dics_names[dics_names.index('noise_csd')] = 'noise_cov' - dics_names.pop(dics_names.index('real_filter')) # not a thing for LCMV + dics_names[dics_names.index("csd")] = "data_cov" + dics_names[dics_names.index("noise_csd")] = "noise_cov" + dics_names.pop(dics_names.index("real_filter")) # not a thing for LCMV assert lcmv_names == dics_names diff --git a/mne/beamformer/tests/test_rap_music.py b/mne/beamformer/tests/test_rap_music.py index 6595b792dcb..68abae4d435 100644 --- a/mne/beamformer/tests/test_rap_music.py +++ b/mne/beamformer/tests/test_rap_music.py @@ -19,18 +19,16 @@ data_path = testing.data_path(download=False) fname_ave = data_path / "MEG" / "sample" / "sample_audvis-ave.fif" fname_cov = data_path / "MEG" / "sample" / "sample_audvis_trunc-cov.fif" -fname_fwd = ( - data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-4-fwd.fif" -) +fname_fwd = data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-4-fwd.fif" def _get_data(ch_decim=1): """Read in data used in tests.""" # Read evoked evoked = mne.read_evokeds(fname_ave, 0, baseline=(None, 0)) - evoked.info['bads'] = ['MEG 2443'] + evoked.info["bads"] = ["MEG 2443"] with evoked.info._unlock(): - evoked.info['lowpass'] = 16 # fake for decim + evoked.info["lowpass"] = 16 # fake for decim evoked.decimate(12) evoked.crop(0.0, 0.3) picks = mne.pick_types(evoked.info, meg=True, eeg=False) @@ -39,8 +37,8 @@ def _get_data(ch_decim=1): evoked.info.normalize_proj() noise_cov = mne.read_cov(fname_cov) - noise_cov['projs'] = [] - noise_cov = regularize(noise_cov, evoked.info, rank='full', proj=False) + noise_cov["projs"] = [] + noise_cov = regularize(noise_cov, evoked.info, rank="full", proj=False) return evoked, noise_cov @@ -51,66 +49,69 @@ def simu_data(evoked, forward, noise_cov, n_dipoles, times, nave=1): """ # Generate the two dipoles data mu, sigma = 0.1, 0.005 - s1 = 1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(-(times - mu) ** 2 / - (2 * sigma ** 2)) + s1 = ( + 1 + / (sigma * np.sqrt(2 * np.pi)) + * np.exp(-((times - mu) ** 2) / (2 * sigma**2)) + ) mu, sigma = 0.075, 0.008 - s2 = -1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(-(times - mu) ** 2 / - (2 * sigma ** 2)) + s2 = ( + -1 + / (sigma * np.sqrt(2 * np.pi)) + * np.exp(-((times - mu) ** 2) / (2 * sigma**2)) + ) data = np.array([s1, s2]) * 1e-9 - src = forward['src'] + src = forward["src"] rng = np.random.RandomState(42) - rndi = rng.randint(len(src[0]['vertno'])) - lh_vertno = src[0]['vertno'][[rndi]] + rndi = rng.randint(len(src[0]["vertno"])) + lh_vertno = src[0]["vertno"][[rndi]] - rndi = rng.randint(len(src[1]['vertno'])) - rh_vertno = src[1]['vertno'][[rndi]] + rndi = rng.randint(len(src[1]["vertno"])) + rh_vertno = src[1]["vertno"][[rndi]] vertices = [lh_vertno, rh_vertno] - tmin, tstep = times.min(), 1 / evoked.info['sfreq'] + tmin, tstep = times.min(), 1 / evoked.info["sfreq"] stc = mne.SourceEstimate(data, vertices=vertices, tmin=tmin, tstep=tstep) - sim_evoked = mne.simulation.simulate_evoked(forward, stc, evoked.info, - noise_cov, nave=nave, - random_state=rng) + sim_evoked = mne.simulation.simulate_evoked( + forward, stc, evoked.info, noise_cov, nave=nave, random_state=rng + ) return sim_evoked, stc def _check_dipoles(dipoles, fwd, stc, evoked, residual=None): - src = fwd['src'] - pos1 = fwd['source_rr'][np.where(src[0]['vertno'] == - stc.vertices[0])] - pos2 = fwd['source_rr'][np.where(src[1]['vertno'] == - stc.vertices[1])[0] + - len(src[0]['vertno'])] + src = fwd["src"] + pos1 = fwd["source_rr"][np.where(src[0]["vertno"] == stc.vertices[0])] + pos2 = fwd["source_rr"][ + np.where(src[1]["vertno"] == stc.vertices[1])[0] + len(src[0]["vertno"]) + ] # Check the position of the two dipoles - assert (dipoles[0].pos[0] in np.array([pos1, pos2])) - assert (dipoles[1].pos[0] in np.array([pos1, pos2])) + assert dipoles[0].pos[0] in np.array([pos1, pos2]) + assert dipoles[1].pos[0] in np.array([pos1, pos2]) - ori1 = fwd['source_nn'][np.where(src[0]['vertno'] == - stc.vertices[0])[0]][0] - ori2 = fwd['source_nn'][np.where(src[1]['vertno'] == - stc.vertices[1])[0] + - len(src[0]['vertno'])][0] + ori1 = fwd["source_nn"][np.where(src[0]["vertno"] == stc.vertices[0])[0]][0] + ori2 = fwd["source_nn"][ + np.where(src[1]["vertno"] == stc.vertices[1])[0] + len(src[0]["vertno"]) + ][0] # Check the orientation of the dipoles - assert (np.max(np.abs(np.dot(dipoles[0].ori[0], - np.array([ori1, ori2]).T))) > 0.99) + assert np.max(np.abs(np.dot(dipoles[0].ori[0], np.array([ori1, ori2]).T))) > 0.99 - assert (np.max(np.abs(np.dot(dipoles[1].ori[0], - np.array([ori1, ori2]).T))) > 0.99) + assert np.max(np.abs(np.dot(dipoles[1].ori[0], np.array([ori1, ori2]).T))) > 0.99 if residual is not None: - picks_grad = mne.pick_types(residual.info, meg='grad') - picks_mag = mne.pick_types(residual.info, meg='mag') + picks_grad = mne.pick_types(residual.info, meg="grad") + picks_mag = mne.pick_types(residual.info, meg="mag") rel_tol = 0.02 for picks in [picks_grad, picks_mag]: - assert (linalg.norm(residual.data[picks], ord='fro') < - rel_tol * linalg.norm(evoked.data[picks], ord='fro')) + assert linalg.norm(residual.data[picks], ord="fro") < rel_tol * linalg.norm( + evoked.data[picks], ord="fro" + ) @testing.requires_testing_data @@ -120,37 +121,48 @@ def test_rap_music_simulated(): forward = mne.read_forward_solution(fname_fwd) forward = mne.pick_channels_forward(forward, evoked.ch_names) forward_surf_ori = mne.convert_forward_solution(forward, surf_ori=True) - forward_fixed = mne.convert_forward_solution(forward, force_fixed=True, - surf_ori=True, use_cps=True) + forward_fixed = mne.convert_forward_solution( + forward, force_fixed=True, surf_ori=True, use_cps=True + ) n_dipoles = 2 - sim_evoked, stc = simu_data(evoked, forward_fixed, noise_cov, - n_dipoles, evoked.times, nave=evoked.nave) + sim_evoked, stc = simu_data( + evoked, forward_fixed, noise_cov, n_dipoles, evoked.times, nave=evoked.nave + ) # Check dipoles for fixed ori with catch_logging() as log: - dipoles = rap_music(sim_evoked, forward_fixed, noise_cov, - n_dipoles=n_dipoles, verbose=True) + dipoles = rap_music( + sim_evoked, forward_fixed, noise_cov, n_dipoles=n_dipoles, verbose=True + ) assert_var_exp_log(log.getvalue(), 89, 91) _check_dipoles(dipoles, forward_fixed, stc, sim_evoked) assert 97 < dipoles[0].gof.max() < 100 assert 91 < dipoles[1].gof.max() < 93 - assert dipoles[0].gof.min() >= 0. + assert dipoles[0].gof.min() >= 0.0 nave = 100000 # add a tiny amount of noise to the simulated evokeds - sim_evoked, stc = simu_data(evoked, forward_fixed, noise_cov, - n_dipoles, evoked.times, nave=nave) - dipoles, residual = rap_music(sim_evoked, forward_fixed, noise_cov, - n_dipoles=n_dipoles, return_residual=True) + sim_evoked, stc = simu_data( + evoked, forward_fixed, noise_cov, n_dipoles, evoked.times, nave=nave + ) + dipoles, residual = rap_music( + sim_evoked, forward_fixed, noise_cov, n_dipoles=n_dipoles, return_residual=True + ) _check_dipoles(dipoles, forward_fixed, stc, sim_evoked, residual) # Check dipoles for free ori - dipoles, residual = rap_music(sim_evoked, forward, noise_cov, - n_dipoles=n_dipoles, return_residual=True) + dipoles, residual = rap_music( + sim_evoked, forward, noise_cov, n_dipoles=n_dipoles, return_residual=True + ) _check_dipoles(dipoles, forward_fixed, stc, sim_evoked, residual) # Check dipoles for free surface ori - dipoles, residual = rap_music(sim_evoked, forward_surf_ori, noise_cov, - n_dipoles=n_dipoles, return_residual=True) + dipoles, residual = rap_music( + sim_evoked, + forward_surf_ori, + noise_cov, + n_dipoles=n_dipoles, + return_residual=True, + ) _check_dipoles(dipoles, forward_fixed, stc, sim_evoked, residual) @@ -159,17 +171,19 @@ def test_rap_music_simulated(): def test_rap_music_sphere(): """Test RAP-MUSIC with real data, sphere model, MEG only.""" evoked, noise_cov = _get_data(ch_decim=8) - sphere = mne.make_sphere_model(r0=(0., 0., 0.04)) - src = mne.setup_volume_source_space(subject=None, pos=10., - sphere=(0.0, 0.0, 40, 65.0), - mindist=5.0, exclude=0.0, - sphere_units='mm') - forward = mne.make_forward_solution(evoked.info, trans=None, src=src, - bem=sphere) + sphere = mne.make_sphere_model(r0=(0.0, 0.0, 0.04)) + src = mne.setup_volume_source_space( + subject=None, + pos=10.0, + sphere=(0.0, 0.0, 40, 65.0), + mindist=5.0, + exclude=0.0, + sphere_units="mm", + ) + forward = mne.make_forward_solution(evoked.info, trans=None, src=src, bem=sphere) with catch_logging() as log: - dipoles = rap_music(evoked, forward, noise_cov, n_dipoles=2, - verbose=True) + dipoles = rap_music(evoked, forward, noise_cov, n_dipoles=2, verbose=True) assert_var_exp_log(log.getvalue(), 47, 49) # Test that there is one dipole on each hemisphere pos = np.array([dip.pos[0] for dip in dipoles]) @@ -177,11 +191,11 @@ def test_rap_music_sphere(): assert (pos[:, 0] < 0).sum() == 1 assert (pos[:, 0] > 0).sum() == 1 # Check the amplitude scale - assert (1e-10 < dipoles[0].amplitude[0] < 1e-7) + assert 1e-10 < dipoles[0].amplitude[0] < 1e-7 # Check the orientation dip_fit = mne.fit_dipole(evoked, noise_cov, sphere)[0] - assert (np.max(np.abs(np.dot(dip_fit.ori, dipoles[0].ori[0]))) > 0.99) - assert (np.max(np.abs(np.dot(dip_fit.ori, dipoles[1].ori[0]))) > 0.99) + assert np.max(np.abs(np.dot(dip_fit.ori, dipoles[0].ori[0]))) > 0.99 + assert np.max(np.abs(np.dot(dip_fit.ori, dipoles[1].ori[0]))) > 0.99 idx = dip_fit.gof.argmax() dist = np.linalg.norm(dipoles[0].pos[idx] - dip_fit.pos[idx]) assert 0.004 <= dist < 0.007 @@ -191,8 +205,7 @@ def test_rap_music_sphere(): @testing.requires_testing_data def test_rap_music_picks(): """Test RAP-MUSIC with picking.""" - evoked = mne.read_evokeds(fname_ave, condition='Right Auditory', - baseline=(None, 0)) + evoked = mne.read_evokeds(fname_ave, condition="Right Auditory", baseline=(None, 0)) evoked.crop(tmin=0.05, tmax=0.15) # select N100 evoked.pick_types(meg=True, eeg=False) forward = mne.read_forward_solution(fname_fwd) diff --git a/mne/beamformer/tests/test_resolution_matrix.py b/mne/beamformer/tests/test_resolution_matrix.py index 6d6730e3b9e..6e574bf89f5 100755 --- a/mne/beamformer/tests/test_resolution_matrix.py +++ b/mne/beamformer/tests/test_resolution_matrix.py @@ -19,16 +19,11 @@ data_path = testing.data_path(download=False) subjects_dir = data_path / "subjects" fname_inv = ( - data_path - / "MEG" - / "sample" - / "sample_audvis_trunc-meg-eeg-oct-6-meg-inv.fif" + data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-6-meg-inv.fif" ) fname_evoked = data_path / "MEG" / "sample" / "sample_audvis_trunc-ave.fif" fname_raw = data_path / "MEG" / "sample" / "sample_audvis_trunc_raw.fif" -fname_fwd = ( - data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-4-fwd.fif" -) +fname_fwd = data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-4-fwd.fif" fname_cov = data_path / "MEG" / "sample" / "sample_audvis_trunc-cov.fif" @@ -39,11 +34,10 @@ def test_resolution_matrix_lcmv(): forward = mne.read_forward_solution(fname_fwd) # remove bad channels - forward = mne.pick_channels_forward(forward, exclude='bads') + forward = mne.pick_channels_forward(forward, exclude="bads") # forward operator with fixed source orientations - forward_fxd = mne.convert_forward_solution(forward, surf_ori=True, - force_fixed=True) + forward_fxd = mne.convert_forward_solution(forward, surf_ori=True, force_fixed=True) # evoked info info = mne.io.read_info(fname_evoked) @@ -59,12 +53,18 @@ def test_resolution_matrix_lcmv(): # compute beamformer filters # reg=0. to make sure noise_cov and data_cov are as similar as possible - filters = make_lcmv(info, forward_fxd, data_cov, reg=0., - noise_cov=noise_cov, - pick_ori=None, rank=None, - weight_norm=None, - reduce_rank=False, - verbose=False) + filters = make_lcmv( + info, + forward_fxd, + data_cov, + reg=0.0, + noise_cov=noise_cov, + pick_ori=None, + rank=None, + weight_norm=None, + reduce_rank=False, + verbose=False, + ) # Compute resolution matrix for beamformer resmat_lcmv = make_lcmv_resolution_matrix(filters, forward_fxd, info) @@ -73,9 +73,9 @@ def test_resolution_matrix_lcmv(): # transpose of leadfield # create filters with transposed whitened leadfield as weights - forward_fxd = mne.pick_channels_forward(forward_fxd, info['ch_names']) + forward_fxd = mne.pick_channels_forward(forward_fxd, info["ch_names"]) filters_lfd = deepcopy(filters) - filters_lfd['weights'][:] = forward_fxd['sol']['data'].T + filters_lfd["weights"][:] = forward_fxd["sol"]["data"].T # compute resolution matrix for filters with transposed leadfield resmat_fwd = make_lcmv_resolution_matrix(filters_lfd, forward_fxd, info) @@ -85,12 +85,11 @@ def test_resolution_matrix_lcmv(): # Some rows are off by about 0.1 - not yet clear why corr = [] - for (f, lf) in zip(resmat_fwd, resmat_lcmv): - + for f, lf in zip(resmat_fwd, resmat_lcmv): corr.append(np.corrcoef(f, lf)[0, 1]) # all row correlations should at least be above ~0.8 - assert_allclose(corr, 1., atol=0.2) + assert_allclose(corr, 1.0, atol=0.2) # Maximum row correlation should at least be close to 1 - assert_allclose(np.max(corr), 1., atol=0.01) + assert_allclose(np.max(corr), 1.0, atol=0.01) diff --git a/mne/bem.py b/mne/bem.py index 505c0fca79d..b9a2bb2b96e 100644 --- a/mne/bem.py +++ b/mne/bem.py @@ -23,22 +23,54 @@ from .fixes import _compare_version from .io.constants import FIFF, FWD from .io._digitization import _dig_kind_dict, _dig_kind_rev, _dig_kind_ints -from .io.write import (start_and_end_file, start_block, write_float, write_int, - write_float_matrix, write_int_matrix, end_block, - write_string) +from .io.write import ( + start_and_end_file, + start_block, + write_float, + write_int, + write_float_matrix, + write_int_matrix, + end_block, + write_string, +) from .io.tag import find_tag from .io.tree import dir_tree_find from .io.open import fiff_open -from .surface import (read_surface, write_surface, complete_surface_info, - _compute_nearest, _get_ico_surface, read_tri, - _fast_cross_nd_sum, _get_solids, _complete_sphere_surf, - decimate_surface, transform_surface_to) +from .surface import ( + read_surface, + write_surface, + complete_surface_info, + _compute_nearest, + _get_ico_surface, + read_tri, + _fast_cross_nd_sum, + _get_solids, + _complete_sphere_surf, + decimate_surface, + transform_surface_to, +) from .transforms import _ensure_trans, apply_trans, Transform -from .utils import (verbose, logger, run_subprocess, get_subjects_dir, warn, - _pl, _validate_type, _TempDir, _check_freesurfer_home, - _check_fname, _check_option, path_like, _import_nibabel, - _on_missing, _import_h5io_funcs, _ensure_int, - _path_like, _verbose_safe_false, _check_head_radius) +from .utils import ( + verbose, + logger, + run_subprocess, + get_subjects_dir, + warn, + _pl, + _validate_type, + _TempDir, + _check_freesurfer_home, + _check_fname, + _check_option, + path_like, + _import_nibabel, + _on_missing, + _import_h5io_funcs, + _ensure_int, + _path_like, + _verbose_safe_false, + _check_head_radius, +) # ############################################################################ @@ -56,20 +88,22 @@ class ConductorModel(dict): """BEM or sphere model.""" def __repr__(self): # noqa: D105 - if self['is_sphere']: - center = ', '.join('%0.1f' % (x * 1000.) for x in self['r0']) + if self["is_sphere"]: + center = ", ".join("%0.1f" % (x * 1000.0) for x in self["r0"]) rad = self.radius if rad is None: # no radius / MEG only - extra = 'Sphere (no layers): r0=[%s] mm' % center + extra = "Sphere (no layers): r0=[%s] mm" % center else: - extra = ('Sphere (%s layer%s): r0=[%s] R=%1.f mm' - % (len(self['layers']) - 1, _pl(self['layers']), - center, rad * 1000.)) + extra = "Sphere (%s layer%s): r0=[%s] R=%1.f mm" % ( + len(self["layers"]) - 1, + _pl(self["layers"]), + center, + rad * 1000.0, + ) else: - extra = ('BEM (%s layer%s)' % (len(self['surfs']), - _pl(self['surfs']))) - extra += " solver=%s" % self['solver'] - return '' % extra + extra = "BEM (%s layer%s)" % (len(self["surfs"]), _pl(self["surfs"])) + extra += " solver=%s" % self["solver"] + return "" % extra def copy(self): """Return copy of ConductorModel instance.""" @@ -78,9 +112,9 @@ def copy(self): @property def radius(self): """Sphere radius if an EEG sphere model.""" - if not self['is_sphere']: - raise RuntimeError('radius undefined for BEM') - return None if len(self['layers']) == 0 else self['layers'][-1]['rad'] + if not self["is_sphere"]: + raise RuntimeError("radius undefined for BEM") + return None if len(self["layers"]) == 0 else self["layers"][-1]["rad"] def _calc_beta(rk, rk_norm, rk1, rk1_norm): @@ -108,9 +142,9 @@ def _lin_pot_coeff(fros, tri_rr, tri_nn, tri_area): l2 = np.linalg.norm(v2, axis=1) l3 = np.linalg.norm(v3, axis=1) ss = l1 * l2 * l3 - ss += np.einsum('ij,ij,i->i', v1, v2, l3) - ss += np.einsum('ij,ij,i->i', v1, v3, l2) - ss += np.einsum('ij,ij,i->i', v2, v3, l1) + ss += np.einsum("ij,ij,i->i", v1, v2, l3) + ss += np.einsum("ij,ij,i->i", v1, v3, l2) + ss += np.einsum("ij,ij,i->i", v2, v3, l1) solids = np.arctan2(triples, ss) # We *could* subselect the good points from v1, v2, v3, triples, solids, @@ -119,14 +153,16 @@ def _lin_pot_coeff(fros, tri_rr, tri_nn, tri_area): # solution. These three lines ensure we don't get invalid values in # _calc_beta. bad_mask = np.abs(solids) < np.pi / 1e6 - l1[bad_mask] = 1. - l2[bad_mask] = 1. - l3[bad_mask] = 1. + l1[bad_mask] = 1.0 + l2[bad_mask] = 1.0 + l3[bad_mask] = 1.0 # Calculate the magic vector vec_omega - beta = [_calc_beta(v1, l1, v2, l2)[:, np.newaxis], - _calc_beta(v2, l2, v3, l3)[:, np.newaxis], - _calc_beta(v3, l3, v1, l1)[:, np.newaxis]] + beta = [ + _calc_beta(v1, l1, v2, l2)[:, np.newaxis], + _calc_beta(v2, l2, v3, l3)[:, np.newaxis], + _calc_beta(v3, l3, v1, l1)[:, np.newaxis], + ] vec_omega = (beta[2] - beta[0]) * v1 vec_omega += (beta[0] - beta[1]) * v2 vec_omega += (beta[1] - beta[2]) * v3 @@ -140,26 +176,27 @@ def _lin_pot_coeff(fros, tri_rr, tri_nn, tri_area): for k in range(3): diff = yys[idx[k - 1]] - yys[idx[k + 1]] zdots = _fast_cross_nd_sum(yys[idx[k + 1]], yys[idx[k - 1]], tri_nn) - omega[:, k] = -n2 * (area2 * zdots * 2. * solids - - triples * (diff * vec_omega).sum(axis=-1)) + omega[:, k] = -n2 * ( + area2 * zdots * 2.0 * solids - triples * (diff * vec_omega).sum(axis=-1) + ) # omit the bad points from the solution - omega[bad_mask] = 0. + omega[bad_mask] = 0.0 return omega def _correct_auto_elements(surf, mat): """Improve auto-element approximation.""" pi2 = 2.0 * np.pi - tris_flat = surf['tris'].ravel() + tris_flat = surf["tris"].ravel() misses = pi2 - mat.sum(axis=1) for j, miss in enumerate(misses): # How much is missing? - n_memb = len(surf['neighbor_tri'][j]) + n_memb = len(surf["neighbor_tri"][j]) assert n_memb > 0 # should be guaranteed by our surface checks # The node itself receives one half mat[j, j] = miss / 2.0 # The rest is divided evenly among the member nodes... - miss /= (4.0 * n_memb) + miss /= 4.0 * n_memb members = np.where(j == tris_flat)[0] mods = members % 3 offsets = np.array([[1, 2], [-1, 1], [-1, -2]]) @@ -174,27 +211,34 @@ def _correct_auto_elements(surf, mat): def _fwd_bem_lin_pot_coeff(surfs): """Calculate the coefficients for linear collocation approach.""" # taken from fwd_bem_linear_collocation.c - nps = [surf['np'] for surf in surfs] + nps = [surf["np"] for surf in surfs] np_tot = sum(nps) coeff = np.zeros((np_tot, np_tot)) offsets = np.cumsum(np.concatenate(([0], nps))) for si_1, surf1 in enumerate(surfs): rr_ord = np.arange(nps[si_1]) for si_2, surf2 in enumerate(surfs): - logger.info(" %s (%d) -> %s (%d) ..." % - (_bem_surf_name[surf1['id']], nps[si_1], - _bem_surf_name[surf2['id']], nps[si_2])) - tri_rr = surf2['rr'][surf2['tris']] - tri_nn = surf2['tri_nn'] - tri_area = surf2['tri_area'] - submat = coeff[offsets[si_1]:offsets[si_1 + 1], - offsets[si_2]:offsets[si_2 + 1]] # view - for k in range(surf2['ntri']): - tri = surf2['tris'][k] + logger.info( + " %s (%d) -> %s (%d) ..." + % ( + _bem_surf_name[surf1["id"]], + nps[si_1], + _bem_surf_name[surf2["id"]], + nps[si_2], + ) + ) + tri_rr = surf2["rr"][surf2["tris"]] + tri_nn = surf2["tri_nn"] + tri_area = surf2["tri_area"] + submat = coeff[ + offsets[si_1] : offsets[si_1 + 1], offsets[si_2] : offsets[si_2 + 1] + ] # view + for k in range(surf2["ntri"]): + tri = surf2["tris"][k] if si_1 == si_2: - skip_idx = ((rr_ord == tri[0]) | - (rr_ord == tri[1]) | - (rr_ord == tri[2])) + skip_idx = ( + (rr_ord == tri[0]) | (rr_ord == tri[1]) | (rr_ord == tri[2]) + ) else: skip_idx = list() # No contribution from a triangle that @@ -202,9 +246,13 @@ def _fwd_bem_lin_pot_coeff(surfs): # if sidx1 == sidx2 and (tri == j).any(): # continue # Otherwise do the hard job - coeffs = _lin_pot_coeff(fros=surf1['rr'], tri_rr=tri_rr[k], - tri_nn=tri_nn[k], tri_area=tri_area[k]) - coeffs[skip_idx] = 0. + coeffs = _lin_pot_coeff( + fros=surf1["rr"], + tri_rr=tri_rr[k], + tri_nn=tri_nn[k], + tri_area=tri_area[k], + ) + coeffs[skip_idx] = 0.0 submat[:, tri] -= coeffs if si_1 == si_2: _correct_auto_elements(surf1, submat) @@ -246,11 +294,11 @@ def _fwd_bem_ip_modify_solution(solution, ip_solution, ip_mult, n_tri): n_last = n_tri[-1] mult = (1.0 + ip_mult) / ip_mult - logger.info(' Combining...') + logger.info(" Combining...") offsets = np.cumsum(np.concatenate(([0], n_tri))) for si in range(len(n_tri)): # Pick the correct submatrix (right column) and multiply - sub = solution[offsets[si]:offsets[si + 1], np.sum(n_tri[:-1]):] + sub = solution[offsets[si] : offsets[si + 1], np.sum(n_tri[:-1]) :] # Multiply sub -= 2 * np.dot(sub, ip_solution) @@ -258,63 +306,64 @@ def _fwd_bem_ip_modify_solution(solution, ip_solution, ip_mult, n_tri): sub[-n_last:, -n_last:] += mult * ip_solution # Final scaling - logger.info(' Scaling...') + logger.info(" Scaling...") solution *= ip_mult return -def _check_complete_surface(surf, copy=False, incomplete='raise', extra=''): - surf = complete_surface_info( - surf, copy=copy, verbose=_verbose_safe_false()) - fewer = np.where([len(t) < 3 for t in surf['neighbor_tri']])[0] +def _check_complete_surface(surf, copy=False, incomplete="raise", extra=""): + surf = complete_surface_info(surf, copy=copy, verbose=_verbose_safe_false()) + fewer = np.where([len(t) < 3 for t in surf["neighbor_tri"]])[0] if len(fewer) > 0: fewer = list(fewer) - fewer = (fewer[:80] + ['...']) if len(fewer) > 80 else fewer - fewer = ', '.join(str(f) for f in fewer) - msg = ('Surface {} has topological defects: {:.0f} / {:.0f} vertices ' - 'have fewer than three neighboring triangles [{}]{}' - .format(_bem_surf_name[surf['id']], len(fewer), len(surf['rr']), - fewer, extra)) - _on_missing(on_missing=incomplete, msg=msg, name='on_defects') + fewer = (fewer[:80] + ["..."]) if len(fewer) > 80 else fewer + fewer = ", ".join(str(f) for f in fewer) + msg = ( + "Surface {} has topological defects: {:.0f} / {:.0f} vertices " + "have fewer than three neighboring triangles [{}]{}".format( + _bem_surf_name[surf["id"]], len(fewer), len(surf["rr"]), fewer, extra + ) + ) + _on_missing(on_missing=incomplete, msg=msg, name="on_defects") return surf def _fwd_bem_linear_collocation_solution(bem): """Compute the linear collocation potential solution.""" # first, add surface geometries - logger.info('Computing the linear collocation solution...') - logger.info(' Matrix coefficients...') - coeff = _fwd_bem_lin_pot_coeff(bem['surfs']) - bem['nsol'] = len(coeff) + logger.info("Computing the linear collocation solution...") + logger.info(" Matrix coefficients...") + coeff = _fwd_bem_lin_pot_coeff(bem["surfs"]) + bem["nsol"] = len(coeff) logger.info(" Inverting the coefficient matrix...") - nps = [surf['np'] for surf in bem['surfs']] - bem['solution'] = _fwd_bem_multi_solution(coeff, bem['gamma'], nps) - if len(bem['surfs']) == 3: - ip_mult = bem['sigma'][1] / bem['sigma'][2] + nps = [surf["np"] for surf in bem["surfs"]] + bem["solution"] = _fwd_bem_multi_solution(coeff, bem["gamma"], nps) + if len(bem["surfs"]) == 3: + ip_mult = bem["sigma"][1] / bem["sigma"][2] if ip_mult <= FWD.BEM_IP_APPROACH_LIMIT: - logger.info('IP approach required...') - logger.info(' Matrix coefficients (homog)...') - coeff = _fwd_bem_lin_pot_coeff([bem['surfs'][-1]]) - logger.info(' Inverting the coefficient matrix (homog)...') - ip_solution = _fwd_bem_homog_solution(coeff, - [bem['surfs'][-1]['np']]) - logger.info(' Modify the original solution to incorporate ' - 'IP approach...') - _fwd_bem_ip_modify_solution(bem['solution'], ip_solution, ip_mult, - nps) - bem['bem_method'] = FIFF.FIFFV_BEM_APPROX_LINEAR - bem['solver'] = 'mne' - - -def _import_openmeeg(what='compute a BEM solution using OpenMEEG'): + logger.info("IP approach required...") + logger.info(" Matrix coefficients (homog)...") + coeff = _fwd_bem_lin_pot_coeff([bem["surfs"][-1]]) + logger.info(" Inverting the coefficient matrix (homog)...") + ip_solution = _fwd_bem_homog_solution(coeff, [bem["surfs"][-1]["np"]]) + logger.info( + " Modify the original solution to incorporate " "IP approach..." + ) + _fwd_bem_ip_modify_solution(bem["solution"], ip_solution, ip_mult, nps) + bem["bem_method"] = FIFF.FIFFV_BEM_APPROX_LINEAR + bem["solver"] = "mne" + + +def _import_openmeeg(what="compute a BEM solution using OpenMEEG"): try: import openmeeg as om except Exception as exc: raise ImportError( - f'The OpenMEEG module must be installed to {what}, but ' - f'"import openmeeg" resulted in: {exc}') from None - if not _compare_version(om.__version__, '>=', '2.5.6'): - raise ImportError(f'OpenMEEG 2.5.6+ is required, got {om.__version__}') + f"The OpenMEEG module must be installed to {what}, but " + f'"import openmeeg" resulted in: {exc}' + ) from None + if not _compare_version(om.__version__, ">=", "2.5.6"): + raise ImportError(f"OpenMEEG 2.5.6+ is required, got {om.__version__}") return om @@ -322,37 +371,37 @@ def _make_openmeeg_geometry(bem, mri_head_t=None): # OpenMEEG om = _import_openmeeg() meshes = [] - for surf in bem['surfs'][::-1]: + for surf in bem["surfs"][::-1]: if mri_head_t is not None: surf = transform_surface_to(surf, "head", mri_head_t, copy=True) - points, faces = surf['rr'], surf['tris'] + points, faces = surf["rr"], surf["tris"] faces = faces[:, [1, 0, 2]] # swap faces meshes.append((points, faces)) - conductivity = bem['sigma'][::-1] + conductivity = bem["sigma"][::-1] return om.make_nested_geometry(meshes, conductivity) def _fwd_bem_openmeeg_solution(bem): om = _import_openmeeg() - logger.info('Creating BEM solution using OpenMEEG') - logger.info('Computing the openmeeg head matrix solution...') - logger.info(' Matrix coefficients...') + logger.info("Creating BEM solution using OpenMEEG") + logger.info("Computing the openmeeg head matrix solution...") + logger.info(" Matrix coefficients...") geom = _make_openmeeg_geometry(bem) hm = om.HeadMat(geom) - bem['nsol'] = hm.nlin() + bem["nsol"] = hm.nlin() logger.info(" Inverting the coefficient matrix...") hm.invert() # invert inplace - bem['solution'] = hm.array_flat() - bem['bem_method'] = FIFF.FIFFV_BEM_APPROX_LINEAR - bem['solver'] = 'openmeeg' + bem["solution"] = hm.array_flat() + bem["bem_method"] = FIFF.FIFFV_BEM_APPROX_LINEAR + bem["solver"] = "openmeeg" @verbose -def make_bem_solution(surfs, *, solver='mne', verbose=None): +def make_bem_solution(surfs, *, solver="mne", verbose=None): """Create a BEM solution using the linear collocation approach. Parameters @@ -383,76 +432,83 @@ def make_bem_solution(surfs, *, solver='mne', verbose=None): ----- .. versionadded:: 0.10.0 """ - _validate_type(solver, str, 'solver') - _check_option('method', solver.lower(), ('mne', 'openmeeg')) + _validate_type(solver, str, "solver") + _check_option("method", solver.lower(), ("mne", "openmeeg")) bem = _ensure_bem_surfaces(surfs) _add_gamma_multipliers(bem) - if len(bem['surfs']) == 3: - logger.info('Three-layer model surfaces loaded.') - elif len(bem['surfs']) == 1: - logger.info('Homogeneous model surface loaded.') + if len(bem["surfs"]) == 3: + logger.info("Three-layer model surfaces loaded.") + elif len(bem["surfs"]) == 1: + logger.info("Homogeneous model surface loaded.") else: - raise RuntimeError('Only 1- or 3-layer BEM computations supported') - _check_bem_size(bem['surfs']) - for surf in bem['surfs']: + raise RuntimeError("Only 1- or 3-layer BEM computations supported") + _check_bem_size(bem["surfs"]) + for surf in bem["surfs"]: _check_complete_surface(surf) - if solver.lower() == 'openmeeg': + if solver.lower() == "openmeeg": _fwd_bem_openmeeg_solution(bem) else: - assert solver.lower() == 'mne' + assert solver.lower() == "mne" _fwd_bem_linear_collocation_solution(bem) logger.info("Solution ready.") - logger.info('BEM geometry computations complete.') + logger.info("BEM geometry computations complete.") return bem # ############################################################################ # Make BEM model + def _ico_downsample(surf, dest_grade): """Downsample the surface if isomorphic to a subdivided icosahedron.""" - n_tri = len(surf['tris']) - bad_msg = ("Cannot decimate to requested ico grade %d. The provided " - "BEM surface has %d triangles, which cannot be isomorphic with " - "a subdivided icosahedron. Consider manually decimating the " - "surface to a suitable density and then use ico=None in " - "make_bem_model." % (dest_grade, n_tri)) + n_tri = len(surf["tris"]) + bad_msg = ( + "Cannot decimate to requested ico grade %d. The provided " + "BEM surface has %d triangles, which cannot be isomorphic with " + "a subdivided icosahedron. Consider manually decimating the " + "surface to a suitable density and then use ico=None in " + "make_bem_model." % (dest_grade, n_tri) + ) if n_tri % 20 != 0: raise RuntimeError(bad_msg) n_tri = n_tri // 20 found = int(round(np.log(n_tri) / np.log(4))) - if n_tri != 4 ** found: + if n_tri != 4**found: raise RuntimeError(bad_msg) del n_tri if dest_grade > found: - raise RuntimeError('For this surface, decimation grade should be %d ' - 'or less, not %s.' % (found, dest_grade)) + raise RuntimeError( + "For this surface, decimation grade should be %d " + "or less, not %s." % (found, dest_grade) + ) source = _get_ico_surface(found) dest = _get_ico_surface(dest_grade, patch_stats=True) - del dest['tri_cent'] - del dest['tri_nn'] - del dest['neighbor_tri'] - del dest['tri_area'] - if not np.array_equal(source['tris'], surf['tris']): - raise RuntimeError('The source surface has a matching number of ' - 'triangles but ordering is wrong') - logger.info('Going from %dth to %dth subdivision of an icosahedron ' - '(n_tri: %d -> %d)' % (found, dest_grade, len(surf['tris']), - len(dest['tris']))) + del dest["tri_cent"] + del dest["tri_nn"] + del dest["neighbor_tri"] + del dest["tri_area"] + if not np.array_equal(source["tris"], surf["tris"]): + raise RuntimeError( + "The source surface has a matching number of " + "triangles but ordering is wrong" + ) + logger.info( + "Going from %dth to %dth subdivision of an icosahedron " + "(n_tri: %d -> %d)" % (found, dest_grade, len(surf["tris"]), len(dest["tris"])) + ) # Find the mapping - dest['rr'] = surf['rr'][_get_ico_map(source, dest)] + dest["rr"] = surf["rr"][_get_ico_map(source, dest)] return dest def _get_ico_map(fro, to): """Get a mapping between ico surfaces.""" - nearest, dists = _compute_nearest(fro['rr'], to['rr'], return_dists=True) + nearest, dists = _compute_nearest(fro["rr"], to["rr"], return_dists=True) n_bads = (dists > 5e-3).sum() if n_bads > 0: - raise RuntimeError('No matching vertex for %d destination vertices' - % (n_bads)) + raise RuntimeError("No matching vertex for %d destination vertices" % (n_bads)) return nearest @@ -461,32 +517,36 @@ def _order_surfaces(surfs): if len(surfs) != 3: return surfs # we have three surfaces - surf_order = [FIFF.FIFFV_BEM_SURF_ID_HEAD, - FIFF.FIFFV_BEM_SURF_ID_SKULL, - FIFF.FIFFV_BEM_SURF_ID_BRAIN] - ids = np.array([surf['id'] for surf in surfs]) + surf_order = [ + FIFF.FIFFV_BEM_SURF_ID_HEAD, + FIFF.FIFFV_BEM_SURF_ID_SKULL, + FIFF.FIFFV_BEM_SURF_ID_BRAIN, + ] + ids = np.array([surf["id"] for surf in surfs]) if set(ids) != set(surf_order): - raise RuntimeError('bad surface ids: %s' % ids) + raise RuntimeError("bad surface ids: %s" % ids) order = [np.where(ids == id_)[0][0] for id_ in surf_order] surfs = [surfs[idx] for idx in order] return surfs -def _assert_complete_surface(surf, incomplete='raise'): +def _assert_complete_surface(surf, incomplete="raise"): """Check the sum of solid angles as seen from inside.""" # from surface_checks.c # Center of mass.... - cm = surf['rr'].mean(axis=0) - logger.info('%s CM is %6.2f %6.2f %6.2f mm' % - (_bem_surf_name[surf['id']], - 1000 * cm[0], 1000 * cm[1], 1000 * cm[2])) - tot_angle = _get_solids(surf['rr'][surf['tris']], cm[np.newaxis, :])[0] + cm = surf["rr"].mean(axis=0) + logger.info( + "%s CM is %6.2f %6.2f %6.2f mm" + % (_bem_surf_name[surf["id"]], 1000 * cm[0], 1000 * cm[1], 1000 * cm[2]) + ) + tot_angle = _get_solids(surf["rr"][surf["tris"]], cm[np.newaxis, :])[0] prop = tot_angle / (2 * np.pi) if np.abs(prop - 1.0) > 1e-5: - msg = (f'Surface {_bem_surf_name[surf["id"]]} is not complete (sum of ' - f'solid angles yielded {prop}, should be 1.)') - _on_missing( - incomplete, msg, name='incomplete', error_klass=RuntimeError) + msg = ( + f'Surface {_bem_surf_name[surf["id"]]} is not complete (sum of ' + f"solid angles yielded {prop}, should be 1.)" + ) + _on_missing(incomplete, msg, name="incomplete", error_klass=RuntimeError) def _assert_inside(fro, to): @@ -494,15 +554,15 @@ def _assert_inside(fro, to): # this is "is_inside" in surface_checks.c fro_name = _bem_surf_name[fro["id"]] to_name = _bem_surf_name[to["id"]] - logger.info( - f'Checking that surface {fro_name} is inside surface {to_name} ...') - tot_angle = _get_solids(to['rr'][to['tris']], fro['rr']) + logger.info(f"Checking that surface {fro_name} is inside surface {to_name} ...") + tot_angle = _get_solids(to["rr"][to["tris"]], fro["rr"]) if (np.abs(tot_angle / (2 * np.pi) - 1.0) > 1e-5).any(): raise RuntimeError( - f'Surface {fro_name} is not completely inside surface {to_name}') + f"Surface {fro_name} is not completely inside surface {to_name}" + ) -def _check_surfaces(surfs, incomplete='raise'): +def _check_surfaces(surfs, incomplete="raise"): """Check that the surfaces are complete and non-intersecting.""" for surf in surfs: _assert_complete_surface(surf, incomplete=incomplete) @@ -513,36 +573,40 @@ def _check_surfaces(surfs, incomplete='raise'): def _check_surface_size(surf): """Check that the coordinate limits are reasonable.""" - sizes = surf['rr'].max(axis=0) - surf['rr'].min(axis=0) + sizes = surf["rr"].max(axis=0) - surf["rr"].min(axis=0) if (sizes < 0.05).any(): raise RuntimeError( f'Dimensions of the surface {_bem_surf_name[surf["id"]]} seem too ' - f'small ({1000 * sizes.min():9.5f}). Maybe the unit of measure' - ' is meters instead of mm') + f"small ({1000 * sizes.min():9.5f}). Maybe the unit of measure" + " is meters instead of mm" + ) def _check_thicknesses(surfs): """Compute how close we are.""" for surf_1, surf_2 in zip(surfs[:-1], surfs[1:]): - min_dist = _compute_nearest(surf_1['rr'], surf_2['rr'], - return_dists=True)[1] + min_dist = _compute_nearest(surf_1["rr"], surf_2["rr"], return_dists=True)[1] min_dist = min_dist.min() - fro = _bem_surf_name[surf_1['id']] - to = _bem_surf_name[surf_2['id']] - logger.info(f'Checking distance between {fro} and {to} surfaces...') - logger.info(f'Minimum distance between the {fro} and {to} surfaces is ' - f'approximately {1000 * min_dist:6.1f} mm') - - -def _surfaces_to_bem(surfs, ids, sigmas, ico=None, rescale=True, - incomplete='raise', extra=''): + fro = _bem_surf_name[surf_1["id"]] + to = _bem_surf_name[surf_2["id"]] + logger.info(f"Checking distance between {fro} and {to} surfaces...") + logger.info( + f"Minimum distance between the {fro} and {to} surfaces is " + f"approximately {1000 * min_dist:6.1f} mm" + ) + + +def _surfaces_to_bem( + surfs, ids, sigmas, ico=None, rescale=True, incomplete="raise", extra="" +): """Convert surfaces to a BEM.""" # equivalent of mne_surf2bem # surfs can be strings (filenames) or surface dicts - if len(surfs) not in (1, 3) or not (len(surfs) == len(ids) == - len(sigmas)): - raise ValueError('surfs, ids, and sigmas must all have the same ' - 'number of elements (1 or 3)') + if len(surfs) not in (1, 3) or not (len(surfs) == len(ids) == len(sigmas)): + raise ValueError( + "surfs, ids, and sigmas must all have the same " + "number of elements (1 or 3)" + ) for si, surf in enumerate(surfs): if isinstance(surf, (str, Path, os.PathLike)): surfs[si] = surf = read_surface(surf, return_dict=True)[-1] @@ -552,19 +616,18 @@ def _surfaces_to_bem(surfs, ids, sigmas, ico=None, rescale=True, surfs[si] = _ico_downsample(surf, ico) for surf, id_ in zip(surfs, ids): # Do topology checks (but don't save data) to fail early - surf['id'] = id_ - _check_complete_surface(surf, copy=True, incomplete=incomplete, - extra=extra) - surf['coord_frame'] = surf.get('coord_frame', FIFF.FIFFV_COORD_MRI) - surf.update(np=len(surf['rr']), ntri=len(surf['tris'])) + surf["id"] = id_ + _check_complete_surface(surf, copy=True, incomplete=incomplete, extra=extra) + surf["coord_frame"] = surf.get("coord_frame", FIFF.FIFFV_COORD_MRI) + surf.update(np=len(surf["rr"]), ntri=len(surf["tris"])) if rescale: - surf['rr'] /= 1000. # convert to meters + surf["rr"] /= 1000.0 # convert to meters # Shifting surfaces is not implemented here... # Order the surfaces for the benefit of the topology checks for surf, sigma in zip(surfs, sigmas): - surf['sigma'] = sigma + surf["sigma"] = sigma surfs = _order_surfaces(surfs) # Check topology as best we can @@ -572,13 +635,14 @@ def _surfaces_to_bem(surfs, ids, sigmas, ico=None, rescale=True, for surf in surfs: _check_surface_size(surf) _check_thicknesses(surfs) - logger.info('Surfaces passed the basic topology checks.') + logger.info("Surfaces passed the basic topology checks.") return surfs @verbose -def make_bem_model(subject, ico=4, conductivity=(0.3, 0.006, 0.3), - subjects_dir=None, verbose=None): +def make_bem_model( + subject, ico=4, conductivity=(0.3, 0.006, 0.3), subjects_dir=None, verbose=None +): """Create a BEM model for a subject. .. note:: To get a single layer bem corresponding to the --homog flag in @@ -619,8 +683,7 @@ def make_bem_model(subject, ico=4, conductivity=(0.3, 0.006, 0.3), """ conductivity = np.array(conductivity, float) if conductivity.ndim != 1 or conductivity.size not in (1, 3): - raise ValueError('conductivity must be 1D array-like with 1 or 3 ' - 'elements') + raise ValueError("conductivity must be 1D array-like with 1 or 3 " "elements") subjects_dir = get_subjects_dir(subjects_dir, raise_error=True) subject_dir = subjects_dir / subject bem_dir = subject_dir / "bem" @@ -628,27 +691,30 @@ def make_bem_model(subject, ico=4, conductivity=(0.3, 0.006, 0.3), outer_skull = bem_dir / "outer_skull.surf" outer_skin = bem_dir / "outer_skin.surf" surfaces = [inner_skull, outer_skull, outer_skin] - ids = [FIFF.FIFFV_BEM_SURF_ID_BRAIN, - FIFF.FIFFV_BEM_SURF_ID_SKULL, - FIFF.FIFFV_BEM_SURF_ID_HEAD] - logger.info('Creating the BEM geometry...') + ids = [ + FIFF.FIFFV_BEM_SURF_ID_BRAIN, + FIFF.FIFFV_BEM_SURF_ID_SKULL, + FIFF.FIFFV_BEM_SURF_ID_HEAD, + ] + logger.info("Creating the BEM geometry...") if len(conductivity) == 1: surfaces = surfaces[:1] ids = ids[:1] surfaces = _surfaces_to_bem(surfaces, ids, conductivity, ico) _check_bem_size(surfaces) - logger.info('Complete.\n') + logger.info("Complete.\n") return surfaces # ############################################################################ # Compute EEG sphere model + def _fwd_eeg_get_multi_sphere_model_coeffs(m, n_terms): """Get the model depended weighting factor for n.""" - nlayer = len(m['layers']) + nlayer = len(m["layers"]) if nlayer in (0, 1): - return 1. + return 1.0 # Initialize the arrays c1 = np.zeros(nlayer - 1) @@ -656,9 +722,9 @@ def _fwd_eeg_get_multi_sphere_model_coeffs(m, n_terms): cr = np.zeros(nlayer - 1) cr_mult = np.zeros(nlayer - 1) for k in range(nlayer - 1): - c1[k] = m['layers'][k]['sigma'] / m['layers'][k + 1]['sigma'] + c1[k] = m["layers"][k]["sigma"] / m["layers"][k + 1]["sigma"] c2[k] = c1[k] - 1.0 - cr_mult[k] = m['layers'][k]['rel_rad'] + cr_mult[k] = m["layers"][k]["rel_rad"] cr[k] = cr_mult[k] cr_mult[k] *= cr_mult[k] @@ -672,8 +738,13 @@ def _fwd_eeg_get_multi_sphere_model_coeffs(m, n_terms): M = np.eye(2) n1 = n + 1.0 for k in range(nlayer - 2, -1, -1): - M = np.dot([[n + n1 * c1[k], n1 * c2[k] / cr[k]], - [n * c2[k] * cr[k], n1 + n * c1[k]]], M) + M = np.dot( + [ + [n + n1 * c1[k], n1 * c2[k] / cr[k]], + [n * c2[k] * cr[k], n1 + n * c1[k]], + ], + M, + ) num = n * (2.0 * n + 1.0) ** (nlayer - 1) coeffs[n - 1] = num / (n * M[1, 1] + n1 * M[1, 0]) return coeffs @@ -682,15 +753,15 @@ def _fwd_eeg_get_multi_sphere_model_coeffs(m, n_terms): def _compose_linear_fitting_data(mu, u): """Get the linear fitting data.""" from scipy import linalg - k1 = np.arange(1, u['nterms']) + + k1 = np.arange(1, u["nterms"]) mu1ns = mu[0] ** k1 # data to be fitted - y = u['w'][:-1] * (u['fn'][1:] - mu1ns * u['fn'][0]) + y = u["w"][:-1] * (u["fn"][1:] - mu1ns * u["fn"][0]) # model matrix - M = u['w'][:-1, np.newaxis] * (mu[1:] ** k1[:, np.newaxis] - - mu1ns[:, np.newaxis]) + M = u["w"][:-1, np.newaxis] * (mu[1:] ** k1[:, np.newaxis] - mu1ns[:, np.newaxis]) uu, sing, vv = linalg.svd(M, full_matrices=False) - ncomp = u['nfit'] - 1 + ncomp = u["nfit"] - 1 uu, sing, vv = uu[:, :ncomp], sing[:ncomp], vv[:ncomp] return y, uu, sing, vv @@ -704,9 +775,9 @@ def _compute_linear_parameters(mu, u): resi = y - np.dot(uu, vec) vec /= sing - lambda_ = np.zeros(u['nfit']) + lambda_ = np.zeros(u["nfit"]) lambda_[1:] = np.dot(vec, vv) - lambda_[0] = u['fn'][0] - np.sum(lambda_[1:]) + lambda_[0] = u["fn"][0] - np.sum(lambda_[1:]) rv = np.dot(resi, resi) / np.dot(y, y) return rv, lambda_ @@ -725,27 +796,28 @@ def _one_step(mu, u): def _fwd_eeg_fit_berg_scherg(m, nterms, nfit): """Fit the Berg-Scherg equivalent spherical model dipole parameters.""" from scipy.optimize import fmin_cobyla + assert nfit >= 2 u = dict(nfit=nfit, nterms=nterms) # (1) Calculate the coefficients of the true expansion - u['fn'] = _fwd_eeg_get_multi_sphere_model_coeffs(m, nterms + 1) + u["fn"] = _fwd_eeg_get_multi_sphere_model_coeffs(m, nterms + 1) # (2) Calculate the weighting - f = (min([layer['rad'] for layer in m['layers']]) / - max([layer['rad'] for layer in m['layers']])) + f = min([layer["rad"] for layer in m["layers"]]) / max( + [layer["rad"] for layer in m["layers"]] + ) # correct weighting k = np.arange(1, nterms + 1) - u['w'] = np.sqrt((2.0 * k + 1) * (3.0 * k + 1.0) / - k) * np.power(f, (k - 1.0)) - u['w'][-1] = 0 + u["w"] = np.sqrt((2.0 * k + 1) * (3.0 * k + 1.0) / k) * np.power(f, (k - 1.0)) + u["w"][-1] = 0 # Do the nonlinear minimization, constraining mu to the interval [-1, +1] mu_0 = np.zeros(3) fun = partial(_one_step, u=u) catol = 1e-6 - max_ = 1. - 2 * catol + max_ = 1.0 - 2 * catol def cons(x): return max_ - np.abs(x) @@ -757,17 +829,22 @@ def cons(x): order = np.argsort(mu)[::-1] mu, lambda_ = mu[order], lambda_[order] # sort: largest mu first - m['mu'] = mu + m["mu"] = mu # This division takes into account the actual conductivities - m['lambda'] = lambda_ / m['layers'][-1]['sigma'] - m['nfit'] = nfit + m["lambda"] = lambda_ / m["layers"][-1]["sigma"] + m["nfit"] = nfit return rv @verbose -def make_sphere_model(r0=(0., 0., 0.04), head_radius=0.09, info=None, - relative_radii=(0.90, 0.92, 0.97, 1.0), - sigmas=(0.33, 1.0, 0.004, 0.33), verbose=None): +def make_sphere_model( + r0=(0.0, 0.0, 0.04), + head_radius=0.09, + info=None, + relative_radii=(0.90, 0.92, 0.97, 1.0), + sigmas=(0.33, 1.0, 0.004, 0.33), + verbose=None, +): """Create a spherical model for forward solution calculation. Parameters @@ -809,33 +886,37 @@ def make_sphere_model(r0=(0., 0., 0.04), head_radius=0.09, info=None, .. versionadded:: 0.9.0 """ - for name in ('r0', 'head_radius'): + for name in ("r0", "head_radius"): param = locals()[name] if isinstance(param, str): - if param != 'auto': - raise ValueError('%s, if str, must be "auto" not "%s"' - % (name, param)) + if param != "auto": + raise ValueError('%s, if str, must be "auto" not "%s"' % (name, param)) relative_radii = np.array(relative_radii, float).ravel() sigmas = np.array(sigmas, float).ravel() if len(relative_radii) != len(sigmas): - raise ValueError('relative_radii length (%s) must match that of ' - 'sigmas (%s)' % (len(relative_radii), - len(sigmas))) + raise ValueError( + "relative_radii length (%s) must match that of " + "sigmas (%s)" % (len(relative_radii), len(sigmas)) + ) if len(sigmas) <= 1 and head_radius is not None: - raise ValueError('at least 2 sigmas must be supplied if ' - 'head_radius is not None, got %s' % (len(sigmas),)) - if (isinstance(r0, str) and r0 == 'auto') or \ - (isinstance(head_radius, str) and head_radius == 'auto'): + raise ValueError( + "at least 2 sigmas must be supplied if " + "head_radius is not None, got %s" % (len(sigmas),) + ) + if (isinstance(r0, str) and r0 == "auto") or ( + isinstance(head_radius, str) and head_radius == "auto" + ): if info is None: - raise ValueError('Info must not be None for auto mode') - head_radius_fit, r0_fit = fit_sphere_to_headshape(info, units='m')[:2] + raise ValueError("Info must not be None for auto mode") + head_radius_fit, r0_fit = fit_sphere_to_headshape(info, units="m")[:2] if isinstance(r0, str): r0 = r0_fit if isinstance(head_radius, str): head_radius = head_radius_fit - sphere = ConductorModel(is_sphere=True, r0=np.array(r0), - coord_frame=FIFF.FIFFV_COORD_HEAD) - sphere['layers'] = list() + sphere = ConductorModel( + is_sphere=True, r0=np.array(r0), coord_frame=FIFF.FIFFV_COORD_HEAD + ) + sphere["layers"] = list() if head_radius is not None: # Eventually these could be configurable... relative_radii = np.array(relative_radii, float) @@ -846,15 +927,15 @@ def make_sphere_model(r0=(0., 0., 0.04), head_radius=0.09, info=None, for rel_rad, sig in zip(relative_radii, sigmas): # sort layers by (relative) radius, and scale radii layer = dict(rad=rel_rad, sigma=sig) - layer['rel_rad'] = layer['rad'] = rel_rad - sphere['layers'].append(layer) + layer["rel_rad"] = layer["rad"] = rel_rad + sphere["layers"].append(layer) # scale the radii - R = sphere['layers'][-1]['rad'] - rR = sphere['layers'][-1]['rel_rad'] - for layer in sphere['layers']: - layer['rad'] /= R - layer['rel_rad'] /= rR + R = sphere["layers"][-1]["rad"] + rR = sphere["layers"][-1]["rel_rad"] + for layer in sphere["layers"]: + layer["rad"] /= R + layer["rel_rad"] /= rR # # Setup the EEG sphere model calculations @@ -862,25 +943,32 @@ def make_sphere_model(r0=(0., 0., 0.04), head_radius=0.09, info=None, # Scale the relative radii for k in range(len(relative_radii)): - sphere['layers'][k]['rad'] = (head_radius * - sphere['layers'][k]['rel_rad']) + sphere["layers"][k]["rad"] = head_radius * sphere["layers"][k]["rel_rad"] rv = _fwd_eeg_fit_berg_scherg(sphere, 200, 3) - logger.info('\nEquiv. model fitting -> RV = %g %%' % (100 * rv)) + logger.info("\nEquiv. model fitting -> RV = %g %%" % (100 * rv)) for k in range(3): - logger.info('mu%d = %g lambda%d = %g' - % (k + 1, sphere['mu'][k], k + 1, - sphere['layers'][-1]['sigma'] * - sphere['lambda'][k])) - logger.info('Set up EEG sphere model with scalp radius %7.1f mm\n' - % (1000 * head_radius,)) + logger.info( + "mu%d = %g lambda%d = %g" + % ( + k + 1, + sphere["mu"][k], + k + 1, + sphere["layers"][-1]["sigma"] * sphere["lambda"][k], + ) + ) + logger.info( + "Set up EEG sphere model with scalp radius %7.1f mm\n" + % (1000 * head_radius,) + ) return sphere # ############################################################################# # Sphere fitting + @verbose -def fit_sphere_to_headshape(info, dig_kinds='auto', units='m', verbose=None): +def fit_sphere_to_headshape(info, dig_kinds="auto", units="m", verbose=None): """Fit a sphere to the headshape points to determine head center. Parameters @@ -907,11 +995,10 @@ def fit_sphere_to_headshape(info, dig_kinds='auto', units='m', verbose=None): This function excludes any points that are low and frontal (``z < 0 and y > 0``) to improve the fit. """ - if not isinstance(units, str) or units not in ('m', 'mm'): + if not isinstance(units, str) or units not in ("m", "mm"): raise ValueError('units must be a "m" or "mm"') - radius, origin_head, origin_device = _fit_sphere_to_headshape( - info, dig_kinds) - if units == 'mm': + radius, origin_head, origin_device = _fit_sphere_to_headshape(info, dig_kinds) + if units == "mm": radius *= 1e3 origin_head *= 1e3 origin_device *= 1e3 @@ -919,8 +1006,7 @@ def fit_sphere_to_headshape(info, dig_kinds='auto', units='m', verbose=None): @verbose -def get_fitting_dig(info, dig_kinds='auto', exclude_frontal=True, - verbose=None): +def get_fitting_dig(info, dig_kinds="auto", exclude_frontal=True, verbose=None): """Get digitization points suitable for sphere fitting. Parameters @@ -946,17 +1032,18 @@ def get_fitting_dig(info, dig_kinds='auto', exclude_frontal=True, .. versionadded:: 0.14 """ _validate_type(info, "info") - if info['dig'] is None: - raise RuntimeError('Cannot fit headshape without digitization ' - ', info["dig"] is None') + if info["dig"] is None: + raise RuntimeError( + "Cannot fit headshape without digitization " ', info["dig"] is None' + ) if isinstance(dig_kinds, str): - if dig_kinds == 'auto': + if dig_kinds == "auto": # try "extra" first try: - return get_fitting_dig(info, 'extra') + return get_fitting_dig(info, "extra") except ValueError: pass - return get_fitting_dig(info, ('extra', 'eeg')) + return get_fitting_dig(info, ("extra", "eeg")) else: dig_kinds = (dig_kinds,) # convert string args to ints (first make dig_kinds mutable in case tuple) @@ -964,19 +1051,21 @@ def get_fitting_dig(info, dig_kinds='auto', exclude_frontal=True, for di, d in enumerate(dig_kinds): dig_kinds[di] = _dig_kind_dict.get(d, d) if dig_kinds[di] not in _dig_kind_ints: - raise ValueError('dig_kinds[#%d] (%s) must be one of %s' - % (di, d, sorted(list(_dig_kind_dict.keys())))) + raise ValueError( + "dig_kinds[#%d] (%s) must be one of %s" + % (di, d, sorted(list(_dig_kind_dict.keys()))) + ) # get head digization points of the specified kind(s) - dig = [p for p in info['dig'] if p['kind'] in dig_kinds] + dig = [p for p in info["dig"] if p["kind"] in dig_kinds] if len(dig) == 0: - raise ValueError( - f'No digitization points found for dig_kinds={dig_kinds}') - if any(p['coord_frame'] != FIFF.FIFFV_COORD_HEAD for p in dig): + raise ValueError(f"No digitization points found for dig_kinds={dig_kinds}") + if any(p["coord_frame"] != FIFF.FIFFV_COORD_HEAD for p in dig): raise RuntimeError( - f'Digitization points dig_kinds={dig_kinds} not in head ' - 'coordinates, contact mne-python developers') - hsp = [p['r'] for p in dig] + f"Digitization points dig_kinds={dig_kinds} not in head " + "coordinates, contact mne-python developers" + ) + hsp = [p["r"] for p in dig] del dig # exclude some frontal points (nose etc.) @@ -985,14 +1074,16 @@ def get_fitting_dig(info, dig_kinds='auto', exclude_frontal=True, hsp = np.array(hsp) if len(hsp) <= 10: - kinds_str = ', '.join(['"%s"' % _dig_kind_rev[d] - for d in sorted(dig_kinds)]) - msg = ('Only %s head digitization points of the specified kind%s (%s,)' - % (len(hsp), _pl(dig_kinds), kinds_str)) + kinds_str = ", ".join(['"%s"' % _dig_kind_rev[d] for d in sorted(dig_kinds)]) + msg = "Only %s head digitization points of the specified kind%s (%s,)" % ( + len(hsp), + _pl(dig_kinds), + kinds_str, + ) if len(hsp) < 4: - raise ValueError(msg + ', at least 4 required') + raise ValueError(msg + ", at least 4 required") else: - warn(msg + ', fitting may be inaccurate') + warn(msg + ", fitting may be inaccurate") return hsp @@ -1002,33 +1093,39 @@ def _fit_sphere_to_headshape(info, dig_kinds, verbose=None): hsp = get_fitting_dig(info, dig_kinds) radius, origin_head = _fit_sphere(np.array(hsp), disp=False) # compute origin in device coordinates - dev_head_t = info['dev_head_t'] + dev_head_t = info["dev_head_t"] if dev_head_t is None: - dev_head_t = Transform('meg', 'head') - head_to_dev = _ensure_trans(dev_head_t, 'head', 'meg') + dev_head_t = Transform("meg", "head") + head_to_dev = _ensure_trans(dev_head_t, "head", "meg") origin_device = apply_trans(head_to_dev, origin_head) - logger.info('Fitted sphere radius:'.ljust(30) + '%0.1f mm' - % (radius * 1e3,)) + logger.info("Fitted sphere radius:".ljust(30) + "%0.1f mm" % (radius * 1e3,)) _check_head_radius(radius) # > 2 cm away from head center in X or Y is strange if np.linalg.norm(origin_head[:2]) > 0.02: - warn('(X, Y) fit (%0.1f, %0.1f) more than 20 mm from ' - 'head frame origin' % tuple(1e3 * origin_head[:2])) - logger.info('Origin head coordinates:'.ljust(30) + - '%0.1f %0.1f %0.1f mm' % tuple(1e3 * origin_head)) - logger.info('Origin device coordinates:'.ljust(30) + - '%0.1f %0.1f %0.1f mm' % tuple(1e3 * origin_device)) + warn( + "(X, Y) fit (%0.1f, %0.1f) more than 20 mm from " + "head frame origin" % tuple(1e3 * origin_head[:2]) + ) + logger.info( + "Origin head coordinates:".ljust(30) + + "%0.1f %0.1f %0.1f mm" % tuple(1e3 * origin_head) + ) + logger.info( + "Origin device coordinates:".ljust(30) + + "%0.1f %0.1f %0.1f mm" % tuple(1e3 * origin_device) + ) return radius, origin_head, origin_device -def _fit_sphere(points, disp='auto'): +def _fit_sphere(points, disp="auto"): """Fit a sphere to an arbitrary set of points.""" from scipy.optimize import fmin_cobyla - if isinstance(disp, str) and disp == 'auto': + + if isinstance(disp, str) and disp == "auto": disp = True if logger.level <= 20 else False # initial guess for center and radius - radii = (np.max(points, axis=1) - np.min(points, axis=1)) / 2. + radii = (np.max(points, axis=1) - np.min(points, axis=1)) / 2.0 radius_init = radii.mean() center_init = np.median(points, axis=0) @@ -1043,38 +1140,46 @@ def cost_fun(center_rad): def constraint(center_rad): return center_rad[3] # radius must be >= 0 - x_opt = fmin_cobyla(cost_fun, x0, constraint, rhobeg=radius_init, - rhoend=radius_init * 1e-6, disp=disp) + x_opt = fmin_cobyla( + cost_fun, + x0, + constraint, + rhobeg=radius_init, + rhoend=radius_init * 1e-6, + disp=disp, + ) origin, radius = x_opt[:3], x_opt[3] return radius, origin -def _check_origin(origin, info, coord_frame='head', disp=False): +def _check_origin(origin, info, coord_frame="head", disp=False): """Check or auto-determine the origin.""" if isinstance(origin, str): - if origin != 'auto': - raise ValueError('origin must be a numerical array, or "auto", ' - 'not %s' % (origin,)) - if coord_frame == 'head': + if origin != "auto": + raise ValueError( + 'origin must be a numerical array, or "auto", ' "not %s" % (origin,) + ) + if coord_frame == "head": R, origin = fit_sphere_to_headshape( - info, verbose=_verbose_safe_false(), units='m')[:2] - logger.info(' Automatic origin fit: head of radius %0.1f mm' - % (R * 1000.,)) + info, verbose=_verbose_safe_false(), units="m" + )[:2] + logger.info( + " Automatic origin fit: head of radius %0.1f mm" % (R * 1000.0,) + ) del R else: - origin = (0., 0., 0.) + origin = (0.0, 0.0, 0.0) origin = np.array(origin, float) if origin.shape != (3,): - raise ValueError('origin must be a 3-element array') + raise ValueError("origin must be a 3-element array") if disp: - origin_str = ', '.join(['%0.1f' % (o * 1000) for o in origin]) - msg = (' Using origin %s mm in the %s frame' - % (origin_str, coord_frame)) - if coord_frame == 'meg' and info['dev_head_t'] is not None: - o_dev = apply_trans(info['dev_head_t'], origin) - origin_str = ', '.join('%0.1f' % (o * 1000,) for o in o_dev) - msg += ' (%s mm in the head frame)' % (origin_str,) + origin_str = ", ".join(["%0.1f" % (o * 1000) for o in origin]) + msg = " Using origin %s mm in the %s frame" % (origin_str, coord_frame) + if coord_frame == "meg" and info["dev_head_t"] is not None: + o_dev = apply_trans(info["dev_head_t"], origin) + origin_str = ", ".join("%0.1f" % (o * 1000,) for o in o_dev) + msg += " (%s mm in the head frame)" % (origin_str,) logger.info(msg) return origin @@ -1082,11 +1187,22 @@ def _check_origin(origin, info, coord_frame='head', disp=False): # ############################################################################ # Create BEM surfaces + @verbose -def make_watershed_bem(subject, subjects_dir=None, overwrite=False, - volume='T1', atlas=False, gcaatlas=False, preflood=None, - show=False, copy=True, T1=None, brainmask='ws.mgz', - verbose=None): +def make_watershed_bem( + subject, + subjects_dir=None, + overwrite=False, + volume="T1", + atlas=False, + gcaatlas=False, + preflood=None, + show=False, + copy=True, + T1=None, + brainmask="ws.mgz", + verbose=None, +): """Create BEM surfaces using the FreeSurfer watershed algorithm. Parameters @@ -1141,78 +1257,97 @@ def make_watershed_bem(subject, subjects_dir=None, overwrite=False, .. versionadded:: 0.10 """ from .viz.misc import plot_bem + env, mri_dir, bem_dir = _prepare_env(subject, subjects_dir) tempdir = _TempDir() # fsl and Freesurfer create some random junk in CWD - run_subprocess_env = partial(run_subprocess, env=env, - cwd=tempdir) + run_subprocess_env = partial(run_subprocess, env=env, cwd=tempdir) - subjects_dir = env['SUBJECTS_DIR'] # Set by _prepare_env() above. + subjects_dir = env["SUBJECTS_DIR"] # Set by _prepare_env() above. subject_dir = op.join(subjects_dir, subject) - ws_dir = op.join(bem_dir, 'watershed') + ws_dir = op.join(bem_dir, "watershed") T1_dir = op.join(mri_dir, volume) T1_mgz = T1_dir - if not T1_dir.endswith('.mgz'): - T1_mgz += '.mgz' + if not T1_dir.endswith(".mgz"): + T1_mgz += ".mgz" if not op.isdir(bem_dir): os.makedirs(bem_dir) - _check_fname(T1_mgz, overwrite='read', must_exist=True, name='MRI data') + _check_fname(T1_mgz, overwrite="read", must_exist=True, name="MRI data") if op.isdir(ws_dir): if not overwrite: - raise RuntimeError('%s already exists. Use the --overwrite option' - ' to recreate it.' % ws_dir) + raise RuntimeError( + "%s already exists. Use the --overwrite option" + " to recreate it." % ws_dir + ) else: shutil.rmtree(ws_dir) # put together the command - cmd = ['mri_watershed'] + cmd = ["mri_watershed"] if preflood: cmd += ["-h", "%s" % int(preflood)] if T1 is None: T1 = gcaatlas if T1: - cmd += ['-T1'] + cmd += ["-T1"] if gcaatlas: - fname = op.join(env['FREESURFER_HOME'], 'average', - 'RB_all_withskull_*.gca') + fname = op.join(env["FREESURFER_HOME"], "average", "RB_all_withskull_*.gca") fname = sorted(glob.glob(fname))[::-1][0] - logger.info('Using GCA atlas: %s' % (fname,)) - cmd += ['-atlas', '-brain_atlas', fname, - subject_dir + '/mri/transforms/talairach_with_skull.lta'] + logger.info("Using GCA atlas: %s" % (fname,)) + cmd += [ + "-atlas", + "-brain_atlas", + fname, + subject_dir + "/mri/transforms/talairach_with_skull.lta", + ] elif atlas: - cmd += ['-atlas'] + cmd += ["-atlas"] if op.exists(T1_mgz): - cmd += ['-useSRAS', '-surf', op.join(ws_dir, subject), T1_mgz, - op.join(ws_dir, brainmask)] + cmd += [ + "-useSRAS", + "-surf", + op.join(ws_dir, subject), + T1_mgz, + op.join(ws_dir, brainmask), + ] else: - cmd += ['-useSRAS', '-surf', op.join(ws_dir, subject), T1_dir, - op.join(ws_dir, brainmask)] + cmd += [ + "-useSRAS", + "-surf", + op.join(ws_dir, subject), + T1_dir, + op.join(ws_dir, brainmask), + ] # report and run - logger.info('\nRunning mri_watershed for BEM segmentation with the ' - 'following parameters:\n\nResults dir = %s\nCommand = %s\n' - % (ws_dir, ' '.join(cmd))) + logger.info( + "\nRunning mri_watershed for BEM segmentation with the " + "following parameters:\n\nResults dir = %s\nCommand = %s\n" + % (ws_dir, " ".join(cmd)) + ) os.makedirs(op.join(ws_dir)) run_subprocess_env(cmd) del tempdir # clean up directory if op.isfile(T1_mgz): new_info = _extract_volume_info(T1_mgz) if not new_info: - warn('nibabel is not available or the volume info is invalid.' - 'Volume info not updated in the written surface.') - surfs = ['brain', 'inner_skull', 'outer_skull', 'outer_skin'] + warn( + "nibabel is not available or the volume info is invalid." + "Volume info not updated in the written surface." + ) + surfs = ["brain", "inner_skull", "outer_skull", "outer_skin"] for s in surfs: - surf_ws_out = op.join(ws_dir, '%s_%s_surface' % (subject, s)) + surf_ws_out = op.join(ws_dir, "%s_%s_surface" % (subject, s)) - rr, tris, volume_info = read_surface(surf_ws_out, - read_metadata=True) + rr, tris, volume_info = read_surface(surf_ws_out, read_metadata=True) # replace volume info, 'head' stays volume_info.update(new_info) - write_surface(surf_ws_out, rr, tris, volume_info=volume_info, - overwrite=True) + write_surface( + surf_ws_out, rr, tris, volume_info=volume_info, overwrite=True + ) # Create symbolic links - surf_out = op.join(bem_dir, '%s.surf' % s) + surf_out = op.join(bem_dir, "%s.surf" % s) if not overwrite and op.exists(surf_out): skip_symlink = True else: @@ -1222,48 +1357,60 @@ def make_watershed_bem(subject, subjects_dir=None, overwrite=False, skip_symlink = False if skip_symlink: - logger.info("Unable to create all symbolic links to .surf files " - "in bem folder. Use --overwrite option to recreate " - "them.") - dest = op.join(bem_dir, 'watershed') + logger.info( + "Unable to create all symbolic links to .surf files " + "in bem folder. Use --overwrite option to recreate " + "them." + ) + dest = op.join(bem_dir, "watershed") else: logger.info("Symbolic links to .surf files created in bem folder") dest = bem_dir - logger.info("\nThank you for waiting.\nThe BEM triangulations for this " - "subject are now available at:\n%s." % dest) + logger.info( + "\nThank you for waiting.\nThe BEM triangulations for this " + "subject are now available at:\n%s." % dest + ) # Write a head file for coregistration - fname_head = op.join(bem_dir, subject + '-head.fif') + fname_head = op.join(bem_dir, subject + "-head.fif") if op.isfile(fname_head): os.remove(fname_head) - surf = _surfaces_to_bem([op.join(ws_dir, subject + '_outer_skin_surface')], - [FIFF.FIFFV_BEM_SURF_ID_HEAD], sigmas=[1]) + surf = _surfaces_to_bem( + [op.join(ws_dir, subject + "_outer_skin_surface")], + [FIFF.FIFFV_BEM_SURF_ID_HEAD], + sigmas=[1], + ) write_bem_surfaces(fname_head, surf) # Show computed BEM surfaces if show: - plot_bem(subject=subject, subjects_dir=subjects_dir, - orientation='coronal', slices=None, show=True) + plot_bem( + subject=subject, + subjects_dir=subjects_dir, + orientation="coronal", + slices=None, + show=True, + ) - logger.info('Created %s\n\nComplete.' % (fname_head,)) + logger.info("Created %s\n\nComplete." % (fname_head,)) def _extract_volume_info(mgz): """Extract volume info from a mgz file.""" nib = _import_nibabel() header = nib.load(mgz).header - version = header['version'] + version = header["version"] vol_info = dict() if version == 1: - version = '%s # volume info valid' % version - vol_info['valid'] = version - vol_info['filename'] = mgz - vol_info['volume'] = header['dims'][:3] - vol_info['voxelsize'] = header['delta'] - vol_info['xras'], vol_info['yras'], vol_info['zras'] = header['Mdc'] - vol_info['cras'] = header['Pxyz_c'] + version = "%s # volume info valid" % version + vol_info["valid"] = version + vol_info["filename"] = mgz + vol_info["volume"] = header["dims"][:3] + vol_info["voxelsize"] = header["delta"] + vol_info["xras"], vol_info["yras"], vol_info["zras"] = header["Mdc"] + vol_info["cras"] = header["Pxyz_c"] return vol_info @@ -1271,9 +1418,11 @@ def _extract_volume_info(mgz): # ############################################################################ # Read + @verbose -def read_bem_surfaces(fname, patch_stats=False, s_id=None, on_defects='raise', - verbose=None): +def read_bem_surfaces( + fname, patch_stats=False, s_id=None, on_defects="raise", verbose=None +): """Read the BEM surfaces from a FIF file. Parameters @@ -1302,16 +1451,16 @@ def read_bem_surfaces(fname, patch_stats=False, s_id=None, on_defects='raise', write_bem_surfaces, write_bem_solution, make_bem_model """ # Open the file, create directory - _validate_type(s_id, ('int-like', None), 's_id') - fname = _check_fname(fname, 'read', True, 'fname') + _validate_type(s_id, ("int-like", None), "s_id") + fname = _check_fname(fname, "read", True, "fname") if fname.suffix == ".h5": surf = _read_bem_surfaces_h5(fname, s_id) else: surf = _read_bem_surfaces_fif(fname, s_id) if s_id is not None and len(surf) != 1: - raise ValueError('surface with id %d not found' % s_id) + raise ValueError("surface with id %d not found" % s_id) for this in surf: - if patch_stats or this['nn'] is None: + if patch_stats or this["nn"] is None: _check_complete_surface(this, incomplete=on_defects) return surf[0] if s_id is not None else surf @@ -1320,12 +1469,12 @@ def _read_bem_surfaces_h5(fname, s_id): read_hdf5, _ = _import_h5io_funcs() bem = read_hdf5(fname) try: - [s['id'] for s in bem['surfs']] + [s["id"] for s in bem["surfs"]] except Exception: # not our format - raise ValueError('BEM data not found') - surf = bem['surfs'] + raise ValueError("BEM data not found") + surf = bem["surfs"] if s_id is not None: - surf = [s for s in surf if s['id'] == s_id] + surf = [s for s in surf if s["id"] == s_id] return surf @@ -1337,32 +1486,33 @@ def _read_bem_surfaces_fif(fname, s_id): # Find BEM bem = dir_tree_find(tree, FIFF.FIFFB_BEM) if bem is None or len(bem) == 0: - raise ValueError('BEM data not found') + raise ValueError("BEM data not found") bem = bem[0] # Locate all surfaces bemsurf = dir_tree_find(bem, FIFF.FIFFB_BEM_SURF) if bemsurf is None: - raise ValueError('BEM surface data not found') + raise ValueError("BEM surface data not found") - logger.info(' %d BEM surfaces found' % len(bemsurf)) + logger.info(" %d BEM surfaces found" % len(bemsurf)) # Coordinate frame possibly at the top level tag = find_tag(fid, bem, FIFF.FIFF_BEM_COORD_FRAME) if tag is not None: coord_frame = tag.data # Read all surfaces if s_id is not None: - surf = [_read_bem_surface(fid, bsurf, coord_frame, s_id) - for bsurf in bemsurf] + surf = [ + _read_bem_surface(fid, bsurf, coord_frame, s_id) for bsurf in bemsurf + ] surf = [s for s in surf if s is not None] else: surf = list() for bsurf in bemsurf: - logger.info(' Reading a surface...') + logger.info(" Reading a surface...") this = _read_bem_surface(fid, bsurf, coord_frame) surf.append(this) - logger.info('[done]') - logger.info(' %d BEM surfaces read' % len(surf)) + logger.info("[done]") + logger.info(" %d BEM surfaces read" % len(surf)) return surf @@ -1374,63 +1524,63 @@ def _read_bem_surface(fid, this, def_coord_frame, s_id=None): tag = find_tag(fid, this, FIFF.FIFF_BEM_SURF_ID) if tag is None: - res['id'] = FIFF.FIFFV_BEM_SURF_ID_UNKNOWN + res["id"] = FIFF.FIFFV_BEM_SURF_ID_UNKNOWN else: - res['id'] = int(tag.data.item()) + res["id"] = int(tag.data.item()) - if s_id is not None and res['id'] != s_id: + if s_id is not None and res["id"] != s_id: return None tag = find_tag(fid, this, FIFF.FIFF_BEM_SIGMA) - res['sigma'] = 1.0 if tag is None else float(tag.data.item()) + res["sigma"] = 1.0 if tag is None else float(tag.data.item()) tag = find_tag(fid, this, FIFF.FIFF_BEM_SURF_NNODE) if tag is None: - raise ValueError('Number of vertices not found') + raise ValueError("Number of vertices not found") - res['np'] = int(tag.data.item()) + res["np"] = int(tag.data.item()) tag = find_tag(fid, this, FIFF.FIFF_BEM_SURF_NTRI) if tag is None: - raise ValueError('Number of triangles not found') - res['ntri'] = int(tag.data.item()) + raise ValueError("Number of triangles not found") + res["ntri"] = int(tag.data.item()) tag = find_tag(fid, this, FIFF.FIFF_MNE_COORD_FRAME) if tag is None: tag = find_tag(fid, this, FIFF.FIFF_BEM_COORD_FRAME) if tag is None: - res['coord_frame'] = def_coord_frame + res["coord_frame"] = def_coord_frame else: - res['coord_frame'] = int(tag.data.item()) + res["coord_frame"] = int(tag.data.item()) else: - res['coord_frame'] = int(tag.data.item()) + res["coord_frame"] = int(tag.data.item()) # Vertices, normals, and triangles tag = find_tag(fid, this, FIFF.FIFF_BEM_SURF_NODES) if tag is None: - raise ValueError('Vertex data not found') + raise ValueError("Vertex data not found") - res['rr'] = tag.data.astype(np.float64) - if res['rr'].shape[0] != res['np']: - raise ValueError('Vertex information is incorrect') + res["rr"] = tag.data.astype(np.float64) + if res["rr"].shape[0] != res["np"]: + raise ValueError("Vertex information is incorrect") tag = find_tag(fid, this, FIFF.FIFF_MNE_SOURCE_SPACE_NORMALS) if tag is None: tag = find_tag(fid, this, FIFF.FIFF_BEM_SURF_NORMALS) if tag is None: - res['nn'] = None + res["nn"] = None else: - res['nn'] = tag.data.astype(np.float64) - if res['nn'].shape[0] != res['np']: - raise ValueError('Vertex normal information is incorrect') + res["nn"] = tag.data.astype(np.float64) + if res["nn"].shape[0] != res["np"]: + raise ValueError("Vertex normal information is incorrect") tag = find_tag(fid, this, FIFF.FIFF_BEM_SURF_TRIANGLES) if tag is None: - raise ValueError('Triangulation not found') + raise ValueError("Triangulation not found") - res['tris'] = tag.data - 1 # index start at 0 in Python - if res['tris'].shape[0] != res['ntri']: - raise ValueError('Triangulation information is incorrect') + res["tris"] = tag.data - 1 # index start at 0 in Python + if res["tris"].shape[0] != res["ntri"]: + raise ValueError("Triangulation information is incorrect") return res @@ -1457,97 +1607,104 @@ def read_bem_solution(fname, *, verbose=None): make_bem_solution write_bem_solution """ - fname = _check_fname(fname, 'read', True, 'fname') + fname = _check_fname(fname, "read", True, "fname") # mirrors fwd_bem_load_surfaces from fwd_bem_model.c if fname.suffix == ".h5": read_hdf5, _ = _import_h5io_funcs() - logger.info('Loading surfaces and solution...') + logger.info("Loading surfaces and solution...") bem = read_hdf5(fname) - if 'solver' not in bem: - bem['solver'] = 'mne' + if "solver" not in bem: + bem["solver"] = "mne" else: bem = _read_bem_solution_fif(fname) - if len(bem['surfs']) == 3: - logger.info('Three-layer model surfaces loaded.') - needed = np.array([FIFF.FIFFV_BEM_SURF_ID_HEAD, - FIFF.FIFFV_BEM_SURF_ID_SKULL, - FIFF.FIFFV_BEM_SURF_ID_BRAIN]) - if not all(x['id'] in needed for x in bem['surfs']): - raise RuntimeError('Could not find necessary BEM surfaces') + if len(bem["surfs"]) == 3: + logger.info("Three-layer model surfaces loaded.") + needed = np.array( + [ + FIFF.FIFFV_BEM_SURF_ID_HEAD, + FIFF.FIFFV_BEM_SURF_ID_SKULL, + FIFF.FIFFV_BEM_SURF_ID_BRAIN, + ] + ) + if not all(x["id"] in needed for x in bem["surfs"]): + raise RuntimeError("Could not find necessary BEM surfaces") # reorder surfaces as necessary (shouldn't need to?) reorder = [None] * 3 - for x in bem['surfs']: - reorder[np.where(x['id'] == needed)[0][0]] = x - bem['surfs'] = reorder - elif len(bem['surfs']) == 1: - if not bem['surfs'][0]['id'] == FIFF.FIFFV_BEM_SURF_ID_BRAIN: - raise RuntimeError('BEM Surfaces not found') - logger.info('Homogeneous model surface loaded.') - - assert set(bem.keys()) == set( - ('surfs', 'solution', 'bem_method', 'solver')) + for x in bem["surfs"]: + reorder[np.where(x["id"] == needed)[0][0]] = x + bem["surfs"] = reorder + elif len(bem["surfs"]) == 1: + if not bem["surfs"][0]["id"] == FIFF.FIFFV_BEM_SURF_ID_BRAIN: + raise RuntimeError("BEM Surfaces not found") + logger.info("Homogeneous model surface loaded.") + + assert set(bem.keys()) == set(("surfs", "solution", "bem_method", "solver")) bem = ConductorModel(bem) - bem['is_sphere'] = False + bem["is_sphere"] = False # sanity checks and conversions _check_option( - 'BEM approximation method', bem['bem_method'], - (FIFF.FIFFV_BEM_APPROX_LINEAR,)) # CONSTANT not supported + "BEM approximation method", bem["bem_method"], (FIFF.FIFFV_BEM_APPROX_LINEAR,) + ) # CONSTANT not supported dim = 0 - solver = bem.get('solver', 'mne') - _check_option('BEM solver', solver, ('mne', 'openmeeg')) - for si, surf in enumerate(bem['surfs']): - assert bem['bem_method'] == FIFF.FIFFV_BEM_APPROX_LINEAR - dim += surf['np'] - if solver == 'openmeeg' and si != 0: - dim += surf['ntri'] - dims = bem['solution'].shape + solver = bem.get("solver", "mne") + _check_option("BEM solver", solver, ("mne", "openmeeg")) + for si, surf in enumerate(bem["surfs"]): + assert bem["bem_method"] == FIFF.FIFFV_BEM_APPROX_LINEAR + dim += surf["np"] + if solver == "openmeeg" and si != 0: + dim += surf["ntri"] + dims = bem["solution"].shape if solver == "openmeeg": sz = (dim * (dim + 1)) // 2 if len(dims) != 1 or dims[0] != sz: raise RuntimeError( - 'For the given BEM surfaces, OpenMEEG should produce a ' - f'solution matrix of shape ({sz},) but got {dims}') - bem['nsol'] = dim + "For the given BEM surfaces, OpenMEEG should produce a " + f"solution matrix of shape ({sz},) but got {dims}" + ) + bem["nsol"] = dim else: if len(dims) != 2 and solver != "openmeeg": - raise RuntimeError('Expected a two-dimensional solution matrix ' - 'instead of a %d dimensional one' % dims[0]) + raise RuntimeError( + "Expected a two-dimensional solution matrix " + "instead of a %d dimensional one" % dims[0] + ) if dims[0] != dim or dims[1] != dim: - raise RuntimeError('Expected a %d x %d solution matrix instead of ' - 'a %d x %d one' % (dim, dim, dims[1], dims[0])) - bem['nsol'] = bem['solution'].shape[0] + raise RuntimeError( + "Expected a %d x %d solution matrix instead of " + "a %d x %d one" % (dim, dim, dims[1], dims[0]) + ) + bem["nsol"] = bem["solution"].shape[0] # Gamma factors and multipliers _add_gamma_multipliers(bem) - extra = f'made by {solver}' if solver != 'mne' else '' - logger.info(f'Loaded linear collocation BEM solution{extra} from {fname}') + extra = f"made by {solver}" if solver != "mne" else "" + logger.info(f"Loaded linear collocation BEM solution{extra} from {fname}") return bem def _read_bem_solution_fif(fname): - logger.info('Loading surfaces...') - surfs = read_bem_surfaces( - fname, patch_stats=True, verbose=_verbose_safe_false()) + logger.info("Loading surfaces...") + surfs = read_bem_surfaces(fname, patch_stats=True, verbose=_verbose_safe_false()) # convert from surfaces to solution - logger.info('\nLoading the solution matrix...\n') - solver = 'mne' + logger.info("\nLoading the solution matrix...\n") + solver = "mne" f, tree, _ = fiff_open(fname) with f as fid: # Find the BEM data nodes = dir_tree_find(tree, FIFF.FIFFB_BEM) if len(nodes) == 0: - raise RuntimeError('No BEM data in %s' % fname) + raise RuntimeError("No BEM data in %s" % fname) bem_node = nodes[0] # Approximation method tag = find_tag(f, bem_node, FIFF.FIFF_DESCRIPTION) if tag is not None: tag = json.loads(tag.data) - solver = tag['solver'] + solver = tag["solver"] tag = find_tag(f, bem_node, FIFF.FIFF_BEM_APPROX) if tag is None: - raise RuntimeError('No BEM solution found in %s' % fname) + raise RuntimeError("No BEM solution found in %s" % fname) method = tag.data[0] tag = find_tag(fid, bem_node, FIFF.FIFF_BEM_POT_SOLUTION) sol = tag.data @@ -1557,73 +1714,77 @@ def _read_bem_solution_fif(fname): def _add_gamma_multipliers(bem): """Add gamma and multipliers in-place.""" - bem['sigma'] = np.array([surf['sigma'] for surf in bem['surfs']]) + bem["sigma"] = np.array([surf["sigma"] for surf in bem["surfs"]]) # Dirty trick for the zero conductivity outside - sigma = np.r_[0.0, bem['sigma']] - bem['source_mult'] = 2.0 / (sigma[1:] + sigma[:-1]) - bem['field_mult'] = sigma[1:] - sigma[:-1] + sigma = np.r_[0.0, bem["sigma"]] + bem["source_mult"] = 2.0 / (sigma[1:] + sigma[:-1]) + bem["field_mult"] = sigma[1:] - sigma[:-1] # make sure subsequent "zip"s work correctly - assert len(bem['surfs']) == len(bem['field_mult']) - bem['gamma'] = ((sigma[1:] - sigma[:-1])[np.newaxis, :] / - (sigma[1:] + sigma[:-1])[:, np.newaxis]) + assert len(bem["surfs"]) == len(bem["field_mult"]) + bem["gamma"] = (sigma[1:] - sigma[:-1])[np.newaxis, :] / (sigma[1:] + sigma[:-1])[ + :, np.newaxis + ] # In our BEM code we do not model the CSF so we assign the innermost surface # the id BRAIN. Our 4-layer sphere we model CSF (at least by default), so when # searching for and referring to surfaces we need to keep track of this. -_sm_surf_dict = OrderedDict([ - ('brain', FIFF.FIFFV_BEM_SURF_ID_BRAIN), - ('inner_skull', FIFF.FIFFV_BEM_SURF_ID_CSF), - ('outer_skull', FIFF.FIFFV_BEM_SURF_ID_SKULL), - ('head', FIFF.FIFFV_BEM_SURF_ID_HEAD), -]) +_sm_surf_dict = OrderedDict( + [ + ("brain", FIFF.FIFFV_BEM_SURF_ID_BRAIN), + ("inner_skull", FIFF.FIFFV_BEM_SURF_ID_CSF), + ("outer_skull", FIFF.FIFFV_BEM_SURF_ID_SKULL), + ("head", FIFF.FIFFV_BEM_SURF_ID_HEAD), + ] +) _bem_surf_dict = { - 'inner_skull': FIFF.FIFFV_BEM_SURF_ID_BRAIN, - 'outer_skull': FIFF.FIFFV_BEM_SURF_ID_SKULL, - 'head': FIFF.FIFFV_BEM_SURF_ID_HEAD, + "inner_skull": FIFF.FIFFV_BEM_SURF_ID_BRAIN, + "outer_skull": FIFF.FIFFV_BEM_SURF_ID_SKULL, + "head": FIFF.FIFFV_BEM_SURF_ID_HEAD, } _bem_surf_name = { - FIFF.FIFFV_BEM_SURF_ID_BRAIN: 'inner skull', - FIFF.FIFFV_BEM_SURF_ID_SKULL: 'outer skull', - FIFF.FIFFV_BEM_SURF_ID_HEAD: 'outer skin ', - FIFF.FIFFV_BEM_SURF_ID_UNKNOWN: 'unknown ', + FIFF.FIFFV_BEM_SURF_ID_BRAIN: "inner skull", + FIFF.FIFFV_BEM_SURF_ID_SKULL: "outer skull", + FIFF.FIFFV_BEM_SURF_ID_HEAD: "outer skin ", + FIFF.FIFFV_BEM_SURF_ID_UNKNOWN: "unknown ", } _sm_surf_name = { - FIFF.FIFFV_BEM_SURF_ID_BRAIN: 'brain', - FIFF.FIFFV_BEM_SURF_ID_CSF: 'csf', - FIFF.FIFFV_BEM_SURF_ID_SKULL: 'outer skull', - FIFF.FIFFV_BEM_SURF_ID_HEAD: 'outer skin ', - FIFF.FIFFV_BEM_SURF_ID_UNKNOWN: 'unknown ', + FIFF.FIFFV_BEM_SURF_ID_BRAIN: "brain", + FIFF.FIFFV_BEM_SURF_ID_CSF: "csf", + FIFF.FIFFV_BEM_SURF_ID_SKULL: "outer skull", + FIFF.FIFFV_BEM_SURF_ID_HEAD: "outer skin ", + FIFF.FIFFV_BEM_SURF_ID_UNKNOWN: "unknown ", } def _bem_find_surface(bem, id_): """Find surface from already-loaded conductor model.""" - if bem['is_sphere']: + if bem["is_sphere"]: _surf_dict = _sm_surf_dict _name_dict = _sm_surf_name - kind = 'Sphere model' - tri = 'boundary' + kind = "Sphere model" + tri = "boundary" else: _surf_dict = _bem_surf_dict _name_dict = _bem_surf_name - kind = 'BEM' - tri = 'triangulation' + kind = "BEM" + tri = "triangulation" if isinstance(id_, str): name = id_ id_ = _surf_dict[id_] else: name = _name_dict[id_] - kind = 'Sphere model' if bem['is_sphere'] else 'BEM' - idx = np.where(np.array([s['id'] for s in bem['surfs']]) == id_)[0] + kind = "Sphere model" if bem["is_sphere"] else "BEM" + idx = np.where(np.array([s["id"] for s in bem["surfs"]]) == id_)[0] if len(idx) != 1: - raise RuntimeError(f'{kind} does not have the {name} {tri}') - return bem['surfs'][idx[0]] + raise RuntimeError(f"{kind} does not have the {name} {tri}") + return bem["surfs"][idx[0]] # ############################################################################ # Write + @verbose def write_bem_surfaces(fname, surfs, overwrite=False, *, verbose=None): """Write BEM surfaces to a FIF file. @@ -1639,7 +1800,7 @@ def write_bem_surfaces(fname, surfs, overwrite=False, *, verbose=None): """ if isinstance(surfs, dict): surfs = [surfs] - fname = _check_fname(fname, overwrite=overwrite, name='fname') + fname = _check_fname(fname, overwrite=overwrite, name="fname") if fname.suffix == ".h5": _, write_hdf5 = _import_h5io_funcs() @@ -1647,14 +1808,15 @@ def write_bem_surfaces(fname, surfs, overwrite=False, *, verbose=None): else: with start_and_end_file(fname) as fid: start_block(fid, FIFF.FIFFB_BEM) - write_int(fid, FIFF.FIFF_BEM_COORD_FRAME, surfs[0]['coord_frame']) + write_int(fid, FIFF.FIFF_BEM_COORD_FRAME, surfs[0]["coord_frame"]) _write_bem_surfaces_block(fid, surfs) end_block(fid, FIFF.FIFFB_BEM) @verbose -def write_head_bem(fname, rr, tris, on_defects='raise', overwrite=False, - *, verbose=None): +def write_head_bem( + fname, rr, tris, on_defects="raise", overwrite=False, *, verbose=None +): """Write a head surface to a FIF file. Parameters @@ -1670,9 +1832,13 @@ def write_head_bem(fname, rr, tris, on_defects='raise', overwrite=False, %(overwrite)s %(verbose)s """ - surf = _surfaces_to_bem([dict(rr=rr, tris=tris)], - [FIFF.FIFFV_BEM_SURF_ID_HEAD], [1], rescale=False, - incomplete=on_defects) + surf = _surfaces_to_bem( + [dict(rr=rr, tris=tris)], + [FIFF.FIFFV_BEM_SURF_ID_HEAD], + [1], + rescale=False, + incomplete=on_defects, + ) write_bem_surfaces(fname, surf, overwrite=overwrite) @@ -1680,17 +1846,16 @@ def _write_bem_surfaces_block(fid, surfs): """Write bem surfaces to open file handle.""" for surf in surfs: start_block(fid, FIFF.FIFFB_BEM_SURF) - write_float(fid, FIFF.FIFF_BEM_SIGMA, surf['sigma']) - write_int(fid, FIFF.FIFF_BEM_SURF_ID, surf['id']) - write_int(fid, FIFF.FIFF_MNE_COORD_FRAME, surf['coord_frame']) - write_int(fid, FIFF.FIFF_BEM_SURF_NNODE, surf['np']) - write_int(fid, FIFF.FIFF_BEM_SURF_NTRI, surf['ntri']) - write_float_matrix(fid, FIFF.FIFF_BEM_SURF_NODES, surf['rr']) + write_float(fid, FIFF.FIFF_BEM_SIGMA, surf["sigma"]) + write_int(fid, FIFF.FIFF_BEM_SURF_ID, surf["id"]) + write_int(fid, FIFF.FIFF_MNE_COORD_FRAME, surf["coord_frame"]) + write_int(fid, FIFF.FIFF_BEM_SURF_NNODE, surf["np"]) + write_int(fid, FIFF.FIFF_BEM_SURF_NTRI, surf["ntri"]) + write_float_matrix(fid, FIFF.FIFF_BEM_SURF_NODES, surf["rr"]) # index start at 0 in Python - write_int_matrix(fid, FIFF.FIFF_BEM_SURF_TRIANGLES, - surf['tris'] + 1) - if 'nn' in surf and surf['nn'] is not None and len(surf['nn']) > 0: - write_float_matrix(fid, FIFF.FIFF_BEM_SURF_NORMALS, surf['nn']) + write_int_matrix(fid, FIFF.FIFF_BEM_SURF_TRIANGLES, surf["tris"] + 1) + if "nn" in surf and surf["nn"] is not None and len(surf["nn"]) > 0: + write_float_matrix(fid, FIFF.FIFF_BEM_SURF_NORMALS, surf["nn"]) end_block(fid, FIFF.FIFFB_BEM_SURF) @@ -1711,42 +1876,40 @@ def write_bem_solution(fname, bem, overwrite=False, *, verbose=None): -------- read_bem_solution """ - fname = _check_fname(fname, overwrite=overwrite, name='fname') + fname = _check_fname(fname, overwrite=overwrite, name="fname") if fname.suffix == ".h5": _, write_hdf5 = _import_h5io_funcs() - bem = {k: bem[k] for k in ('surfs', 'solution', 'bem_method')} + bem = {k: bem[k] for k in ("surfs", "solution", "bem_method")} write_hdf5(fname, bem, overwrite=True) else: _write_bem_solution_fif(fname, bem) def _write_bem_solution_fif(fname, bem): - _check_bem_size(bem['surfs']) + _check_bem_size(bem["surfs"]) with start_and_end_file(fname) as fid: start_block(fid, FIFF.FIFFB_BEM) # Coordinate frame (mainly for backward compatibility) - write_int(fid, FIFF.FIFF_BEM_COORD_FRAME, - bem['surfs'][0]['coord_frame']) - solver = bem.get('solver', 'mne') - if solver != 'mne': - write_string( - fid, FIFF.FIFF_DESCRIPTION, json.dumps(dict(solver=solver))) + write_int(fid, FIFF.FIFF_BEM_COORD_FRAME, bem["surfs"][0]["coord_frame"]) + solver = bem.get("solver", "mne") + if solver != "mne": + write_string(fid, FIFF.FIFF_DESCRIPTION, json.dumps(dict(solver=solver))) # Surfaces - _write_bem_surfaces_block(fid, bem['surfs']) + _write_bem_surfaces_block(fid, bem["surfs"]) # The potential solution - if 'solution' in bem: + if "solution" in bem: _check_option( - 'bem_method', bem['bem_method'], - (FIFF.FIFFV_BEM_APPROX_LINEAR,)) + "bem_method", bem["bem_method"], (FIFF.FIFFV_BEM_APPROX_LINEAR,) + ) write_int(fid, FIFF.FIFF_BEM_APPROX, FIFF.FIFFV_BEM_APPROX_LINEAR) - write_float_matrix(fid, FIFF.FIFF_BEM_POT_SOLUTION, - bem['solution']) + write_float_matrix(fid, FIFF.FIFF_BEM_POT_SOLUTION, bem["solution"]) end_block(fid, FIFF.FIFFB_BEM) # ############################################################################# # Create 3-Layers BEM model from Flash MRI images + def _prepare_env(subject, subjects_dir): """Prepare an env object for subprocess calls.""" env = os.environ.copy() @@ -1758,18 +1921,19 @@ def _prepare_env(subject, subjects_dir): subjects_dir = get_subjects_dir(subjects_dir, raise_error=True) subject_dir = subjects_dir / subject if not subject_dir.is_dir(): - raise RuntimeError('Could not find the subject data directory "%s"' - % (subject_dir,)) - env.update(SUBJECT=subject, SUBJECTS_DIR=str(subjects_dir), - FREESURFER_HOME=fs_home) + raise RuntimeError( + 'Could not find the subject data directory "%s"' % (subject_dir,) + ) + env.update(SUBJECT=subject, SUBJECTS_DIR=str(subjects_dir), FREESURFER_HOME=fs_home) mri_dir = subject_dir / "mri" bem_dir = subject_dir / "bem" return env, mri_dir, bem_dir def _write_echos(mri_dir, flash_echos, angle): - nib = _import_nibabel('write echoes') + nib = _import_nibabel("write echoes") from nibabel.spatialimages import SpatialImage + if _path_like(flash_echos): flash_echos = nib.load(flash_echos) if isinstance(flash_echos, SpatialImage): @@ -1780,8 +1944,7 @@ def _write_echos(mri_dir, flash_echos, angle): data = data[..., np.newaxis] for echo_idx in range(data.shape[3]): this_echo_img = flash_echos.__class__( - data[..., echo_idx], affine=affine, - header=deepcopy(flash_echos.header) + data[..., echo_idx], affine=affine, header=deepcopy(flash_echos.header) ) flash_echo_imgs.append(this_echo_img) flash_echos = flash_echo_imgs @@ -1789,13 +1952,13 @@ def _write_echos(mri_dir, flash_echos, angle): for idx, flash_echo in enumerate(flash_echos, 1): if _path_like(flash_echo): flash_echo = nib.load(flash_echo) - nib.save(flash_echo, - op.join(mri_dir, 'flash', f'mef{angle}_{idx:03d}.mgz')) + nib.save(flash_echo, op.join(mri_dir, "flash", f"mef{angle}_{idx:03d}.mgz")) @verbose -def convert_flash_mris(subject, flash30=True, unwarp=False, - subjects_dir=None, flash5=True, verbose=None): +def convert_flash_mris( + subject, flash30=True, unwarp=False, subjects_dir=None, flash5=True, verbose=None +): """Synthesize the flash 5 files for use with make_flash_bem. This function aims to produce a synthesized flash 5 MRI from @@ -1843,32 +2006,30 @@ def convert_flash_mris(subject, flash30=True, unwarp=False, """ # noqa: E501 env, mri_dir = _prepare_env(subject, subjects_dir)[:2] tempdir = _TempDir() # fsl and Freesurfer create some random junk in CWD - run_subprocess_env = partial(run_subprocess, env=env, - cwd=tempdir) + run_subprocess_env = partial(run_subprocess, env=env, cwd=tempdir) mri_dir = Path(mri_dir) # Step 1a : Data conversion to mgz format flash_dir = mri_dir / "flash" - pm_dir = flash_dir / 'parameter_maps' + pm_dir = flash_dir / "parameter_maps" pm_dir.mkdir(parents=True, exist_ok=True) echos_done = 0 if not isinstance(flash5, bool): - _write_echos(mri_dir, flash5, angle='05') + _write_echos(mri_dir, flash5, angle="05") if not isinstance(flash30, bool): - _write_echos(mri_dir, flash30, angle='30') + _write_echos(mri_dir, flash30, angle="30") # Step 1b : Run grad_unwarp on converted files template = op.join(flash_dir, "mef*_*.mgz") files = sorted(glob.glob(template)) if len(files) == 0: - raise ValueError('No suitable source files found (%s)' % template) + raise ValueError("No suitable source files found (%s)" % template) if unwarp: logger.info("\n---- Unwarp mgz data sets ----") for infile in files: outfile = infile.replace(".mgz", "u.mgz") - cmd = ['grad_unwarp', '-i', infile, '-o', outfile, '-unwarp', - 'true'] + cmd = ["grad_unwarp", "-i", infile, "-o", outfile, "-unwarp", "true"] run_subprocess_env(cmd) # Clear parameter maps if some of the data were reconverted if echos_done > 0 and pm_dir.exists(): @@ -1882,20 +2043,24 @@ def convert_flash_mris(subject, flash30=True, unwarp=False, if unwarp: files = sorted(glob.glob(op.join(flash_dir, "mef05_*u.mgz"))) if len(os.listdir(pm_dir)) == 0: - cmd = (['mri_ms_fitparms'] + files + [str(pm_dir)]) + cmd = ["mri_ms_fitparms"] + files + [str(pm_dir)] run_subprocess_env(cmd) else: logger.info("Parameter maps were already computed") # Step 3 : Synthesize the flash 5 images logger.info("\n---- Synthesizing flash 5 images ----") - if not (pm_dir / 'flash5.mgz').exists(): - cmd = ['mri_synthesize', '20', '5', '5', - (pm_dir / 'T1.mgz'), - (pm_dir / 'PD.mgz'), - (pm_dir / 'flash5.mgz') - ] + if not (pm_dir / "flash5.mgz").exists(): + cmd = [ + "mri_synthesize", + "20", + "5", + "5", + (pm_dir / "T1.mgz"), + (pm_dir / "PD.mgz"), + (pm_dir / "flash5.mgz"), + ] run_subprocess_env(cmd) - (pm_dir / 'flash5_reg.mgz').unlink() + (pm_dir / "flash5_reg.mgz").unlink() else: logger.info("Synthesized flash 5 volume is already there") else: @@ -1903,18 +2068,27 @@ def convert_flash_mris(subject, flash30=True, unwarp=False, template = "mef05_*u.mgz" if unwarp else "mef05_*.mgz" files = sorted(flash_dir.glob(template)) if len(files) == 0: - raise ValueError('No suitable source files found (%s)' % template) - cmd = (['mri_average', '-noconform'] + files + [pm_dir / 'flash5.mgz']) + raise ValueError("No suitable source files found (%s)" % template) + cmd = ["mri_average", "-noconform"] + files + [pm_dir / "flash5.mgz"] run_subprocess_env(cmd) - (pm_dir / 'flash5_reg.mgz').unlink(missing_ok=True) + (pm_dir / "flash5_reg.mgz").unlink(missing_ok=True) del tempdir # finally done running subprocesses - assert (pm_dir / 'flash5.mgz').exists() - return pm_dir / 'flash5.mgz' + assert (pm_dir / "flash5.mgz").exists() + return pm_dir / "flash5.mgz" @verbose -def make_flash_bem(subject, overwrite=False, show=True, subjects_dir=None, - copy=True, *, flash5_img=None, register=True, verbose=None): +def make_flash_bem( + subject, + overwrite=False, + show=True, + subjects_dir=None, + copy=True, + *, + flash5_img=None, + register=True, + verbose=None, +): """Create 3-Layer BEM model from prepared flash MRI images. Parameters @@ -1963,46 +2137,53 @@ def make_flash_bem(subject, overwrite=False, show=True, subjects_dir=None, env, mri_dir, bem_dir = _prepare_env(subject, subjects_dir) tempdir = _TempDir() # fsl and Freesurfer create some random junk in CWD - run_subprocess_env = partial(run_subprocess, env=env, - cwd=tempdir) + run_subprocess_env = partial(run_subprocess, env=env, cwd=tempdir) mri_dir = Path(mri_dir) bem_dir = Path(bem_dir) - subjects_dir = env['SUBJECTS_DIR'] - flash_path = (mri_dir / 'flash' / 'parameter_maps').resolve() + subjects_dir = env["SUBJECTS_DIR"] + flash_path = (mri_dir / "flash" / "parameter_maps").resolve() flash_path.mkdir(exist_ok=True, parents=True) - logger.info('\nProcessing the flash MRI data to produce BEM meshes with ' - 'the following parameters:\n' - 'SUBJECTS_DIR = %s\n' - 'SUBJECT = %s\n' - 'Result dir = %s\n' % (subjects_dir, subject, - bem_dir / 'flash')) + logger.info( + "\nProcessing the flash MRI data to produce BEM meshes with " + "the following parameters:\n" + "SUBJECTS_DIR = %s\n" + "SUBJECT = %s\n" + "Result dir = %s\n" % (subjects_dir, subject, bem_dir / "flash") + ) # Step 4 : Register with MPRAGE - flash5 = flash_path / 'flash5.mgz' + flash5 = flash_path / "flash5.mgz" if _path_like(flash5_img): logger.info(f"Copying flash 5 image {flash5_img} to {flash5}") - cmd = ['mri_convert', Path(flash5_img).resolve(), flash5] + cmd = ["mri_convert", Path(flash5_img).resolve(), flash5] run_subprocess_env(cmd) elif flash5_img is None: if not flash5.exists(): - raise ValueError(f'Flash 5 image cannot be found at {flash5}.') + raise ValueError(f"Flash 5 image cannot be found at {flash5}.") else: logger.info(f"Writing flash 5 image at {flash5}") - nib = _import_nibabel('write an MRI image') + nib = _import_nibabel("write an MRI image") nib.save(flash5_img, flash5) if register: logger.info("\n---- Registering flash 5 with T1 MPRAGE ----") - flash5_reg = flash_path / 'flash5_reg.mgz' + flash5_reg = flash_path / "flash5_reg.mgz" if not flash5_reg.exists(): - if (mri_dir / 'T1.mgz').exists(): - ref_volume = mri_dir / 'T1.mgz' + if (mri_dir / "T1.mgz").exists(): + ref_volume = mri_dir / "T1.mgz" else: - ref_volume = mri_dir / 'T1' - cmd = ['fsl_rigid_register', '-r', str(ref_volume), '-i', - str(flash5), '-o', str(flash5_reg)] + ref_volume = mri_dir / "T1" + cmd = [ + "fsl_rigid_register", + "-r", + str(ref_volume), + "-i", + str(flash5), + "-o", + str(flash5_reg), + ] run_subprocess_env(cmd) else: logger.info("Registered flash 5 image is already there") @@ -2011,62 +2192,61 @@ def make_flash_bem(subject, overwrite=False, show=True, subjects_dir=None, # Step 5a : Convert flash5 into COR logger.info("\n---- Converting flash5 volume into COR format ----") - flash5_dir = mri_dir / 'flash5' + flash5_dir = mri_dir / "flash5" shutil.rmtree(flash5_dir, ignore_errors=True) flash5_dir.mkdir(exist_ok=True, parents=True) - cmd = ['mri_convert', flash5_reg, flash5_dir] + cmd = ["mri_convert", flash5_reg, flash5_dir] run_subprocess_env(cmd) # Step 5b and c : Convert the mgz volumes into COR convert_T1 = False - T1_dir = mri_dir / 'T1' - if not T1_dir.is_dir() or next(T1_dir.glob('COR*')) is None: + T1_dir = mri_dir / "T1" + if not T1_dir.is_dir() or next(T1_dir.glob("COR*")) is None: convert_T1 = True convert_brain = False - brain_dir = mri_dir / 'brain' - if not brain_dir.is_dir() or next(brain_dir.glob('COR*')) is None: + brain_dir = mri_dir / "brain" + if not brain_dir.is_dir() or next(brain_dir.glob("COR*")) is None: convert_brain = True logger.info("\n---- Converting T1 volume into COR format ----") if convert_T1: - T1_fname = mri_dir / 'T1.mgz' + T1_fname = mri_dir / "T1.mgz" if not T1_fname.is_file(): raise RuntimeError("Both T1 mgz and T1 COR volumes missing.") T1_dir.mkdir(exist_ok=True, parents=True) - cmd = ['mri_convert', T1_fname, T1_dir] + cmd = ["mri_convert", T1_fname, T1_dir] run_subprocess_env(cmd) else: logger.info("T1 volume is already in COR format") logger.info("\n---- Converting brain volume into COR format ----") if convert_brain: - brain_fname = mri_dir / 'brain.mgz' + brain_fname = mri_dir / "brain.mgz" if not brain_fname.is_file(): raise RuntimeError("Both brain mgz and brain COR volumes missing.") brain_dir.mkdir(exist_ok=True, parents=True) - cmd = ['mri_convert', brain_fname, brain_dir] + cmd = ["mri_convert", brain_fname, brain_dir] run_subprocess_env(cmd) else: logger.info("Brain volume is already in COR format") # Finally ready to go logger.info("\n---- Creating the BEM surfaces ----") - cmd = ['mri_make_bem_surfaces', subject] + cmd = ["mri_make_bem_surfaces", subject] run_subprocess_env(cmd) del tempdir # ran our last subprocess; clean up directory logger.info("\n---- Converting the tri files into surf files ----") - flash_bem_dir = bem_dir / 'flash' + flash_bem_dir = bem_dir / "flash" flash_bem_dir.mkdir(exist_ok=True, parents=True) - surfs = ['inner_skull', 'outer_skull', 'outer_skin'] + surfs = ["inner_skull", "outer_skull", "outer_skin"] for surf in surfs: - out_fname = flash_bem_dir / (surf + '.tri') - shutil.move(bem_dir / (surf + '.tri'), out_fname) + out_fname = flash_bem_dir / (surf + ".tri") + shutil.move(bem_dir / (surf + ".tri"), out_fname) nodes, tris = read_tri(out_fname, swap=True) # Do not write volume info here because the tris are already in # standard Freesurfer coords - write_surface(op.splitext(out_fname)[0] + '.surf', nodes, tris, - overwrite=True) + write_surface(op.splitext(out_fname)[0] + ".surf", nodes, tris, overwrite=True) # Cleanup section logger.info("\n---- Cleaning up ----") - (bem_dir / 'inner_skull_tmp.tri').unlink() + (bem_dir / "inner_skull_tmp.tri").unlink() if convert_T1: shutil.rmtree(T1_dir) logger.info("Deleted the T1 COR volume") @@ -2079,7 +2259,7 @@ def make_flash_bem(subject, overwrite=False, show=True, subjects_dir=None, logger.info("\n---- Creating symbolic links ----") # os.chdir(bem_dir) for surf in surfs: - surf = bem_dir / (surf + '.surf') + surf = bem_dir / (surf + ".surf") if not overwrite and surf.exists(): skip_symlink = True else: @@ -2088,28 +2268,38 @@ def make_flash_bem(subject, overwrite=False, show=True, subjects_dir=None, _symlink(flash_bem_dir / surf.name, surf, copy) skip_symlink = False if skip_symlink: - logger.info("Unable to create all symbolic links to .surf files " - "in bem folder. Use --overwrite option to recreate them.") - dest = bem_dir / 'flash' + logger.info( + "Unable to create all symbolic links to .surf files " + "in bem folder. Use --overwrite option to recreate them." + ) + dest = bem_dir / "flash" else: logger.info("Symbolic links to .surf files created in bem folder") dest = bem_dir - logger.info("\nThank you for waiting.\nThe BEM triangulations for this " - "subject are now available at:\n%s.\nWe hope the BEM meshes " - "created will facilitate your MEG and EEG data analyses." - % dest) + logger.info( + "\nThank you for waiting.\nThe BEM triangulations for this " + "subject are now available at:\n%s.\nWe hope the BEM meshes " + "created will facilitate your MEG and EEG data analyses." % dest + ) # Show computed BEM surfaces if show: - plot_bem(subject=subject, subjects_dir=subjects_dir, - orientation='coronal', slices=None, show=True) + plot_bem( + subject=subject, + subjects_dir=subjects_dir, + orientation="coronal", + slices=None, + show=True, + ) def _check_bem_size(surfs): """Check bem surface sizes.""" - if len(surfs) > 1 and surfs[0]['np'] > 10000: - warn('The bem surfaces have %s data points. 5120 (ico grade=4) ' - 'should be enough. Dense 3-layer bems may not save properly.' % - surfs[0]['np']) + if len(surfs) > 1 and surfs[0]["np"] > 10000: + warn( + "The bem surfaces have %s data points. 5120 (ico grade=4) " + "should be enough. Dense 3-layer bems may not save properly." + % surfs[0]["np"] + ) def _symlink(src, dest, copy=False): @@ -2119,40 +2309,41 @@ def _symlink(src, dest, copy=False): try: os.symlink(src_link, dest) except OSError: - warn('Could not create symbolic link %s. Check that your ' - 'partition handles symbolic links. The file will be copied ' - 'instead.' % dest) + warn( + "Could not create symbolic link %s. Check that your " + "partition handles symbolic links. The file will be copied " + "instead." % dest + ) copy = True if copy: shutil.copy(src, dest) -def _ensure_bem_surfaces(bem, extra_allow=(), name='bem'): +def _ensure_bem_surfaces(bem, extra_allow=(), name="bem"): # by default only allow path-like and list, but handle None and # ConductorModel properly if need be. Always return a ConductorModel # even though it's incomplete (and might have is_sphere=True). assert all(extra in (None, ConductorModel) for extra in extra_allow) - allowed = ('path-like', list) + extra_allow + allowed = ("path-like", list) + extra_allow _validate_type(bem, allowed, name) if isinstance(bem, path_like): # Load the surfaces - logger.info(f'Loading BEM surfaces from {str(bem)}...') + logger.info(f"Loading BEM surfaces from {str(bem)}...") bem = read_bem_surfaces(bem) bem = ConductorModel(is_sphere=False, surfs=bem) elif isinstance(bem, list): for ii, this_surf in enumerate(bem): - _validate_type(this_surf, dict, f'{name}[{ii}]') + _validate_type(this_surf, dict, f"{name}[{ii}]") if isinstance(bem, list): bem = ConductorModel(is_sphere=False, surfs=bem) # add surfaces in the spherical case - if isinstance(bem, ConductorModel) and bem['is_sphere']: + if isinstance(bem, ConductorModel) and bem["is_sphere"]: bem = bem.copy() - bem['surfs'] = [] - if len(bem['layers']) == 4: + bem["surfs"] = [] + if len(bem["layers"]) == 4: for idx, id_ in enumerate(_sm_surf_dict.values()): - bem['surfs'].append(_complete_sphere_surf( - bem, idx, 4, complete=False)) - bem['surfs'][-1]['id'] = id_ + bem["surfs"].append(_complete_sphere_surf(bem, idx, 4, complete=False)) + bem["surfs"][-1]["id"] = id_ return bem @@ -2160,7 +2351,7 @@ def _ensure_bem_surfaces(bem, extra_allow=(), name='bem'): def _check_file(fname, overwrite): """Prevent overwrites.""" if op.isfile(fname) and not overwrite: - raise OSError(f'File {fname} exists, use --overwrite to overwrite it') + raise OSError(f"File {fname} exists, use --overwrite to overwrite it") _tri_levels = dict( @@ -2170,9 +2361,17 @@ def _check_file(fname, overwrite): @verbose -def make_scalp_surfaces(subject, subjects_dir=None, force=True, - overwrite=False, no_decimate=False, *, - threshold=20, mri='T1.mgz', verbose=None): +def make_scalp_surfaces( + subject, + subjects_dir=None, + force=True, + overwrite=False, + no_decimate=False, + *, + threshold=20, + mri="T1.mgz", + verbose=None, +): """Create surfaces of the scalp and neck. The scalp surfaces are required for using the MNE coregistration GUI, and @@ -2204,22 +2403,24 @@ def make_scalp_surfaces(subject, subjects_dir=None, force=True, %(verbose)s """ subjects_dir = get_subjects_dir(subjects_dir, raise_error=True) - incomplete = 'warn' if force else 'raise' + incomplete = "warn" if force else "raise" subj_path = subjects_dir / subject if not subj_path.exists(): - raise RuntimeError('%s does not exist. Please check your subject ' - 'directory path.' % subj_path) + raise RuntimeError( + "%s does not exist. Please check your subject " + "directory path." % subj_path + ) # Backward compat for old FreeSurfer (?) - _validate_type(mri, str, 'mri') - if mri == 'T1.mgz': + _validate_type(mri, str, "mri") + if mri == "T1.mgz": mri = mri if (subj_path / "mri" / mri).exists() else "T1" - logger.info('1. Creating a dense scalp tessellation with mkheadsurf...') + logger.info("1. Creating a dense scalp tessellation with mkheadsurf...") def check_seghead(surf_path=subj_path / "surf"): surf = None - for k in ['lh.seghead', 'lh.smseghead']: + for k in ["lh.seghead", "lh.smseghead"]: this_surf = surf_path / k if this_surf.exists(): surf = this_surf @@ -2227,61 +2428,79 @@ def check_seghead(surf_path=subj_path / "surf"): return surf my_seghead = check_seghead() - threshold = _ensure_int(threshold, 'threshold') + threshold = _ensure_int(threshold, "threshold") if my_seghead is None: this_env = deepcopy(os.environ) - this_env['SUBJECTS_DIR'] = str(subjects_dir) - this_env['SUBJECT'] = subject - this_env['subjdir'] = str(subj_path) - if 'FREESURFER_HOME' not in this_env: + this_env["SUBJECTS_DIR"] = str(subjects_dir) + this_env["SUBJECT"] = subject + this_env["subjdir"] = str(subj_path) + if "FREESURFER_HOME" not in this_env: raise RuntimeError( - 'The FreeSurfer environment needs to be set up to use ' - 'make_scalp_surfaces to create the outer skin surface ' - 'lh.seghead') - run_subprocess([ - 'mkheadsurf', '-subjid', subject, '-srcvol', mri, - '-thresh1', str(threshold), - '-thresh2', str(threshold)], env=this_env) + "The FreeSurfer environment needs to be set up to use " + "make_scalp_surfaces to create the outer skin surface " + "lh.seghead" + ) + run_subprocess( + [ + "mkheadsurf", + "-subjid", + subject, + "-srcvol", + mri, + "-thresh1", + str(threshold), + "-thresh2", + str(threshold), + ], + env=this_env, + ) surf = check_seghead() if surf is None: - raise RuntimeError('mkheadsurf did not produce the standard output ' - 'file.') + raise RuntimeError("mkheadsurf did not produce the standard output " "file.") bem_dir = subjects_dir / subject / "bem" if not bem_dir.is_dir(): os.mkdir(bem_dir) fname_template = bem_dir / ("%s-head-{}.fif" % subject) - dense_fname = str(fname_template).format('dense') - logger.info('2. Creating %s ...' % dense_fname) + dense_fname = str(fname_template).format("dense") + logger.info("2. Creating %s ..." % dense_fname) _check_file(dense_fname, overwrite) # Helpful message if we get a topology error - msg = ('\n\nConsider using pymeshfix directly to fix the mesh, or --force ' - 'to ignore the problem.') + msg = ( + "\n\nConsider using pymeshfix directly to fix the mesh, or --force " + "to ignore the problem." + ) surf = _surfaces_to_bem( - [surf], [FIFF.FIFFV_BEM_SURF_ID_HEAD], [1], - incomplete=incomplete, extra=msg)[0] + [surf], [FIFF.FIFFV_BEM_SURF_ID_HEAD], [1], incomplete=incomplete, extra=msg + )[0] write_bem_surfaces(dense_fname, surf, overwrite=overwrite) - if os.getenv('_MNE_TESTING_SCALP', 'false') == 'true': - tris = [len(surf['tris'])] # don't actually decimate + if os.getenv("_MNE_TESTING_SCALP", "false") == "true": + tris = [len(surf["tris"])] # don't actually decimate for ii, (level, n_tri) in enumerate(_tri_levels.items(), 3): if no_decimate: break - logger.info(f'{ii}. Creating {level} tessellation...') - logger.info(f'{ii}.1 Decimating the dense tessellation ' - f'({len(surf["tris"])} -> {n_tri} triangles)...') - points, tris = decimate_surface(points=surf['rr'], - triangles=surf['tris'], - n_triangles=n_tri) + logger.info(f"{ii}. Creating {level} tessellation...") + logger.info( + f"{ii}.1 Decimating the dense tessellation " + f'({len(surf["tris"])} -> {n_tri} triangles)...' + ) + points, tris = decimate_surface( + points=surf["rr"], triangles=surf["tris"], n_triangles=n_tri + ) dec_fname = str(fname_template).format(level) - logger.info('%i.2 Creating %s' % (ii, dec_fname)) + logger.info("%i.2 Creating %s" % (ii, dec_fname)) _check_file(dec_fname, overwrite) dec_surf = _surfaces_to_bem( [dict(rr=points, tris=tris)], - [FIFF.FIFFV_BEM_SURF_ID_HEAD], [1], rescale=False, - incomplete=incomplete, extra=msg) + [FIFF.FIFFV_BEM_SURF_ID_HEAD], + [1], + rescale=False, + incomplete=incomplete, + extra=msg, + ) write_bem_surfaces(dec_fname, dec_surf, overwrite=overwrite) - logger.info('[done]') + logger.info("[done]") @verbose @@ -2318,28 +2537,23 @@ def distance_to_bem(pos, bem, trans=None, verbose=None): distance = np.zeros((n,)) logger.info( - 'Computing distance to inner skull surface for ' + - f'{n} position{_pl(n)}...' + "Computing distance to inner skull surface for " + f"{n} position{_pl(n)}..." ) - if bem['is_sphere']: - center = bem['r0'] + if bem["is_sphere"]: + center = bem["r0"] if trans: center = apply_trans(trans, center, move=True) - radius = bem['layers'][0]['rad'] + radius = bem["layers"][0]["rad"] - distance = np.abs(radius - np.linalg.norm( - pos - center, axis=1 - )) + distance = np.abs(radius - np.linalg.norm(pos - center, axis=1)) else: # is BEM - surface_points = bem['surfs'][0]['rr'] + surface_points = bem["surfs"][0]["rr"] if trans: - surface_points = apply_trans( - trans, surface_points, move=True - ) + surface_points = apply_trans(trans, surface_points, move=True) _, distance = _compute_nearest(surface_points, pos, return_dists=True) diff --git a/mne/channels/__init__.py b/mne/channels/__init__.py index c5701c7b2b1..cfd48fc2449 100644 --- a/mne/channels/__init__.py +++ b/mne/channels/__init__.py @@ -4,41 +4,81 @@ """ from ..defaults import HEAD_SIZE_DEFAULT -from .layout import (Layout, make_eeg_layout, make_grid_layout, read_layout, - find_layout, generate_2d_layout) -from .montage import (DigMontage, - get_builtin_montages, make_dig_montage, read_dig_dat, - read_dig_egi, read_dig_captrak, read_dig_fif, - read_dig_polhemus_isotrak, read_polhemus_fastscan, - compute_dev_head_t, make_standard_montage, - read_custom_montage, read_dig_hpts, read_dig_localite, - compute_native_head_t) -from .channels import (equalize_channels, rename_channels, fix_mag_coil_types, - read_ch_adjacency, _get_ch_type, find_ch_adjacency, - make_1020_channel_selections, combine_channels, - read_vectorview_selection, _SELECTIONS, _EEG_SELECTIONS, - _divide_to_regions, get_builtin_ch_adjacencies) +from .layout import ( + Layout, + make_eeg_layout, + make_grid_layout, + read_layout, + find_layout, + generate_2d_layout, +) +from .montage import ( + DigMontage, + get_builtin_montages, + make_dig_montage, + read_dig_dat, + read_dig_egi, + read_dig_captrak, + read_dig_fif, + read_dig_polhemus_isotrak, + read_polhemus_fastscan, + compute_dev_head_t, + make_standard_montage, + read_custom_montage, + read_dig_hpts, + read_dig_localite, + compute_native_head_t, +) +from .channels import ( + equalize_channels, + rename_channels, + fix_mag_coil_types, + read_ch_adjacency, + _get_ch_type, + find_ch_adjacency, + make_1020_channel_selections, + combine_channels, + read_vectorview_selection, + _SELECTIONS, + _EEG_SELECTIONS, + _divide_to_regions, + get_builtin_ch_adjacencies, +) __all__ = [ # Data Structures - 'DigMontage', 'Layout', - + "DigMontage", + "Layout", # Factory Methods - 'make_dig_montage', 'make_eeg_layout', 'make_grid_layout', - 'make_standard_montage', - + "make_dig_montage", + "make_eeg_layout", + "make_grid_layout", + "make_standard_montage", # Readers - 'read_ch_adjacency', 'read_dig_captrak', 'read_dig_dat', - 'read_dig_egi', 'read_dig_fif', 'read_dig_localite', - 'read_dig_polhemus_isotrak', 'read_layout', - 'read_polhemus_fastscan', 'read_custom_montage', 'read_dig_hpts', - + "read_ch_adjacency", + "read_dig_captrak", + "read_dig_dat", + "read_dig_egi", + "read_dig_fif", + "read_dig_localite", + "read_dig_polhemus_isotrak", + "read_layout", + "read_polhemus_fastscan", + "read_custom_montage", + "read_dig_hpts", # Helpers - 'rename_channels', 'make_1020_channel_selections', - '_get_ch_type', 'equalize_channels', 'find_ch_adjacency', 'find_layout', - 'fix_mag_coil_types', 'generate_2d_layout', 'get_builtin_montages', - 'combine_channels', 'read_vectorview_selection', - + "rename_channels", + "make_1020_channel_selections", + "_get_ch_type", + "equalize_channels", + "find_ch_adjacency", + "find_layout", + "fix_mag_coil_types", + "generate_2d_layout", + "get_builtin_montages", + "combine_channels", + "read_vectorview_selection", # Other - 'compute_dev_head_t', 'compute_native_head_t', + "compute_dev_head_t", + "compute_native_head_t", ] diff --git a/mne/channels/_dig_montage_utils.py b/mne/channels/_dig_montage_utils.py index a60418e84d4..d03cbc1fcbe 100644 --- a/mne/channels/_dig_montage_utils.py +++ b/mne/channels/_dig_montage_utils.py @@ -20,71 +20,76 @@ def _read_dig_montage_egi( - fname, - _scaling, - _all_data_kwargs_are_none, + fname, + _scaling, + _all_data_kwargs_are_none, ): - if not _all_data_kwargs_are_none: - raise ValueError('hsp, hpi, elp, point_names, fif must all be ' - 'None if egi is not None') - _check_fname(fname, overwrite='read', must_exist=True) + raise ValueError( + "hsp, hpi, elp, point_names, fif must all be " "None if egi is not None" + ) + _check_fname(fname, overwrite="read", must_exist=True) root = ElementTree.parse(fname).getroot() - ns = root.tag[root.tag.index('{'):root.tag.index('}') + 1] - sensors = root.find('%ssensorLayout/%ssensors' % (ns, ns)) + ns = root.tag[root.tag.index("{") : root.tag.index("}") + 1] + sensors = root.find("%ssensorLayout/%ssensors" % (ns, ns)) fids = dict() dig_ch_pos = dict() - fid_name_map = {'Nasion': 'nasion', - 'Right periauricular point': 'rpa', - 'Left periauricular point': 'lpa'} + fid_name_map = { + "Nasion": "nasion", + "Right periauricular point": "rpa", + "Left periauricular point": "lpa", + } for s in sensors: name, number, kind = s[0].text, int(s[1].text), int(s[2].text) - coordinates = np.array([float(s[3].text), float(s[4].text), - float(s[5].text)]) + coordinates = np.array([float(s[3].text), float(s[4].text), float(s[5].text)]) coordinates *= _scaling # EEG Channels if kind == 0: - dig_ch_pos['EEG %03d' % number] = coordinates + dig_ch_pos["EEG %03d" % number] = coordinates # Reference elif kind == 1: - dig_ch_pos['EEG %03d' % - (len(dig_ch_pos.keys()) + 1)] = coordinates + dig_ch_pos["EEG %03d" % (len(dig_ch_pos.keys()) + 1)] = coordinates # Fiducials elif kind == 2: fid_name = fid_name_map[name] fids[fid_name] = coordinates # Unknown else: - warn('Unknown sensor type %s detected. Skipping sensor...' - 'Proceed with caution!' % kind) + warn( + "Unknown sensor type %s detected. Skipping sensor..." + "Proceed with caution!" % kind + ) return Bunch( # EGI stuff - nasion=fids['nasion'], lpa=fids['lpa'], rpa=fids['rpa'], - ch_pos=dig_ch_pos, coord_frame='unknown', + nasion=fids["nasion"], + lpa=fids["lpa"], + rpa=fids["rpa"], + ch_pos=dig_ch_pos, + coord_frame="unknown", ) def _parse_brainvision_dig_montage(fname, scale): - FID_NAME_MAP = {'Nasion': 'nasion', 'RPA': 'rpa', 'LPA': 'lpa'} + FID_NAME_MAP = {"Nasion": "nasion", "RPA": "rpa", "LPA": "lpa"} root = ElementTree.parse(fname).getroot() - sensors = root.find('CapTrakElectrodeList') + sensors = root.find("CapTrakElectrodeList") fids, dig_ch_pos = dict(), dict() for s in sensors: - name = s.find('Name').text + name = s.find("Name").text is_fid = name in FID_NAME_MAP - coordinates = scale * np.array([float(s.find('X').text), - float(s.find('Y').text), - float(s.find('Z').text)]) + coordinates = scale * np.array( + [float(s.find("X").text), float(s.find("Y").text), float(s.find("Z").text)] + ) # Fiducials if is_fid: @@ -95,6 +100,9 @@ def _parse_brainvision_dig_montage(fname, scale): return dict( # BVCT stuff - nasion=fids['nasion'], lpa=fids['lpa'], rpa=fids['rpa'], - ch_pos=dig_ch_pos, coord_frame='unknown' + nasion=fids["nasion"], + lpa=fids["lpa"], + rpa=fids["rpa"], + ch_pos=dig_ch_pos, + coord_frame="unknown", ) diff --git a/mne/channels/_standard_montage_utils.py b/mne/channels/_standard_montage_utils.py index b83252c0dc1..c136b107924 100644 --- a/mne/channels/_standard_montage_utils.py +++ b/mne/channels/_standard_montage_utils.py @@ -17,28 +17,32 @@ from ..utils import warn, _pl from . import __file__ as _CHANNELS_INIT_FILE -MONTAGE_PATH = op.join(op.dirname(_CHANNELS_INIT_FILE), 'data', 'montages') +MONTAGE_PATH = op.join(op.dirname(_CHANNELS_INIT_FILE), "data", "montages") -_str = 'U100' +_str = "U100" # In standard_1020, T9=LPA, T10=RPA, Nasion is the same as Iz with a # sign-flipped Y value + def _egi_256(head_size): - fname = op.join(MONTAGE_PATH, 'EGI_256.csd') + fname = op.join(MONTAGE_PATH, "EGI_256.csd") montage = _read_csd(fname, head_size) ch_pos = montage._get_ch_pos() # For this cap, the Nasion is the frontmost electrode, # LPA/RPA we approximate by putting 75% of the way (toward the front) # between the two electrodes that are halfway down the ear holes - nasion = ch_pos['E31'] - lpa = 0.75 * ch_pos['E67'] + 0.25 * ch_pos['E94'] - rpa = 0.75 * ch_pos['E219'] + 0.25 * ch_pos['E190'] + nasion = ch_pos["E31"] + lpa = 0.75 * ch_pos["E67"] + 0.25 * ch_pos["E94"] + rpa = 0.75 * ch_pos["E219"] + 0.25 * ch_pos["E190"] fids_montage = make_dig_montage( - coord_frame='unknown', nasion=nasion, lpa=lpa, rpa=rpa, + coord_frame="unknown", + nasion=nasion, + lpa=lpa, + rpa=rpa, ) montage += fids_montage # add fiducials to montage @@ -63,119 +67,116 @@ def _str_names(ch_names): def _safe_np_loadtxt(fname, **kwargs): out = np.genfromtxt(fname, **kwargs) - ch_names = _str_names(out['f0']) - others = tuple(out['f%d' % ii] for ii in range(1, len(out.dtype.fields))) + ch_names = _str_names(out["f0"]) + others = tuple(out["f%d" % ii] for ii in range(1, len(out.dtype.fields))) return (ch_names,) + others def _biosemi(basename, head_size): fname = op.join(MONTAGE_PATH, basename) - fid_names = ('Nz', 'LPA', 'RPA') + fid_names = ("Nz", "LPA", "RPA") return _read_theta_phi_in_degrees(fname, head_size, fid_names) -def _mgh_or_standard(basename, head_size, coord_frame='unknown'): - fid_names = ('Nz', 'LPA', 'RPA') +def _mgh_or_standard(basename, head_size, coord_frame="unknown"): + fid_names = ("Nz", "LPA", "RPA") fname = op.join(MONTAGE_PATH, basename) ch_names_, pos = [], [] with open(fname) as fid: # Ignore units as we will scale later using the norms anyway for line in fid: - if 'Positions\n' in line: + if "Positions\n" in line: break pos = [] for line in fid: - if 'Labels\n' in line: + if "Labels\n" in line: break pos.append(list(map(float, line.split()))) for line in fid: - if not line or not set(line) - {' '}: + if not line or not set(line) - {" "}: break - ch_names_.append(line.strip(' ').strip('\n')) + ch_names_.append(line.strip(" ").strip("\n")) - pos = np.array(pos) / 1000. + pos = np.array(pos) / 1000.0 ch_pos = _check_dupes_odict(ch_names_, pos) nasion, lpa, rpa = [ch_pos.pop(n) for n in fid_names] if head_size is None: - scale = 1. + scale = 1.0 else: scale = head_size / np.median(np.linalg.norm(pos, axis=1)) for value in ch_pos.values(): value *= scale # if we are in MRI/MNI coordinates, we need to replace nasion, LPA, and RPA # with those of fsaverage for ``trans='fsaverage'`` to work - if coord_frame == 'mri': - lpa, nasion, rpa = [ - x['r'].copy() for x in get_mni_fiducials('fsaverage')] + if coord_frame == "mri": + lpa, nasion, rpa = [x["r"].copy() for x in get_mni_fiducials("fsaverage")] nasion *= scale lpa *= scale rpa *= scale - return make_dig_montage(ch_pos=ch_pos, coord_frame=coord_frame, - nasion=nasion, lpa=lpa, rpa=rpa) + return make_dig_montage( + ch_pos=ch_pos, coord_frame=coord_frame, nasion=nasion, lpa=lpa, rpa=rpa + ) standard_montage_look_up_table = { - 'EGI_256': _egi_256, - - 'easycap-M1': partial(_easycap, basename='easycap-M1.txt'), - 'easycap-M10': partial(_easycap, basename='easycap-M10.txt'), - - 'GSN-HydroCel-128': partial(_hydrocel, basename='GSN-HydroCel-128.sfp'), - 'GSN-HydroCel-129': partial(_hydrocel, basename='GSN-HydroCel-129.sfp'), - 'GSN-HydroCel-256': partial(_hydrocel, basename='GSN-HydroCel-256.sfp'), - 'GSN-HydroCel-257': partial(_hydrocel, basename='GSN-HydroCel-257.sfp'), - 'GSN-HydroCel-32': partial(_hydrocel, basename='GSN-HydroCel-32.sfp'), - 'GSN-HydroCel-64_1.0': partial(_hydrocel, - basename='GSN-HydroCel-64_1.0.sfp'), - 'GSN-HydroCel-65_1.0': partial(_hydrocel, - basename='GSN-HydroCel-65_1.0.sfp'), - - 'biosemi128': partial(_biosemi, basename='biosemi128.txt'), - 'biosemi16': partial(_biosemi, basename='biosemi16.txt'), - 'biosemi160': partial(_biosemi, basename='biosemi160.txt'), - 'biosemi256': partial(_biosemi, basename='biosemi256.txt'), - 'biosemi32': partial(_biosemi, basename='biosemi32.txt'), - 'biosemi64': partial(_biosemi, basename='biosemi64.txt'), - - 'mgh60': partial(_mgh_or_standard, basename='mgh60.elc', - coord_frame='mri'), - 'mgh70': partial(_mgh_or_standard, basename='mgh70.elc', - coord_frame='mri'), - 'standard_1005': partial(_mgh_or_standard, - basename='standard_1005.elc', coord_frame='mri'), - 'standard_1020': partial(_mgh_or_standard, - basename='standard_1020.elc', coord_frame='mri'), - 'standard_alphabetic': partial(_mgh_or_standard, - basename='standard_alphabetic.elc', - coord_frame='mri'), - 'standard_postfixed': partial(_mgh_or_standard, - basename='standard_postfixed.elc', - coord_frame='mri'), - 'standard_prefixed': partial(_mgh_or_standard, - basename='standard_prefixed.elc', - coord_frame='mri'), - 'standard_primed': partial(_mgh_or_standard, - basename='standard_primed.elc', - coord_frame='mri'), - 'artinis-octamon': partial(_mgh_or_standard, coord_frame='mri', - basename='artinis-octamon.elc'), - 'artinis-brite23': partial(_mgh_or_standard, coord_frame='mri', - basename='artinis-brite23.elc'), - 'brainproducts-RNP-BA-128': partial( - _easycap, basename='brainproducts-RNP-BA-128.txt') + "EGI_256": _egi_256, + "easycap-M1": partial(_easycap, basename="easycap-M1.txt"), + "easycap-M10": partial(_easycap, basename="easycap-M10.txt"), + "GSN-HydroCel-128": partial(_hydrocel, basename="GSN-HydroCel-128.sfp"), + "GSN-HydroCel-129": partial(_hydrocel, basename="GSN-HydroCel-129.sfp"), + "GSN-HydroCel-256": partial(_hydrocel, basename="GSN-HydroCel-256.sfp"), + "GSN-HydroCel-257": partial(_hydrocel, basename="GSN-HydroCel-257.sfp"), + "GSN-HydroCel-32": partial(_hydrocel, basename="GSN-HydroCel-32.sfp"), + "GSN-HydroCel-64_1.0": partial(_hydrocel, basename="GSN-HydroCel-64_1.0.sfp"), + "GSN-HydroCel-65_1.0": partial(_hydrocel, basename="GSN-HydroCel-65_1.0.sfp"), + "biosemi128": partial(_biosemi, basename="biosemi128.txt"), + "biosemi16": partial(_biosemi, basename="biosemi16.txt"), + "biosemi160": partial(_biosemi, basename="biosemi160.txt"), + "biosemi256": partial(_biosemi, basename="biosemi256.txt"), + "biosemi32": partial(_biosemi, basename="biosemi32.txt"), + "biosemi64": partial(_biosemi, basename="biosemi64.txt"), + "mgh60": partial(_mgh_or_standard, basename="mgh60.elc", coord_frame="mri"), + "mgh70": partial(_mgh_or_standard, basename="mgh70.elc", coord_frame="mri"), + "standard_1005": partial( + _mgh_or_standard, basename="standard_1005.elc", coord_frame="mri" + ), + "standard_1020": partial( + _mgh_or_standard, basename="standard_1020.elc", coord_frame="mri" + ), + "standard_alphabetic": partial( + _mgh_or_standard, basename="standard_alphabetic.elc", coord_frame="mri" + ), + "standard_postfixed": partial( + _mgh_or_standard, basename="standard_postfixed.elc", coord_frame="mri" + ), + "standard_prefixed": partial( + _mgh_or_standard, basename="standard_prefixed.elc", coord_frame="mri" + ), + "standard_primed": partial( + _mgh_or_standard, basename="standard_primed.elc", coord_frame="mri" + ), + "artinis-octamon": partial( + _mgh_or_standard, coord_frame="mri", basename="artinis-octamon.elc" + ), + "artinis-brite23": partial( + _mgh_or_standard, coord_frame="mri", basename="artinis-brite23.elc" + ), + "brainproducts-RNP-BA-128": partial( + _easycap, basename="brainproducts-RNP-BA-128.txt" + ), } def _read_sfp(fname, head_size): """Read .sfp BESA/EGI files.""" # fname has been already checked - fid_names = ('FidNz', 'FidT9', 'FidT10') - options = dict(dtype=(_str, 'f4', 'f4', 'f4')) + fid_names = ("FidNz", "FidT9", "FidT10") + options = dict(dtype=(_str, "f4", "f4", "f4")) ch_names, xs, ys, zs = _safe_np_loadtxt(fname, **options) # deal with "headshape" - mask = np.array([ch_name == 'headshape' for ch_name in ch_names], bool) + mask = np.array([ch_name == "headshape" for ch_name in ch_names], bool) hsp = np.stack([xs[mask], ys[mask], zs[mask]], axis=-1) mask = ~mask pos = np.stack([xs[mask], ys[mask], zs[mask]], axis=-1) @@ -193,14 +194,16 @@ def _read_sfp(fname, head_size): lpa = lpa * scale if lpa is not None else None rpa = rpa * scale if rpa is not None else None - return make_dig_montage(ch_pos=ch_pos, coord_frame='unknown', - nasion=nasion, rpa=rpa, lpa=lpa, hsp=hsp) + return make_dig_montage( + ch_pos=ch_pos, coord_frame="unknown", nasion=nasion, rpa=rpa, lpa=lpa, hsp=hsp + ) def _read_csd(fname, head_size): # Label, Theta, Phi, Radius, X, Y, Z, off sphere surface - options = dict(comments='//', - dtype=(_str, 'f4', 'f4', 'f4', 'f4', 'f4', 'f4', 'f4')) + options = dict( + comments="//", dtype=(_str, "f4", "f4", "f4", "f4", "f4", "f4", "f4") + ) ch_names, _, _, _, xs, ys, zs, _ = _safe_np_loadtxt(fname, **options) pos = np.stack([xs, ys, zs], axis=-1) @@ -213,16 +216,15 @@ def _read_csd(fname, head_size): def _check_dupes_odict(ch_names, pos): """Warn if there are duplicates, then turn to ordered dict.""" ch_names = list(ch_names) - dups = OrderedDict((ch_name, ch_names.count(ch_name)) - for ch_name in ch_names) - dups = OrderedDict((ch_name, count) for ch_name, count in dups.items() - if count > 1) + dups = OrderedDict((ch_name, ch_names.count(ch_name)) for ch_name in ch_names) + dups = OrderedDict((ch_name, count) for ch_name, count in dups.items() if count > 1) n = len(dups) if n: - dups = ', '.join( - f'{ch_name} ({count})' for ch_name, count in dups.items()) - warn(f'Duplicate channel position{_pl(n)} found, the last will be ' - f'used for {dups}') + dups = ", ".join(f"{ch_name} ({count})" for ch_name, count in dups.items()) + warn( + f"Duplicate channel position{_pl(n)} found, the last will be " + f"used for {dups}" + ) return OrderedDict(zip(ch_names, pos)) @@ -242,30 +244,30 @@ def _read_elc(fname, head_size): montage : instance of DigMontage The montage in [m]. """ - fid_names = ('Nz', 'LPA', 'RPA') + fid_names = ("Nz", "LPA", "RPA") ch_names_, pos = [], [] with open(fname) as fid: # _read_elc does require to detect the units. (see _mgh_or_standard) for line in fid: - if 'UnitPosition' in line: + if "UnitPosition" in line: units = line.split()[1] - scale = dict(m=1., mm=1e-3)[units] + scale = dict(m=1.0, mm=1e-3)[units] break else: - raise RuntimeError('Could not detect units in file %s' % fname) + raise RuntimeError("Could not detect units in file %s" % fname) for line in fid: - if 'Positions\n' in line: + if "Positions\n" in line: break pos = [] for line in fid: - if 'Labels\n' in line: + if "Labels\n" in line: break pos.append(list(map(float, line.split()))) for line in fid: - if not line or not set(line) - {' '}: + if not line or not set(line) - {" "}: break - ch_names_.append(line.strip(' ').strip('\n')) + ch_names_.append(line.strip(" ").strip("\n")) pos = np.array(pos) * scale if head_size is not None: @@ -274,14 +276,15 @@ def _read_elc(fname, head_size): ch_pos = _check_dupes_odict(ch_names_, pos) nasion, lpa, rpa = [ch_pos.pop(n, None) for n in fid_names] - return make_dig_montage(ch_pos=ch_pos, coord_frame='unknown', - nasion=nasion, lpa=lpa, rpa=rpa) + return make_dig_montage( + ch_pos=ch_pos, coord_frame="unknown", nasion=nasion, lpa=lpa, rpa=rpa + ) -def _read_theta_phi_in_degrees(fname, head_size, fid_names=None, - add_fiducials=False): - ch_names, theta, phi = _safe_np_loadtxt(fname, skip_header=1, - dtype=(_str, 'i4', 'i4')) +def _read_theta_phi_in_degrees(fname, head_size, fid_names=None, add_fiducials=False): + ch_names, theta, phi = _safe_np_loadtxt( + fname, skip_header=1, dtype=(_str, "i4", "i4") + ) if add_fiducials: # Add fiducials based on 10/20 spherical coordinate definitions # http://chgd.umich.edu/wp-content/uploads/2014/06/ @@ -290,7 +293,7 @@ def _read_theta_phi_in_degrees(fname, head_size, fid_names=None, # https://www.easycap.de/wp-content/uploads/2018/02/ # Easycap-Equidistant-Layouts.pdf assert fid_names is None - fid_names = ['Nasion', 'LPA', 'RPA'] + fid_names = ["Nasion", "LPA", "RPA"] ch_names.extend(fid_names) theta = np.append(theta, [115, -115, 115]) phi = np.append(phi, [90, 0, 0]) @@ -303,23 +306,23 @@ def _read_theta_phi_in_degrees(fname, head_size, fid_names=None, if fid_names is not None: nasion, lpa, rpa = [ch_pos.pop(n, None) for n in fid_names] - return make_dig_montage(ch_pos=ch_pos, coord_frame='unknown', - nasion=nasion, lpa=lpa, rpa=rpa) + return make_dig_montage( + ch_pos=ch_pos, coord_frame="unknown", nasion=nasion, lpa=lpa, rpa=rpa + ) def _read_elp_besa(fname, head_size): # This .elp is not the same as polhemus elp. see _read_isotrak_elp_points - dtype = np.dtype('S8, S8, f8, f8, f8') + dtype = np.dtype("S8, S8, f8, f8, f8") data = np.loadtxt(fname, dtype=dtype) - ch_names = data['f1'].astype(str).tolist() - az = data['f2'] - horiz = data['f3'] - radius = np.abs(az / 180.) - az = np.deg2rad(np.array([h if a >= 0. else 180 + h - for h, a in zip(horiz, az)])) + ch_names = data["f1"].astype(str).tolist() + az = data["f2"] + horiz = data["f3"] + radius = np.abs(az / 180.0) + az = np.deg2rad(np.array([h if a >= 0.0 else 180 + h for h, a in zip(horiz, az)])) pol = radius * np.pi - rad = data['f4'] / 100 + rad = data["f4"] / 100 pos = _sph_to_cart(np.array([rad, az, pol]).T) if head_size is not None: @@ -327,7 +330,7 @@ def _read_elp_besa(fname, head_size): ch_pos = _check_dupes_odict(ch_names, pos) - fid_names = ('Nz', 'LPA', 'RPA') + fid_names = ("Nz", "LPA", "RPA") # No one grants that the fid names actually exist. nasion, lpa, rpa = [ch_pos.pop(n, None) for n in fid_names] diff --git a/mne/channels/channels.py b/mne/channels/channels.py index 211b0275441..f599313304c 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -22,18 +22,41 @@ import numpy as np from ..defaults import HEAD_SIZE_DEFAULT, _handle_default -from ..utils import (verbose, logger, warn, - _check_preload, _validate_type, fill_doc, _check_option, - _get_stim_channel, _check_fname, _check_dict_keys, - _on_missing, legacy) +from ..utils import ( + verbose, + logger, + warn, + _check_preload, + _validate_type, + fill_doc, + _check_option, + _get_stim_channel, + _check_fname, + _check_dict_keys, + _on_missing, + legacy, +) from ..io.constants import FIFF -from ..io.meas_info import (anonymize_info, Info, MontageMixin, create_info, - _rename_comps) -from ..io.pick import (channel_type, pick_info, pick_types, _picks_by_type, - _check_excludes_includes, _contains_ch_type, - channel_indices_by_type, pick_channels, _picks_to_idx, - get_channel_type_constants, - _pick_data_channels) +from ..io.meas_info import ( + anonymize_info, + Info, + MontageMixin, + create_info, + _rename_comps, +) +from ..io.pick import ( + channel_type, + pick_info, + pick_types, + _picks_by_type, + _check_excludes_includes, + _contains_ch_type, + channel_indices_by_type, + pick_channels, + _picks_to_idx, + get_channel_type_constants, + _pick_data_channels, +) from ..io.tag import _rename_list from ..io.write import DATE_NONE from ..io.proj import setup_proj @@ -42,39 +65,40 @@ def _get_meg_system(info): """Educated guess for the helmet type based on channels.""" have_helmet = True - for ch in info['chs']: - if ch['kind'] == FIFF.FIFFV_MEG_CH: + for ch in info["chs"]: + if ch["kind"] == FIFF.FIFFV_MEG_CH: # Only take first 16 bits, as higher bits store CTF grad comp order - coil_type = ch['coil_type'] & 0xFFFF - nmag = np.sum( - [c['kind'] == FIFF.FIFFV_MEG_CH for c in info['chs']]) + coil_type = ch["coil_type"] & 0xFFFF + nmag = np.sum([c["kind"] == FIFF.FIFFV_MEG_CH for c in info["chs"]]) if coil_type == FIFF.FIFFV_COIL_NM_122: - system = '122m' + system = "122m" break elif coil_type // 1000 == 3: # All Vectorview coils are 30xx - system = '306m' + system = "306m" break - elif (coil_type == FIFF.FIFFV_COIL_MAGNES_MAG or - coil_type == FIFF.FIFFV_COIL_MAGNES_GRAD): - system = 'Magnes_3600wh' if nmag > 150 else 'Magnes_2500wh' + elif ( + coil_type == FIFF.FIFFV_COIL_MAGNES_MAG + or coil_type == FIFF.FIFFV_COIL_MAGNES_GRAD + ): + system = "Magnes_3600wh" if nmag > 150 else "Magnes_2500wh" break elif coil_type == FIFF.FIFFV_COIL_CTF_GRAD: - system = 'CTF_275' + system = "CTF_275" break elif coil_type == FIFF.FIFFV_COIL_KIT_GRAD: - system = 'KIT' + system = "KIT" # Our helmet does not match very well, so let's just create it have_helmet = False break elif coil_type == FIFF.FIFFV_COIL_BABY_GRAD: - system = 'BabySQUID' + system = "BabySQUID" break elif coil_type == FIFF.FIFFV_COIL_ARTEMIS123_GRAD: - system = 'ARTEMIS123' + system = "ARTEMIS123" have_helmet = False break else: - system = 'unknown' + system = "unknown" have_helmet = False return system, have_helmet @@ -86,11 +110,24 @@ def _get_ch_type(inst, ch_type, allow_ref_meg=False): then grads, then ... to plot. """ if ch_type is None: - allowed_types = ['mag', 'grad', 'planar1', 'planar2', 'eeg', 'csd', - 'fnirs_cw_amplitude', 'fnirs_fd_ac_amplitude', - 'fnirs_fd_phase', 'fnirs_od', 'hbo', 'hbr', - 'ecog', 'seeg', 'dbs'] - allowed_types += ['ref_meg'] if allow_ref_meg else [] + allowed_types = [ + "mag", + "grad", + "planar1", + "planar2", + "eeg", + "csd", + "fnirs_cw_amplitude", + "fnirs_fd_ac_amplitude", + "fnirs_fd_phase", + "fnirs_od", + "hbo", + "hbr", + "ecog", + "seeg", + "dbs", + ] + allowed_types += ["ref_meg"] if allow_ref_meg else [] for type_ in allowed_types: if isinstance(inst, Info): if _contains_ch_type(inst, type_): @@ -100,7 +137,7 @@ def _get_ch_type(inst, ch_type, allow_ref_meg=False): ch_type = type_ break else: - raise RuntimeError('No plottable channel types found') + raise RuntimeError("No plottable channel types found") return ch_type @@ -147,16 +184,26 @@ def equalize_channels(instances, copy=True, verbose=None): # Instances need to have a `ch_names` attribute and a `pick_channels` # method that supports `ordered=True`. - allowed_types = (BaseRaw, BaseEpochs, Evoked, _BaseTFR, Forward, - Covariance, CrossSpectralDensity, Info) - allowed_types_str = ("Raw, Epochs, Evoked, TFR, Forward, Covariance, " - "CrossSpectralDensity or Info") + allowed_types = ( + BaseRaw, + BaseEpochs, + Evoked, + _BaseTFR, + Forward, + Covariance, + CrossSpectralDensity, + Info, + ) + allowed_types_str = ( + "Raw, Epochs, Evoked, TFR, Forward, Covariance, " "CrossSpectralDensity or Info" + ) for inst in instances: - _validate_type(inst, allowed_types, "Instances to be modified", - allowed_types_str) + _validate_type( + inst, allowed_types, "Instances to be modified", allowed_types_str + ) chan_template = instances[0].ch_names - logger.info('Identifying common channels ...') + logger.info("Identifying common channels ...") channels = [set(inst.ch_names) for inst in instances] common_channels = set(chan_template).intersection(*channels) all_channels = set(chan_template).union(*channels) @@ -173,8 +220,9 @@ def equalize_channels(instances, copy=True, verbose=None): # Only perform picking when needed if inst.ch_names != common_channels: if isinstance(inst, Info): - sel = pick_channels(inst.ch_names, common_channels, exclude=[], - ordered=True) + sel = pick_channels( + inst.ch_names, common_channels, exclude=[], ordered=True + ) inst = pick_info(inst, sel, copy=copy, verbose=False) else: if copy: @@ -185,47 +233,59 @@ def equalize_channels(instances, copy=True, verbose=None): equalized_instances.append(inst) if dropped: - logger.info('Dropped the following channels:\n%s' % dropped) + logger.info("Dropped the following channels:\n%s" % dropped) elif reordered: - logger.info('Channels have been re-ordered.') + logger.info("Channels have been re-ordered.") return equalized_instances channel_type_constants = get_channel_type_constants(include_defaults=True) -_human2fiff = {k: v.get('kind', FIFF.FIFFV_COIL_NONE) for k, v in - channel_type_constants.items()} -_human2unit = {k: v.get('unit', FIFF.FIFF_UNIT_NONE) for k, v in - channel_type_constants.items()} -_unit2human = {FIFF.FIFF_UNIT_V: 'V', - FIFF.FIFF_UNIT_T: 'T', - FIFF.FIFF_UNIT_T_M: 'T/m', - FIFF.FIFF_UNIT_MOL: 'M', - FIFF.FIFF_UNIT_NONE: 'NA', - FIFF.FIFF_UNIT_CEL: 'C', - FIFF.FIFF_UNIT_S: 'S', - FIFF.FIFF_UNIT_PX: 'px'} +_human2fiff = { + k: v.get("kind", FIFF.FIFFV_COIL_NONE) for k, v in channel_type_constants.items() +} +_human2unit = { + k: v.get("unit", FIFF.FIFF_UNIT_NONE) for k, v in channel_type_constants.items() +} +_unit2human = { + FIFF.FIFF_UNIT_V: "V", + FIFF.FIFF_UNIT_T: "T", + FIFF.FIFF_UNIT_T_M: "T/m", + FIFF.FIFF_UNIT_MOL: "M", + FIFF.FIFF_UNIT_NONE: "NA", + FIFF.FIFF_UNIT_CEL: "C", + FIFF.FIFF_UNIT_S: "S", + FIFF.FIFF_UNIT_PX: "px", +} def _check_set(ch, projs, ch_type): """Ensure type change is compatible with projectors.""" new_kind = _human2fiff[ch_type] - if ch['kind'] != new_kind: + if ch["kind"] != new_kind: for proj in projs: - if ch['ch_name'] in proj['data']['col_names']: - raise RuntimeError('Cannot change channel type for channel %s ' - 'in projector "%s"' - % (ch['ch_name'], proj['desc'])) - ch['kind'] = new_kind + if ch["ch_name"] in proj["data"]["col_names"]: + raise RuntimeError( + "Cannot change channel type for channel %s " + 'in projector "%s"' % (ch["ch_name"], proj["desc"]) + ) + ch["kind"] = new_kind class SetChannelsMixin(MontageMixin): """Mixin class for Raw, Evoked, Epochs.""" @verbose - def set_eeg_reference(self, ref_channels='average', projection=False, - ch_type='auto', forward=None, *, joint=False, - verbose=None): + def set_eeg_reference( + self, + ref_channels="average", + projection=False, + ch_type="auto", + forward=None, + *, + joint=False, + verbose=None, + ): """Specify which reference to use for EEG data. Use this function to explicitly specify the desired reference for EEG. @@ -251,9 +311,16 @@ def set_eeg_reference(self, ref_channels='average', projection=False, %(set_eeg_reference_see_also_notes)s """ from ..io.reference import set_eeg_reference - return set_eeg_reference(self, ref_channels=ref_channels, copy=False, - projection=projection, ch_type=ch_type, - forward=forward, joint=joint)[0] + + return set_eeg_reference( + self, + ref_channels=ref_channels, + copy=False, + projection=projection, + ch_type=ch_type, + forward=forward, + joint=joint, + )[0] def _get_channel_positions(self, picks=None): """Get channel locations from info. @@ -268,12 +335,13 @@ def _get_channel_positions(self, picks=None): .. versionadded:: 0.9.0 """ picks = _picks_to_idx(self.info, picks) - chs = self.info['chs'] - pos = np.array([chs[k]['loc'][:3] for k in picks]) + chs = self.info["chs"] + pos = np.array([chs[k]["loc"][:3] for k in picks]) n_zero = np.sum(np.sum(np.abs(pos), axis=1) == 0) if n_zero > 1: # XXX some systems have origin (0, 0, 0) - raise ValueError('Could not extract channel positions for ' - '{} channels'.format(n_zero)) + raise ValueError( + "Could not extract channel positions for " "{} channels".format(n_zero) + ) return pos def _set_channel_positions(self, pos, names): @@ -291,24 +359,25 @@ def _set_channel_positions(self, pos, names): .. versionadded:: 0.9.0 """ if len(pos) != len(names): - raise ValueError('Number of channel positions not equal to ' - 'the number of names given.') + raise ValueError( + "Number of channel positions not equal to " "the number of names given." + ) pos = np.asarray(pos, dtype=np.float64) if pos.shape[-1] != 3 or pos.ndim != 2: - msg = ('Channel positions must have the shape (n_points, 3) ' - 'not %s.' % (pos.shape,)) + msg = "Channel positions must have the shape (n_points, 3) " "not %s." % ( + pos.shape, + ) raise ValueError(msg) for name, p in zip(names, pos): if name in self.ch_names: idx = self.ch_names.index(name) - self.info['chs'][idx]['loc'][:3] = p + self.info["chs"][idx]["loc"][:3] = p else: - msg = ('%s was not found in the info. Cannot be updated.' - % name) + msg = "%s was not found in the info. Cannot be updated." % name raise ValueError(msg) @verbose - def set_channel_types(self, mapping, *, on_unit_change='warn', verbose=None): + def set_channel_types(self, mapping, *, on_unit_change="warn", verbose=None): """Specify the sensor types of channels. Parameters @@ -342,64 +411,66 @@ def set_channel_types(self, mapping, *, on_unit_change='warn', verbose=None): .. versionadded:: 0.9.0 """ - ch_names = self.info['ch_names'] + ch_names = self.info["ch_names"] # first check and assemble clean mappings of index and name unit_changes = dict() for ch_name, ch_type in mapping.items(): if ch_name not in ch_names: - raise ValueError("This channel name (%s) doesn't exist in " - "info." % ch_name) + raise ValueError( + "This channel name (%s) doesn't exist in " "info." % ch_name + ) c_ind = ch_names.index(ch_name) if ch_type not in _human2fiff: - raise ValueError('This function cannot change to this ' - 'channel type: %s. Accepted channel types ' - 'are %s.' - % (ch_type, - ", ".join(sorted(_human2unit.keys())))) + raise ValueError( + "This function cannot change to this " + "channel type: %s. Accepted channel types " + "are %s." % (ch_type, ", ".join(sorted(_human2unit.keys()))) + ) # Set sensor type - _check_set(self.info['chs'][c_ind], self.info['projs'], ch_type) - unit_old = self.info['chs'][c_ind]['unit'] + _check_set(self.info["chs"][c_ind], self.info["projs"], ch_type) + unit_old = self.info["chs"][c_ind]["unit"] unit_new = _human2unit[ch_type] if unit_old not in _unit2human: - raise ValueError("Channel '%s' has unknown unit (%s). Please " - "fix the measurement info of your data." - % (ch_name, unit_old)) + raise ValueError( + "Channel '%s' has unknown unit (%s). Please " + "fix the measurement info of your data." % (ch_name, unit_old) + ) if unit_old != _human2unit[ch_type]: this_change = (_unit2human[unit_old], _unit2human[unit_new]) if this_change not in unit_changes: unit_changes[this_change] = list() unit_changes[this_change].append(ch_name) - self.info['chs'][c_ind]['unit'] = _human2unit[ch_type] - if ch_type in ['eeg', 'seeg', 'ecog', 'dbs']: + self.info["chs"][c_ind]["unit"] = _human2unit[ch_type] + if ch_type in ["eeg", "seeg", "ecog", "dbs"]: coil_type = FIFF.FIFFV_COIL_EEG - elif ch_type == 'hbo': + elif ch_type == "hbo": coil_type = FIFF.FIFFV_COIL_FNIRS_HBO - elif ch_type == 'hbr': + elif ch_type == "hbr": coil_type = FIFF.FIFFV_COIL_FNIRS_HBR - elif ch_type == 'fnirs_cw_amplitude': + elif ch_type == "fnirs_cw_amplitude": coil_type = FIFF.FIFFV_COIL_FNIRS_CW_AMPLITUDE - elif ch_type == 'fnirs_fd_ac_amplitude': + elif ch_type == "fnirs_fd_ac_amplitude": coil_type = FIFF.FIFFV_COIL_FNIRS_FD_AC_AMPLITUDE - elif ch_type == 'fnirs_fd_phase': + elif ch_type == "fnirs_fd_phase": coil_type = FIFF.FIFFV_COIL_FNIRS_FD_PHASE - elif ch_type == 'fnirs_od': + elif ch_type == "fnirs_od": coil_type = FIFF.FIFFV_COIL_FNIRS_OD - elif ch_type == 'eyetrack_pos': + elif ch_type == "eyetrack_pos": coil_type = FIFF.FIFFV_COIL_EYETRACK_POS - elif ch_type == 'eyetrack_pupil': + elif ch_type == "eyetrack_pupil": coil_type = FIFF.FIFFV_COIL_EYETRACK_PUPIL else: coil_type = FIFF.FIFFV_COIL_NONE - self.info['chs'][c_ind]['coil_type'] = coil_type + self.info["chs"][c_ind]["coil_type"] = coil_type msg = "The unit for channel(s) {0} has changed from {1} to {2}." for this_change, names in unit_changes.items(): _on_missing( on_missing=on_unit_change, msg=msg.format(", ".join(sorted(names)), *this_change), - name='on_unit_change', + name="on_unit_change", ) return self @@ -427,13 +498,13 @@ def rename_channels(self, mapping, allow_duplicates=False, verbose=None): """ from ..io import BaseRaw - ch_names_orig = list(self.info['ch_names']) + ch_names_orig = list(self.info["ch_names"]) rename_channels(self.info, mapping, allow_duplicates) # Update self._orig_units for Raw if isinstance(self, BaseRaw): # whatever mapping was provided, now we can just use a dict - mapping = dict(zip(ch_names_orig, self.info['ch_names'])) + mapping = dict(zip(ch_names_orig, self.info["ch_names"])) for old_name, new_name in mapping.items(): if old_name in self._orig_units: self._orig_units[new_name] = self._orig_units.pop(old_name) @@ -444,10 +515,20 @@ def rename_channels(self, mapping, allow_duplicates=False, verbose=None): return self @verbose - def plot_sensors(self, kind='topomap', ch_type=None, title=None, - show_names=False, ch_groups=None, to_sphere=True, - axes=None, block=False, show=True, sphere=None, - verbose=None): + def plot_sensors( + self, + kind="topomap", + ch_type=None, + title=None, + show_names=False, + ch_groups=None, + to_sphere=True, + axes=None, + block=False, + show=True, + sphere=None, + verbose=None, + ): """Plot sensor positions. Parameters @@ -518,10 +599,21 @@ def plot_sensors(self, kind='topomap', ch_type=None, title=None, .. versionadded:: 0.12.0 """ from ..viz.utils import plot_sensors - return plot_sensors(self.info, kind=kind, ch_type=ch_type, title=title, - show_names=show_names, ch_groups=ch_groups, - to_sphere=to_sphere, axes=axes, block=block, - show=show, sphere=sphere, verbose=verbose) + + return plot_sensors( + self.info, + kind=kind, + ch_type=ch_type, + title=title, + show_names=show_names, + ch_groups=ch_groups, + to_sphere=to_sphere, + axes=axes, + block=block, + show=show, + sphere=sphere, + verbose=verbose, + ) @verbose def anonymize(self, daysback=None, keep_his=False, verbose=None): @@ -544,9 +636,8 @@ def anonymize(self, daysback=None, keep_his=False, verbose=None): .. versionadded:: 0.13.0 """ - anonymize_info(self.info, daysback=daysback, keep_his=keep_his, - verbose=verbose) - self.set_meas_date(self.info['meas_date']) # unify annot update + anonymize_info(self.info, daysback=daysback, keep_his=keep_his, verbose=verbose) + self.set_meas_date(self.info["meas_date"]) # unify annot update return self def set_meas_date(self, meas_date): @@ -580,25 +671,26 @@ def set_meas_date(self, meas_date): .. versionadded:: 0.20 """ from ..annotations import _handle_meas_date + meas_date = _handle_meas_date(meas_date) with self.info._unlock(): - self.info['meas_date'] = meas_date + self.info["meas_date"] = meas_date # clear file_id and meas_id if needed if meas_date is None: - for key in ('file_id', 'meas_id'): + for key in ("file_id", "meas_id"): value = self.info.get(key) if value is not None: - assert 'msecs' not in value - value['secs'] = DATE_NONE[0] - value['usecs'] = DATE_NONE[1] + assert "msecs" not in value + value["secs"] = DATE_NONE[0] + value["usecs"] = DATE_NONE[1] # The following copy is needed for a test CTF dataset # otherwise value['machid'][:] = 0 would suffice - _tmp = value['machid'].copy() + _tmp = value["machid"].copy() _tmp[:] = 0 - value['machid'] = _tmp + value["machid"] = _tmp - if hasattr(self, 'annotations'): + if hasattr(self, "annotations"): self.annotations._orig_time = meas_date return self @@ -607,14 +699,39 @@ class UpdateChannelsMixin: """Mixin class for Raw, Evoked, Epochs, Spectrum, AverageTFR.""" @verbose - @legacy(alt='inst.pick(...)') - def pick_types(self, meg=False, eeg=False, stim=False, eog=False, - ecg=False, emg=False, ref_meg='auto', *, misc=False, - resp=False, chpi=False, exci=False, ias=False, syst=False, - seeg=False, dipole=False, gof=False, bio=False, - ecog=False, fnirs=False, csd=False, dbs=False, - temperature=False, gsr=False, eyetrack=False, - include=(), exclude='bads', selection=None, verbose=None): + @legacy(alt="inst.pick(...)") + def pick_types( + self, + meg=False, + eeg=False, + stim=False, + eog=False, + ecg=False, + emg=False, + ref_meg="auto", + *, + misc=False, + resp=False, + chpi=False, + exci=False, + ias=False, + syst=False, + seeg=False, + dipole=False, + gof=False, + bio=False, + ecog=False, + fnirs=False, + csd=False, + dbs=False, + temperature=False, + gsr=False, + eyetrack=False, + include=(), + exclude="bads", + selection=None, + verbose=None, + ): """Pick some channels by type and names. Parameters @@ -636,24 +753,47 @@ def pick_types(self, meg=False, eeg=False, stim=False, eog=False, .. versionadded:: 0.9.0 """ idx = pick_types( - self.info, meg=meg, eeg=eeg, stim=stim, eog=eog, ecg=ecg, emg=emg, - ref_meg=ref_meg, misc=misc, resp=resp, chpi=chpi, exci=exci, - ias=ias, syst=syst, seeg=seeg, dipole=dipole, gof=gof, bio=bio, - ecog=ecog, fnirs=fnirs, csd=csd, dbs=dbs, temperature=temperature, - gsr=gsr, eyetrack=eyetrack, include=include, exclude=exclude, - selection=selection) + self.info, + meg=meg, + eeg=eeg, + stim=stim, + eog=eog, + ecg=ecg, + emg=emg, + ref_meg=ref_meg, + misc=misc, + resp=resp, + chpi=chpi, + exci=exci, + ias=ias, + syst=syst, + seeg=seeg, + dipole=dipole, + gof=gof, + bio=bio, + ecog=ecog, + fnirs=fnirs, + csd=csd, + dbs=dbs, + temperature=temperature, + gsr=gsr, + eyetrack=eyetrack, + include=include, + exclude=exclude, + selection=selection, + ) self._pick_drop_channels(idx) # remove dropped channel types from reject and flat - if getattr(self, 'reject', None) is not None: + if getattr(self, "reject", None) is not None: # use list(self.reject) to avoid RuntimeError for changing # dictionary size during iteration for ch_type in list(self.reject): if ch_type not in self: del self.reject[ch_type] - if getattr(self, 'flat', None) is not None: + if getattr(self, "flat", None) is not None: for ch_type in list(self.flat): if ch_type not in self: del self.flat[ch_type] @@ -661,7 +801,7 @@ def pick_types(self, meg=False, eeg=False, stim=False, eog=False, return self @verbose - @legacy(alt='inst.pick(...)') + @legacy(alt="inst.pick(...)") def pick_channels(self, ch_names, ordered=None, *, verbose=None): """Pick some channels. @@ -693,7 +833,7 @@ def pick_channels(self, ch_names, ordered=None, *, verbose=None): .. versionadded:: 0.9.0 """ - picks = pick_channels(self.info['ch_names'], ch_names, ordered=ordered) + picks = pick_channels(self.info["ch_names"], ch_names, ordered=ordered) return self._pick_drop_channels(picks) @verbose @@ -715,8 +855,7 @@ def pick(self, picks, exclude=(), *, verbose=None): inst : instance of Raw, Epochs, or Evoked The modified instance. """ - picks = _picks_to_idx(self.info, picks, 'all', exclude, - allow_empty=False) + picks = _picks_to_idx(self.info, picks, "all", exclude, allow_empty=False) return self._pick_drop_channels(picks) def reorder_channels(self, ch_names): @@ -750,12 +889,12 @@ def reorder_channels(self, ch_names): for ch_name in ch_names: ii = self.ch_names.index(ch_name) if ii in idx: - raise ValueError('Channel name repeated: %s' % (ch_name,)) + raise ValueError("Channel name repeated: %s" % (ch_name,)) idx.append(ii) return self._pick_drop_channels(idx) @fill_doc - def drop_channels(self, ch_names, on_missing='raise'): + def drop_channels(self, ch_names, on_missing="raise"): """Drop channel(s). Parameters @@ -785,20 +924,23 @@ def drop_channels(self, ch_names, on_missing='raise'): try: all_str = all([isinstance(ch, str) for ch in ch_names]) except TypeError: - raise ValueError("'ch_names' must be iterable, got " - "type {} ({}).".format(type(ch_names), ch_names)) + raise ValueError( + "'ch_names' must be iterable, got " + "type {} ({}).".format(type(ch_names), ch_names) + ) if not all_str: - raise ValueError("Each element in 'ch_names' must be str, got " - "{}.".format([type(ch) for ch in ch_names])) + raise ValueError( + "Each element in 'ch_names' must be str, got " + "{}.".format([type(ch) for ch in ch_names]) + ) missing = [ch for ch in ch_names if ch not in self.ch_names] if len(missing) > 0: msg = "Channel(s) {0} not found, nothing dropped." _on_missing(on_missing, msg.format(", ".join(missing))) - bad_idx = [self.ch_names.index(ch) for ch in ch_names - if ch in self.ch_names] + bad_idx = [self.ch_names.index(ch) for ch in ch_names if ch in self.ch_names] idx = np.setdiff1d(np.arange(len(self.ch_names)), bad_idx) return self._pick_drop_channels(idx) @@ -809,45 +951,45 @@ def _pick_drop_channels(self, idx, *, verbose=None): from ..time_frequency import AverageTFR, EpochsTFR from ..time_frequency.spectrum import BaseSpectrum - msg = 'adding, dropping, or reordering channels' + msg = "adding, dropping, or reordering channels" if isinstance(self, BaseRaw): if self._projector is not None: - _check_preload(self, f'{msg} after calling .apply_proj()') + _check_preload(self, f"{msg} after calling .apply_proj()") else: _check_preload(self, msg) - if getattr(self, 'picks', None) is not None: + if getattr(self, "picks", None) is not None: self.picks = self.picks[idx] - if getattr(self, '_read_picks', None) is not None: + if getattr(self, "_read_picks", None) is not None: self._read_picks = [r[idx] for r in self._read_picks] - if hasattr(self, '_cals'): + if hasattr(self, "_cals"): self._cals = self._cals[idx] pick_info(self.info, idx, copy=False) - for key in ('_comp', '_projector'): + for key in ("_comp", "_projector"): mat = getattr(self, key, None) if mat is not None: setattr(self, key, mat[idx][:, idx]) if isinstance(self, BaseSpectrum): - axis = self._dims.index('channel') + axis = self._dims.index("channel") elif isinstance(self, (AverageTFR, EpochsTFR)): axis = -3 else: # All others (Evoked, Epochs, Raw) have chs axis=-2 axis = -2 - if hasattr(self, '_data'): # skip non-preloaded Raw + if hasattr(self, "_data"): # skip non-preloaded Raw self._data = self._data.take(idx, axis=axis) else: assert isinstance(self, BaseRaw) and not self.preload if isinstance(self, BaseRaw): - self.annotations._prune_ch_names(self.info, on_missing='ignore') + self.annotations._prune_ch_names(self.info, on_missing="ignore") self._orig_units = { - k: v for k, v in self._orig_units.items() - if k in self.ch_names} + k: v for k, v in self._orig_units.items() if k in self.ch_names + } self._pick_projs() return self @@ -855,14 +997,14 @@ def _pick_drop_channels(self, idx, *, verbose=None): def _pick_projs(self): """Keep only projectors which apply to at least 1 data channel.""" drop_idx = [] - for idx, proj in enumerate(self.info['projs']): - if not set(self.info['ch_names']) & set(proj['data']['col_names']): + for idx, proj in enumerate(self.info["projs"]): + if not set(self.info["ch_names"]) & set(proj["data"]["col_names"]): drop_idx.append(idx) for idx in drop_idx: logger.info(f"Removing projector {self.info['projs'][idx]}") - if drop_idx and hasattr(self, 'del_proj'): + if drop_idx and hasattr(self, "del_proj"): self.del_proj(drop_idx) return self @@ -900,7 +1042,7 @@ def add_channels(self, add_list, force_update_info=False): from ..io import BaseRaw, _merge_info from ..epochs import BaseEpochs - _validate_type(add_list, (list, tuple), 'Input') + _validate_type(add_list, (list, tuple), "Input") # Object-specific checks for inst in add_list + [self]: @@ -915,7 +1057,7 @@ def add_channels(self, add_list, force_update_info=False): con_axis = 0 comp_class = type(self) for inst in add_list: - _validate_type(inst, comp_class, 'All input') + _validate_type(inst, comp_class, "All input") data = [inst._data for inst in [self] + add_list] # Make sure that all dimensions other than channel axis are the same @@ -924,8 +1066,9 @@ def add_channels(self, add_list, force_update_info=False): for shape in shapes: if not ((shapes[0] - shape) == 0).all(): raise ValueError( - 'All data dimensions except channels must match, got ' - f'{shapes[0]} != {shape}') + "All data dimensions except channels must match, got " + f"{shapes[0]} != {shape}" + ) del shapes # Create final data / info objects @@ -933,43 +1076,50 @@ def add_channels(self, add_list, force_update_info=False): new_info = _merge_info(infos, force_update_to_first=force_update_info) # Now update the attributes - if isinstance(self._data, np.memmap) and con_axis == 0 and \ - sys.platform != 'darwin': # resizing not available--no mremap + if ( + isinstance(self._data, np.memmap) + and con_axis == 0 + and sys.platform != "darwin" + ): # resizing not available--no mremap # Use a resize and fill in other ones out_shape = (sum(d.shape[0] for d in data),) + data[0].shape[1:] n_bytes = np.prod(out_shape) * self._data.dtype.itemsize self._data.flush() self._data.base.resize(n_bytes) - self._data = np.memmap(self._data.filename, mode='r+', - dtype=self._data.dtype, shape=out_shape) + self._data = np.memmap( + self._data.filename, mode="r+", dtype=self._data.dtype, shape=out_shape + ) assert self._data.shape == out_shape assert self._data.nbytes == n_bytes offset = len(data[0]) for d in data[1:]: this_len = len(d) - self._data[offset:offset + this_len] = d + self._data[offset : offset + this_len] = d offset += this_len else: self._data = np.concatenate(data, axis=con_axis) self.info = new_info if isinstance(self, BaseRaw): - self._cals = np.concatenate([getattr(inst, '_cals') - for inst in [self] + add_list]) + self._cals = np.concatenate( + [getattr(inst, "_cals") for inst in [self] + add_list] + ) # We should never use these since data are preloaded, let's just # set it to something large and likely to break (2 ** 31 - 1) - extra_idx = [2147483647] * sum(info['nchan'] for info in infos[1:]) - assert all(len(r) == infos[0]['nchan'] for r in self._read_picks) + extra_idx = [2147483647] * sum(info["nchan"] for info in infos[1:]) + assert all(len(r) == infos[0]["nchan"] for r in self._read_picks) self._read_picks = [ - np.concatenate([r, extra_idx]) for r in self._read_picks] - assert all(len(r) == self.info['nchan'] for r in self._read_picks) + np.concatenate([r, extra_idx]) for r in self._read_picks + ] + assert all(len(r) == self.info["nchan"] for r in self._read_picks) for other in add_list: self._orig_units.update(other._orig_units) elif isinstance(self, BaseEpochs): self.picks = np.arange(self._data.shape[1]) - if hasattr(self, '_projector'): + if hasattr(self, "_projector"): activate = False if self._do_delayed_proj else self.proj - self._projector, self.info = setup_proj(self.info, False, - activate=activate) + self._projector, self.info = setup_proj( + self.info, False, activate=activate + ) return self @@ -999,9 +1149,15 @@ class InterpolationMixin: """Mixin class for Raw, Evoked, Epochs.""" @verbose - def interpolate_bads(self, reset_bads=True, mode='accurate', - origin='auto', method=None, exclude=(), - verbose=None): + def interpolate_bads( + self, + reset_bads=True, + mode="accurate", + origin="auto", + method=None, + exclude=(), + verbose=None, + ): """Interpolate bad MEG and EEG channels. Operates in place. @@ -1052,34 +1208,37 @@ def interpolate_bads(self, reset_bads=True, mode='accurate', .. versionadded:: 0.9.0 """ from ..bem import _check_origin - from .interpolation import _interpolate_bads_eeg,\ - _interpolate_bads_meeg, _interpolate_bads_nirs + from .interpolation import ( + _interpolate_bads_eeg, + _interpolate_bads_meeg, + _interpolate_bads_nirs, + ) _check_preload(self, "interpolation") - method = _handle_default('interpolation_method', method) + method = _handle_default("interpolation_method", method) for key in method: - _check_option('method[key]', key, ('meg', 'eeg', 'fnirs')) - _check_option("method['eeg']", method['eeg'], ('spline', 'MNE')) - _check_option("method['meg']", method['meg'], ('MNE',)) - _check_option("method['fnirs']", method['fnirs'], ('nearest',)) + _check_option("method[key]", key, ("meg", "eeg", "fnirs")) + _check_option("method['eeg']", method["eeg"], ("spline", "MNE")) + _check_option("method['meg']", method["meg"], ("MNE",)) + _check_option("method['fnirs']", method["fnirs"], ("nearest",)) - if len(self.info['bads']) == 0: - warn('No bad channels to interpolate. Doing nothing...') + if len(self.info["bads"]) == 0: + warn("No bad channels to interpolate. Doing nothing...") return self - logger.info('Interpolating bad channels') + logger.info("Interpolating bad channels") origin = _check_origin(origin, self.info) - if method['eeg'] == 'spline': + if method["eeg"] == "spline": _interpolate_bads_eeg(self, origin=origin, exclude=exclude) eeg_mne = False else: eeg_mne = True - _interpolate_bads_meeg(self, mode=mode, origin=origin, eeg=eeg_mne, - exclude=exclude) + _interpolate_bads_meeg( + self, mode=mode, origin=origin, eeg=eeg_mne, exclude=exclude + ) _interpolate_bads_nirs(self, exclude=exclude) if reset_bads is True: - self.info['bads'] = \ - [ch for ch in self.info['bads'] if ch in exclude] + self.info["bads"] = [ch for ch in self.info["bads"] if ch in exclude] return self @@ -1094,27 +1253,30 @@ def rename_channels(info, mapping, allow_duplicates=False, verbose=None): %(mapping_rename_channels_duplicates)s %(verbose)s """ - _validate_type(info, Info, 'info') + _validate_type(info, Info, "info") info._check_consistency() - bads = list(info['bads']) # make our own local copies - ch_names = list(info['ch_names']) + bads = list(info["bads"]) # make our own local copies + ch_names = list(info["ch_names"]) # first check and assemble clean mappings of index and name if isinstance(mapping, dict): - _check_dict_keys(mapping, ch_names, key_description="channel name(s)", - valid_key_source="info") - new_names = [(ch_names.index(ch_name), new_name) - for ch_name, new_name in mapping.items()] + _check_dict_keys( + mapping, + ch_names, + key_description="channel name(s)", + valid_key_source="info", + ) + new_names = [ + (ch_names.index(ch_name), new_name) for ch_name, new_name in mapping.items() + ] elif callable(mapping): - new_names = [(ci, mapping(ch_name)) - for ci, ch_name in enumerate(ch_names)] + new_names = [(ci, mapping(ch_name)) for ci, ch_name in enumerate(ch_names)] else: - raise ValueError('mapping must be callable or dict, not %s' - % (type(mapping),)) + raise ValueError("mapping must be callable or dict, not %s" % (type(mapping),)) # check we got all strings out of the mapping for new_name in new_names: - _validate_type(new_name[1], 'str', 'New channel mappings') + _validate_type(new_name[1], "str", "New channel mappings") # do the remapping locally for c_ind, new_name in new_names: @@ -1125,20 +1287,21 @@ def rename_channels(info, mapping, allow_duplicates=False, verbose=None): # check that all the channel names are unique if len(ch_names) != len(np.unique(ch_names)) and not allow_duplicates: - raise ValueError('New channel names are not unique, renaming failed') + raise ValueError("New channel names are not unique, renaming failed") # do the remapping in info - info['bads'] = bads + info["bads"] = bads ch_names_mapping = dict() - for ch, ch_name in zip(info['chs'], ch_names): - ch_names_mapping[ch['ch_name']] = ch_name - ch['ch_name'] = ch_name + for ch, ch_name in zip(info["chs"], ch_names): + ch_names_mapping[ch["ch_name"]] = ch_name + ch["ch_name"] = ch_name # .get b/c fwd info omits it - _rename_comps(info.get('comps', []), ch_names_mapping) - if 'projs' in info: # fwd might omit it - for proj in info['projs']: - proj['data']['col_names'][:] = \ - _rename_list(proj['data']['col_names'], ch_names_mapping) + _rename_comps(info.get("comps", []), ch_names_mapping) + if "projs" in info: # fwd might omit it + for proj in info["projs"]: + proj["data"]["col_names"][:] = _rename_list( + proj["data"]["col_names"], ch_names_mapping + ) info._update_redundant() info._check_consistency() @@ -1160,244 +1323,277 @@ class _BuiltinChannelAdjacency: _ft_neighbor_url_t = string.Template( - 'https://github.com/fieldtrip/fieldtrip/raw/master/' - 'template/neighbours/$fname' + "https://github.com/fieldtrip/fieldtrip/raw/master/" "template/neighbours/$fname" ) _BUILTIN_CHANNEL_ADJACENCIES = [ _BuiltinChannelAdjacency( - name='biosemi16', - description='Biosemi 16-electrode cap', - fname='biosemi16_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='biosemi16_neighb.mat'), + name="biosemi16", + description="Biosemi 16-electrode cap", + fname="biosemi16_neighb.mat", + source_url=_ft_neighbor_url_t.substitute(fname="biosemi16_neighb.mat"), ), _BuiltinChannelAdjacency( - name='biosemi32', - description='Biosemi 32-electrode cap', - fname='biosemi32_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='biosemi32_neighb.mat'), + name="biosemi32", + description="Biosemi 32-electrode cap", + fname="biosemi32_neighb.mat", + source_url=_ft_neighbor_url_t.substitute(fname="biosemi32_neighb.mat"), ), _BuiltinChannelAdjacency( - name='biosemi64', - description='Biosemi 64-electrode cap', - fname='biosemi64_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='biosemi64_neighb.mat'), + name="biosemi64", + description="Biosemi 64-electrode cap", + fname="biosemi64_neighb.mat", + source_url=_ft_neighbor_url_t.substitute(fname="biosemi64_neighb.mat"), ), _BuiltinChannelAdjacency( - name='bti148', - description='BTI 148-channel system', - fname='bti148_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='bti148_neighb.mat'), + name="bti148", + description="BTI 148-channel system", + fname="bti148_neighb.mat", + source_url=_ft_neighbor_url_t.substitute(fname="bti148_neighb.mat"), ), _BuiltinChannelAdjacency( - name='bti248', - description='BTI 248-channel system', - fname='bti248_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='bti248_neighb.mat'), + name="bti248", + description="BTI 248-channel system", + fname="bti248_neighb.mat", + source_url=_ft_neighbor_url_t.substitute(fname="bti248_neighb.mat"), ), _BuiltinChannelAdjacency( - name='bti248grad', - description='BTI 248 gradiometer system', - fname='bti248grad_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='bti248grad_neighb.mat'), # noqa: E501 + name="bti248grad", + description="BTI 248 gradiometer system", + fname="bti248grad_neighb.mat", + source_url=_ft_neighbor_url_t.substitute( + fname="bti248grad_neighb.mat" + ), # noqa: E501 ), _BuiltinChannelAdjacency( - name='ctf64', - description='CTF 64 axial gradiometer', - fname='ctf64_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='ctf64_neighb.mat'), + name="ctf64", + description="CTF 64 axial gradiometer", + fname="ctf64_neighb.mat", + source_url=_ft_neighbor_url_t.substitute(fname="ctf64_neighb.mat"), ), _BuiltinChannelAdjacency( - name='ctf151', - description='CTF 151 axial gradiometer', - fname='ctf151_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='ctf151_neighb.mat'), + name="ctf151", + description="CTF 151 axial gradiometer", + fname="ctf151_neighb.mat", + source_url=_ft_neighbor_url_t.substitute(fname="ctf151_neighb.mat"), ), _BuiltinChannelAdjacency( - name='ctf275', - description='CTF 275 axial gradiometer', - fname='ctf275_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='ctf275_neighb.mat'), + name="ctf275", + description="CTF 275 axial gradiometer", + fname="ctf275_neighb.mat", + source_url=_ft_neighbor_url_t.substitute(fname="ctf275_neighb.mat"), ), _BuiltinChannelAdjacency( - name='easycap32ch-avg', - description='', - fname='easycap32ch-avg_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='easycap32ch-avg_neighb.mat'), # noqa: E501 + name="easycap32ch-avg", + description="", + fname="easycap32ch-avg_neighb.mat", + source_url=_ft_neighbor_url_t.substitute( + fname="easycap32ch-avg_neighb.mat" + ), # noqa: E501 ), _BuiltinChannelAdjacency( - name='easycap64ch-avg', - description='', - fname='easycap64ch-avg_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='easycap64ch-avg_neighb.mat'), # noqa: E501 + name="easycap64ch-avg", + description="", + fname="easycap64ch-avg_neighb.mat", + source_url=_ft_neighbor_url_t.substitute( + fname="easycap64ch-avg_neighb.mat" + ), # noqa: E501 ), _BuiltinChannelAdjacency( - name='easycap128ch-avg', - description='', - fname='easycap128ch-avg_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='easycap128ch-avg_neighb.mat'), # noqa: E501 + name="easycap128ch-avg", + description="", + fname="easycap128ch-avg_neighb.mat", + source_url=_ft_neighbor_url_t.substitute( + fname="easycap128ch-avg_neighb.mat" + ), # noqa: E501 ), _BuiltinChannelAdjacency( - name='easycapM1', - description='Easycap M1', - fname='easycapM1_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='easycapM1_neighb.mat'), + name="easycapM1", + description="Easycap M1", + fname="easycapM1_neighb.mat", + source_url=_ft_neighbor_url_t.substitute(fname="easycapM1_neighb.mat"), ), _BuiltinChannelAdjacency( - name='easycapM11', - description='Easycap M11', - fname='easycapM11_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='easycapM11_neighb.mat'), # noqa: E501 + name="easycapM11", + description="Easycap M11", + fname="easycapM11_neighb.mat", + source_url=_ft_neighbor_url_t.substitute( + fname="easycapM11_neighb.mat" + ), # noqa: E501 ), _BuiltinChannelAdjacency( - name='easycapM14', - description='Easycap M14', - fname='easycapM14_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='easycapM14_neighb.mat'), # noqa: E501 + name="easycapM14", + description="Easycap M14", + fname="easycapM14_neighb.mat", + source_url=_ft_neighbor_url_t.substitute( + fname="easycapM14_neighb.mat" + ), # noqa: E501 ), _BuiltinChannelAdjacency( - name='easycapM15', - description='Easycap M15', - fname='easycapM15_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='easycapM15_neighb.mat'), # noqa: E501 + name="easycapM15", + description="Easycap M15", + fname="easycapM15_neighb.mat", + source_url=_ft_neighbor_url_t.substitute( + fname="easycapM15_neighb.mat" + ), # noqa: E501 ), _BuiltinChannelAdjacency( - name='KIT-157', - description='', - fname='KIT-157_neighb.mat', + name="KIT-157", + description="", + fname="KIT-157_neighb.mat", source_url=None, ), _BuiltinChannelAdjacency( - name='KIT-208', - description='', - fname='KIT-208_neighb.mat', + name="KIT-208", + description="", + fname="KIT-208_neighb.mat", source_url=None, ), _BuiltinChannelAdjacency( - name='KIT-NYU-2019', - description='', - fname='KIT-NYU-2019_neighb.mat', + name="KIT-NYU-2019", + description="", + fname="KIT-NYU-2019_neighb.mat", source_url=None, ), _BuiltinChannelAdjacency( - name='KIT-UMD-1', - description='', - fname='KIT-UMD-1_neighb.mat', + name="KIT-UMD-1", + description="", + fname="KIT-UMD-1_neighb.mat", source_url=None, ), _BuiltinChannelAdjacency( - name='KIT-UMD-2', - description='', - fname='KIT-UMD-2_neighb.mat', + name="KIT-UMD-2", + description="", + fname="KIT-UMD-2_neighb.mat", source_url=None, ), _BuiltinChannelAdjacency( - name='KIT-UMD-3', - description='', - fname='KIT-UMD-3_neighb.mat', + name="KIT-UMD-3", + description="", + fname="KIT-UMD-3_neighb.mat", source_url=None, ), _BuiltinChannelAdjacency( - name='KIT-UMD-4', - description='', - fname='KIT-UMD-4_neighb.mat', + name="KIT-UMD-4", + description="", + fname="KIT-UMD-4_neighb.mat", source_url=None, ), _BuiltinChannelAdjacency( - name='neuromag306mag', - description='Neuromag306, only magnetometers', - fname='neuromag306mag_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='neuromag306mag_neighb.mat'), # noqa: E501 + name="neuromag306mag", + description="Neuromag306, only magnetometers", + fname="neuromag306mag_neighb.mat", + source_url=_ft_neighbor_url_t.substitute( + fname="neuromag306mag_neighb.mat" + ), # noqa: E501 ), _BuiltinChannelAdjacency( - name='neuromag306planar', - description='Neuromag306, only planar gradiometers', - fname='neuromag306planar_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='neuromag306planar_neighb.mat'), # noqa: E501 + name="neuromag306planar", + description="Neuromag306, only planar gradiometers", + fname="neuromag306planar_neighb.mat", + source_url=_ft_neighbor_url_t.substitute( + fname="neuromag306planar_neighb.mat" + ), # noqa: E501 ), _BuiltinChannelAdjacency( - name='neuromag122cmb', - description='Neuromag122, only combined planar gradiometers', - fname='neuromag122cmb_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='neuromag122cmb_neighb.mat'), # noqa: E501 + name="neuromag122cmb", + description="Neuromag122, only combined planar gradiometers", + fname="neuromag122cmb_neighb.mat", + source_url=_ft_neighbor_url_t.substitute( + fname="neuromag122cmb_neighb.mat" + ), # noqa: E501 ), _BuiltinChannelAdjacency( - name='neuromag306cmb', - description='Neuromag306, only combined planar gradiometers', - fname='neuromag306cmb_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='neuromag306cmb_neighb.mat'), # noqa: E501 + name="neuromag306cmb", + description="Neuromag306, only combined planar gradiometers", + fname="neuromag306cmb_neighb.mat", + source_url=_ft_neighbor_url_t.substitute( + fname="neuromag306cmb_neighb.mat" + ), # noqa: E501 ), _BuiltinChannelAdjacency( - name='ecog256', - description='ECOG 256channels, average referenced', - fname='ecog256_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='ecog256_neighb.mat'), # noqa: E501 + name="ecog256", + description="ECOG 256channels, average referenced", + fname="ecog256_neighb.mat", + source_url=_ft_neighbor_url_t.substitute( + fname="ecog256_neighb.mat" + ), # noqa: E501 ), _BuiltinChannelAdjacency( - name='ecog256bipolar', - description='ECOG 256channels, bipolar referenced', - fname='ecog256bipolar_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='ecog256bipolar_neighb.mat'), # noqa: E501 + name="ecog256bipolar", + description="ECOG 256channels, bipolar referenced", + fname="ecog256bipolar_neighb.mat", + source_url=_ft_neighbor_url_t.substitute( + fname="ecog256bipolar_neighb.mat" + ), # noqa: E501 ), _BuiltinChannelAdjacency( - name='eeg1010_neighb', - description='', - fname='eeg1010_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='eeg1010_neighb.mat'), + name="eeg1010_neighb", + description="", + fname="eeg1010_neighb.mat", + source_url=_ft_neighbor_url_t.substitute(fname="eeg1010_neighb.mat"), ), _BuiltinChannelAdjacency( - name='elec1005', - description='Standard 10-05 system', - fname='elec1005_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='elec1005_neighb.mat'), + name="elec1005", + description="Standard 10-05 system", + fname="elec1005_neighb.mat", + source_url=_ft_neighbor_url_t.substitute(fname="elec1005_neighb.mat"), ), _BuiltinChannelAdjacency( - name='elec1010', - description='Standard 10-10 system', - fname='elec1010_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='elec1010_neighb.mat'), + name="elec1010", + description="Standard 10-10 system", + fname="elec1010_neighb.mat", + source_url=_ft_neighbor_url_t.substitute(fname="elec1010_neighb.mat"), ), _BuiltinChannelAdjacency( - name='elec1020', - description='Standard 10-20 system', - fname='elec1020_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='elec1020_neighb.mat'), + name="elec1020", + description="Standard 10-20 system", + fname="elec1020_neighb.mat", + source_url=_ft_neighbor_url_t.substitute(fname="elec1020_neighb.mat"), ), _BuiltinChannelAdjacency( - name='itab28', - description='ITAB 28-channel system', - fname='itab28_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='itab28_neighb.mat'), + name="itab28", + description="ITAB 28-channel system", + fname="itab28_neighb.mat", + source_url=_ft_neighbor_url_t.substitute(fname="itab28_neighb.mat"), ), _BuiltinChannelAdjacency( - name='itab153', - description='ITAB 153-channel system', - fname='itab153_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='itab153_neighb.mat'), + name="itab153", + description="ITAB 153-channel system", + fname="itab153_neighb.mat", + source_url=_ft_neighbor_url_t.substitute(fname="itab153_neighb.mat"), ), _BuiltinChannelAdjacency( - name='language29ch-avg', - description='MPI for Psycholinguistic: Averaged 29-channel cap', - fname='language29ch-avg_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='language29ch-avg_neighb.mat'), # noqa: E501 + name="language29ch-avg", + description="MPI for Psycholinguistic: Averaged 29-channel cap", + fname="language29ch-avg_neighb.mat", + source_url=_ft_neighbor_url_t.substitute( + fname="language29ch-avg_neighb.mat" + ), # noqa: E501 ), _BuiltinChannelAdjacency( - name='mpi_59_channels', - description='MPI for Psycholinguistic: 59-channel cap', - fname='mpi_59_channels_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='mpi_59_channels_neighb.mat'), # noqa: E501 + name="mpi_59_channels", + description="MPI for Psycholinguistic: 59-channel cap", + fname="mpi_59_channels_neighb.mat", + source_url=_ft_neighbor_url_t.substitute( + fname="mpi_59_channels_neighb.mat" + ), # noqa: E501 ), _BuiltinChannelAdjacency( - name='yokogawa160', - description='', - fname='yokogawa160_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='yokogawa160_neighb.mat'), # noqa: E501 + name="yokogawa160", + description="", + fname="yokogawa160_neighb.mat", + source_url=_ft_neighbor_url_t.substitute( + fname="yokogawa160_neighb.mat" + ), # noqa: E501 ), _BuiltinChannelAdjacency( - name='yokogawa440', - description='', - fname='yokogawa440_neighb.mat', - source_url=_ft_neighbor_url_t.substitute(fname='yokogawa440_neighb.mat'), # noqa: E501 + name="yokogawa440", + description="", + fname="yokogawa440_neighb.mat", + source_url=_ft_neighbor_url_t.substitute( + fname="yokogawa440_neighb.mat" + ), # noqa: E501 ), ] @@ -1433,13 +1629,10 @@ def get_builtin_ch_adjacencies(*, descriptions=False): if descriptions: return sorted( [(m.name, m.description) for m in _BUILTIN_CHANNEL_ADJACENCIES], - key=lambda x: x[0].casefold() # only sort based on name + key=lambda x: x[0].casefold(), # only sort based on name ) else: - return sorted( - [m.name for m in _BUILTIN_CHANNEL_ADJACENCIES], - key=str.casefold - ) + return sorted([m.name for m in _BUILTIN_CHANNEL_ADJACENCIES], key=str.casefold) @fill_doc @@ -1488,6 +1681,7 @@ def read_ch_adjacency(fname, picks=None): to pass to the eventual function. """ from scipy.io import loadmat + if op.isabs(fname): fname = str( _check_fname( @@ -1499,20 +1693,19 @@ def read_ch_adjacency(fname, picks=None): else: # built-in FieldTrip neighbors ch_adj_name = fname del fname - if ch_adj_name.endswith('_neighb.mat'): # backward-compat - ch_adj_name = ch_adj_name.replace('_neighb.mat', '') + if ch_adj_name.endswith("_neighb.mat"): # backward-compat + ch_adj_name = ch_adj_name.replace("_neighb.mat", "") if ch_adj_name not in get_builtin_ch_adjacencies(): raise ValueError( - f'No built-in channel adjacency matrix found with name: ' - f'{ch_adj_name}. Valid names are: ' + f"No built-in channel adjacency matrix found with name: " + f"{ch_adj_name}. Valid names are: " f'{", ".join(get_builtin_ch_adjacencies())}' ) - ch_adj = [a for a in _BUILTIN_CHANNEL_ADJACENCIES - if a.name == ch_adj_name][0] + ch_adj = [a for a in _BUILTIN_CHANNEL_ADJACENCIES if a.name == ch_adj_name][0] fname = ch_adj.fname - templates_dir = Path(__file__).resolve().parent / 'data' / 'neighbors' + templates_dir = Path(__file__).resolve().parent / "data" / "neighbors" fname = str( _check_fname( # only needed to convert to a string fname=templates_dir / fname, @@ -1521,11 +1714,10 @@ def read_ch_adjacency(fname, picks=None): ) ) - nb = loadmat(fname)['neighbours'] - ch_names = _recursive_flatten(nb['label'], str) + nb = loadmat(fname)["neighbours"] + ch_names = _recursive_flatten(nb["label"], str) picks = _picks_to_idx(len(ch_names), picks) - neighbors = [_recursive_flatten(c, str) for c in - nb['neighblabel'].flatten()] + neighbors = [_recursive_flatten(c, str) for c in nb["neighblabel"].flatten()] assert len(ch_names) == len(neighbors) adjacency = _ch_neighbor_adjacency(ch_names, neighbors) # picking before constructing matrix is buggy @@ -1534,8 +1726,8 @@ def read_ch_adjacency(fname, picks=None): # make sure MEG channel names contain space after "MEG" for idx, ch_name in enumerate(ch_names): - if ch_name.startswith('MEG') and not ch_name[3] == ' ': - ch_name = ch_name.replace('MEG', 'MEG ') + if ch_name.startswith("MEG") and not ch_name[3] == " ": + ch_name = ch_name.replace("MEG", "MEG ") ch_names[idx] = ch_name return adjacency, ch_names @@ -1559,19 +1751,19 @@ def _ch_neighbor_adjacency(ch_names, neighbors): The adjacency matrix. """ from scipy import sparse + if len(ch_names) != len(neighbors): - raise ValueError('`ch_names` and `neighbors` must ' - 'have the same length') + raise ValueError("`ch_names` and `neighbors` must " "have the same length") set_neighbors = {c for d in neighbors for c in d} rest = set_neighbors - set(ch_names) if len(rest) > 0: - raise ValueError('Some of your neighbors are not present in the ' - 'list of channel names') + raise ValueError( + "Some of your neighbors are not present in the " "list of channel names" + ) for neigh in neighbors: - if (not isinstance(neigh, list) and - not all(isinstance(c, str) for c in neigh)): - raise ValueError('`neighbors` must be a list of lists of str') + if not isinstance(neigh, list) and not all(isinstance(c, str) for c in neigh): + raise ValueError("`neighbors` must be a list of lists of str") ch_adjacency = np.eye(len(ch_names), dtype=bool) for ii, neigbs in enumerate(neighbors): @@ -1634,49 +1826,64 @@ def find_ch_adjacency(info, ch_type): if ch_type is None: picks = channel_indices_by_type(info) if sum([len(p) != 0 for p in picks.values()]) != 1: - raise ValueError('info must contain only one channel type if ' - 'ch_type is None.') + raise ValueError( + "info must contain only one channel type if " "ch_type is None." + ) ch_type = channel_type(info, 0) else: - _check_option('ch_type', ch_type, ['mag', 'grad', 'eeg']) - (has_vv_mag, has_vv_grad, is_old_vv, has_4D_mag, ctf_other_types, - has_CTF_grad, n_kit_grads, has_any_meg, has_eeg_coils, - has_eeg_coils_and_meg, has_eeg_coils_only, - has_neuromag_122_grad, has_csd_coils) = _get_ch_info(info) + _check_option("ch_type", ch_type, ["mag", "grad", "eeg"]) + ( + has_vv_mag, + has_vv_grad, + is_old_vv, + has_4D_mag, + ctf_other_types, + has_CTF_grad, + n_kit_grads, + has_any_meg, + has_eeg_coils, + has_eeg_coils_and_meg, + has_eeg_coils_only, + has_neuromag_122_grad, + has_csd_coils, + ) = _get_ch_info(info) conn_name = None - if has_vv_mag and ch_type == 'mag': - conn_name = 'neuromag306mag' - elif has_vv_grad and ch_type == 'grad': - conn_name = 'neuromag306planar' + if has_vv_mag and ch_type == "mag": + conn_name = "neuromag306mag" + elif has_vv_grad and ch_type == "grad": + conn_name = "neuromag306planar" elif has_4D_mag: - if 'MEG 248' in info['ch_names']: - idx = info['ch_names'].index('MEG 248') - grad = info['chs'][idx]['coil_type'] == FIFF.FIFFV_COIL_MAGNES_GRAD - mag = info['chs'][idx]['coil_type'] == FIFF.FIFFV_COIL_MAGNES_MAG - if ch_type == 'grad' and grad: - conn_name = 'bti248grad' - elif ch_type == 'mag' and mag: - conn_name = 'bti248' - elif 'MEG 148' in info['ch_names'] and ch_type == 'mag': - idx = info['ch_names'].index('MEG 148') - if info['chs'][idx]['coil_type'] == FIFF.FIFFV_COIL_MAGNES_MAG: - conn_name = 'bti148' - elif has_CTF_grad and ch_type == 'mag': - if info['nchan'] < 100: - conn_name = 'ctf64' - elif info['nchan'] > 200: - conn_name = 'ctf275' + if "MEG 248" in info["ch_names"]: + idx = info["ch_names"].index("MEG 248") + grad = info["chs"][idx]["coil_type"] == FIFF.FIFFV_COIL_MAGNES_GRAD + mag = info["chs"][idx]["coil_type"] == FIFF.FIFFV_COIL_MAGNES_MAG + if ch_type == "grad" and grad: + conn_name = "bti248grad" + elif ch_type == "mag" and mag: + conn_name = "bti248" + elif "MEG 148" in info["ch_names"] and ch_type == "mag": + idx = info["ch_names"].index("MEG 148") + if info["chs"][idx]["coil_type"] == FIFF.FIFFV_COIL_MAGNES_MAG: + conn_name = "bti148" + elif has_CTF_grad and ch_type == "mag": + if info["nchan"] < 100: + conn_name = "ctf64" + elif info["nchan"] > 200: + conn_name = "ctf275" else: - conn_name = 'ctf151' + conn_name = "ctf151" elif n_kit_grads > 0: from ..io.kit.constants import KIT_NEIGHBORS - conn_name = KIT_NEIGHBORS.get(info['kit_system_id']) + + conn_name = KIT_NEIGHBORS.get(info["kit_system_id"]) if conn_name is not None: - logger.info(f'Reading adjacency matrix for {conn_name}.') + logger.info(f"Reading adjacency matrix for {conn_name}.") return read_ch_adjacency(conn_name) - logger.info('Could not find a adjacency matrix for the data. ' - 'Computing adjacency based on Delaunay triangulations.') + logger.info( + "Could not find a adjacency matrix for the data. " + "Computing adjacency based on Delaunay triangulations." + ) return _compute_ch_adjacency(info, ch_type) @@ -1702,21 +1909,24 @@ def _compute_ch_adjacency(info, ch_type): from scipy.spatial import Delaunay from .. import spatial_tris_adjacency from ..channels.layout import _find_topomap_coords, _pair_grad_sensors - combine_grads = (ch_type == 'grad' - and any([coil_type in [ch['coil_type'] - for ch in info['chs']] - for coil_type in - [FIFF.FIFFV_COIL_VV_PLANAR_T1, - FIFF.FIFFV_COIL_NM_122]])) + + combine_grads = ch_type == "grad" and any( + [ + coil_type in [ch["coil_type"] for ch in info["chs"]] + for coil_type in [FIFF.FIFFV_COIL_VV_PLANAR_T1, FIFF.FIFFV_COIL_NM_122] + ] + ) picks = dict(_picks_by_type(info, exclude=[]))[ch_type] - ch_names = [info['ch_names'][pick] for pick in picks] + ch_names = [info["ch_names"][pick] for pick in picks] if combine_grads: pairs = _pair_grad_sensors(info, topomap_coords=False, exclude=[]) if len(pairs) != len(picks): - raise RuntimeError('Cannot find a pair for some of the ' - 'gradiometers. Cannot compute adjacency ' - 'matrix.') + raise RuntimeError( + "Cannot find a pair for some of the " + "gradiometers. Cannot compute adjacency " + "matrix." + ) # only for one of the pair xy = _find_topomap_coords(info, picks[::2], sphere=HEAD_SIZE_DEFAULT) else: @@ -1774,26 +1984,26 @@ def fix_mag_coil_types(info, use_cal=False): old_mag_inds = _get_T1T2_mag_inds(info, use_cal) for ii in old_mag_inds: - info['chs'][ii]['coil_type'] = FIFF.FIFFV_COIL_VV_MAG_T3 - logger.info('%d of %d magnetometer types replaced with T3.' % - (len(old_mag_inds), - len(pick_types(info, meg='mag', exclude=[])))) + info["chs"][ii]["coil_type"] = FIFF.FIFFV_COIL_VV_MAG_T3 + logger.info( + "%d of %d magnetometer types replaced with T3." + % (len(old_mag_inds), len(pick_types(info, meg="mag", exclude=[]))) + ) info._check_consistency() def _get_T1T2_mag_inds(info, use_cal=False): """Find T1/T2 magnetometer coil types.""" - picks = pick_types(info, meg='mag', exclude=[]) + picks = pick_types(info, meg="mag", exclude=[]) old_mag_inds = [] # From email exchanges, systems with the larger T2 coil only use the cal # value of 2.09e-11. Newer T3 magnetometers use 4.13e-11 or 1.33e-10 # (Triux). So we can use a simple check for > 3e-11. for ii in picks: - ch = info['chs'][ii] - if ch['coil_type'] in (FIFF.FIFFV_COIL_VV_MAG_T1, - FIFF.FIFFV_COIL_VV_MAG_T2): + ch = info["chs"][ii] + if ch["coil_type"] in (FIFF.FIFFV_COIL_VV_MAG_T1, FIFF.FIFFV_COIL_VV_MAG_T2): if use_cal: - if ch['cal'] > 3e-11: + if ch["cal"] > 3e-11: old_mag_inds.append(ii) else: old_mag_inds.append(ii) @@ -1802,47 +2012,72 @@ def _get_T1T2_mag_inds(info, use_cal=False): def _get_ch_info(info): """Get channel info for inferring acquisition device.""" - chs = info['chs'] + chs = info["chs"] # Only take first 16 bits, as higher bits store CTF comp order - coil_types = {ch['coil_type'] & 0xFFFF for ch in chs} - channel_types = {ch['kind'] for ch in chs} - - has_vv_mag = any(k in coil_types for k in - [FIFF.FIFFV_COIL_VV_MAG_T1, FIFF.FIFFV_COIL_VV_MAG_T2, - FIFF.FIFFV_COIL_VV_MAG_T3]) - has_vv_grad = any(k in coil_types for k in [FIFF.FIFFV_COIL_VV_PLANAR_T1, - FIFF.FIFFV_COIL_VV_PLANAR_T2, - FIFF.FIFFV_COIL_VV_PLANAR_T3]) - has_neuromag_122_grad = any(k in coil_types - for k in [FIFF.FIFFV_COIL_NM_122]) - - is_old_vv = ' ' in chs[0]['ch_name'] + coil_types = {ch["coil_type"] & 0xFFFF for ch in chs} + channel_types = {ch["kind"] for ch in chs} + + has_vv_mag = any( + k in coil_types + for k in [ + FIFF.FIFFV_COIL_VV_MAG_T1, + FIFF.FIFFV_COIL_VV_MAG_T2, + FIFF.FIFFV_COIL_VV_MAG_T3, + ] + ) + has_vv_grad = any( + k in coil_types + for k in [ + FIFF.FIFFV_COIL_VV_PLANAR_T1, + FIFF.FIFFV_COIL_VV_PLANAR_T2, + FIFF.FIFFV_COIL_VV_PLANAR_T3, + ] + ) + has_neuromag_122_grad = any(k in coil_types for k in [FIFF.FIFFV_COIL_NM_122]) + + is_old_vv = " " in chs[0]["ch_name"] has_4D_mag = FIFF.FIFFV_COIL_MAGNES_MAG in coil_types - ctf_other_types = (FIFF.FIFFV_COIL_CTF_REF_MAG, - FIFF.FIFFV_COIL_CTF_REF_GRAD, - FIFF.FIFFV_COIL_CTF_OFFDIAG_REF_GRAD) - has_CTF_grad = (FIFF.FIFFV_COIL_CTF_GRAD in coil_types or - (FIFF.FIFFV_MEG_CH in channel_types and - any(k in ctf_other_types for k in coil_types))) + ctf_other_types = ( + FIFF.FIFFV_COIL_CTF_REF_MAG, + FIFF.FIFFV_COIL_CTF_REF_GRAD, + FIFF.FIFFV_COIL_CTF_OFFDIAG_REF_GRAD, + ) + has_CTF_grad = FIFF.FIFFV_COIL_CTF_GRAD in coil_types or ( + FIFF.FIFFV_MEG_CH in channel_types + and any(k in ctf_other_types for k in coil_types) + ) # hack due to MNE-C bug in IO of CTF # only take first 16 bits, as higher bits store CTF comp order - n_kit_grads = sum(ch['coil_type'] & 0xFFFF == FIFF.FIFFV_COIL_KIT_GRAD - for ch in chs) - - has_any_meg = any([has_vv_mag, has_vv_grad, has_4D_mag, has_CTF_grad, - n_kit_grads]) - has_eeg_coils = (FIFF.FIFFV_COIL_EEG in coil_types and - FIFF.FIFFV_EEG_CH in channel_types) + n_kit_grads = sum( + ch["coil_type"] & 0xFFFF == FIFF.FIFFV_COIL_KIT_GRAD for ch in chs + ) + + has_any_meg = any([has_vv_mag, has_vv_grad, has_4D_mag, has_CTF_grad, n_kit_grads]) + has_eeg_coils = ( + FIFF.FIFFV_COIL_EEG in coil_types and FIFF.FIFFV_EEG_CH in channel_types + ) has_eeg_coils_and_meg = has_eeg_coils and has_any_meg has_eeg_coils_only = has_eeg_coils and not has_any_meg - has_csd_coils = (FIFF.FIFFV_COIL_EEG_CSD in coil_types and - FIFF.FIFFV_EEG_CH in channel_types) - - return (has_vv_mag, has_vv_grad, is_old_vv, has_4D_mag, ctf_other_types, - has_CTF_grad, n_kit_grads, has_any_meg, has_eeg_coils, - has_eeg_coils_and_meg, has_eeg_coils_only, has_neuromag_122_grad, - has_csd_coils) + has_csd_coils = ( + FIFF.FIFFV_COIL_EEG_CSD in coil_types and FIFF.FIFFV_EEG_CH in channel_types + ) + + return ( + has_vv_mag, + has_vv_grad, + is_old_vv, + has_4D_mag, + ctf_other_types, + has_CTF_grad, + n_kit_grads, + has_any_meg, + has_eeg_coils, + has_eeg_coils_and_meg, + has_eeg_coils_only, + has_neuromag_122_grad, + has_csd_coils, + ) @fill_doc @@ -1884,6 +2119,7 @@ def make_1020_channel_selections(info, midline="z", *, return_ch_names=False): try: from .layout import find_layout + layout = find_layout(info) pos = layout.pos ch_names = layout.names @@ -1905,8 +2141,10 @@ def make_1020_channel_selections(info, midline="z", *, return_ch_names=False): if pos is not None: # sort channels from front to center # (y-coordinate of the position info in the layout) - selections = {selection: np.array(picks)[pos[picks, 1].argsort()] - for selection, picks in selections.items()} + selections = { + selection: np.array(picks)[pos[picks, 1].argsort()] + for selection, picks in selections.items() + } # convert channel indices to names if requested if return_ch_names: @@ -1917,8 +2155,9 @@ def make_1020_channel_selections(info, midline="z", *, return_ch_names=False): @verbose -def combine_channels(inst, groups, method='mean', keep_stim=False, - drop_bad=False, verbose=None): +def combine_channels( + inst, groups, method="mean", keep_stim=False, drop_bad=False, verbose=None +): """Combine channels based on specified channel grouping. Parameters @@ -1969,8 +2208,8 @@ def combine_channels(inst, groups, method='mean', keep_stim=False, from .. import BaseEpochs, EpochsArray, Evoked, EvokedArray ch_axis = 1 if isinstance(inst, BaseEpochs) else 0 - ch_idx = list(range(inst.info['nchan'])) - ch_names = inst.info['ch_names'] + ch_idx = list(range(inst.info["nchan"])) + ch_names = inst.info["ch_names"] ch_types = inst.get_channel_types() inst_data = inst.data if isinstance(inst, Evoked) else inst.get_data() groups = OrderedDict(deepcopy(groups)) @@ -1978,99 +2217,121 @@ def combine_channels(inst, groups, method='mean', keep_stim=False, # Convert string values of ``method`` into callables # XXX Possibly de-duplicate with _make_combine_callable of mne/viz/utils.py if isinstance(method, str): - method_dict = {key: partial(getattr(np, key), axis=ch_axis) - for key in ('mean', 'median', 'std')} + method_dict = { + key: partial(getattr(np, key), axis=ch_axis) + for key in ("mean", "median", "std") + } try: method = method_dict[method] except KeyError: - raise ValueError('"method" must be a callable, or one of "mean", ' - f'"median", or "std"; got "{method}".') + raise ValueError( + '"method" must be a callable, or one of "mean", ' + f'"median", or "std"; got "{method}".' + ) # Instantiate channel info and data new_ch_names, new_ch_types, new_data = [], [], [] if not isinstance(keep_stim, bool): - raise TypeError('"keep_stim" must be of type bool, not ' - f'{type(keep_stim)}.') + raise TypeError('"keep_stim" must be of type bool, not ' f"{type(keep_stim)}.") if keep_stim: stim_ch_idx = list(pick_types(inst.info, meg=False, stim=True)) if stim_ch_idx: new_ch_names = [ch_names[idx] for idx in stim_ch_idx] new_ch_types = [ch_types[idx] for idx in stim_ch_idx] - new_data = [np.take(inst_data, idx, axis=ch_axis) - for idx in stim_ch_idx] + new_data = [np.take(inst_data, idx, axis=ch_axis) for idx in stim_ch_idx] else: - warn('Could not find stimulus channels.') + warn("Could not find stimulus channels.") # Get indices of bad channels ch_idx_bad = [] if not isinstance(drop_bad, bool): - raise TypeError('"drop_bad" must be of type bool, not ' - f'{type(drop_bad)}.') - if drop_bad and inst.info['bads']: - ch_idx_bad = pick_channels(ch_names, inst.info['bads']) + raise TypeError('"drop_bad" must be of type bool, not ' f"{type(drop_bad)}.") + if drop_bad and inst.info["bads"]: + ch_idx_bad = pick_channels(ch_names, inst.info["bads"]) # Check correctness of combinations for this_group, this_picks in groups.items(): # Check if channel indices are out of bounds if not all(idx in ch_idx for idx in this_picks): - raise ValueError('Some channel indices are out of bounds.') + raise ValueError("Some channel indices are out of bounds.") # Check if heterogeneous sensor type combinations this_ch_type = np.array(ch_types)[this_picks] if len(set(this_ch_type)) > 1: - types = ', '.join(set(this_ch_type)) - raise ValueError('Cannot combine sensors of different types; ' - f'"{this_group}" contains types {types}.') + types = ", ".join(set(this_ch_type)) + raise ValueError( + "Cannot combine sensors of different types; " + f'"{this_group}" contains types {types}.' + ) # Remove bad channels these_bads = [idx for idx in this_picks if idx in ch_idx_bad] this_picks = [idx for idx in this_picks if idx not in ch_idx_bad] if these_bads: - logger.info('Dropped the following channels in group ' - f'{this_group}: {these_bads}') + logger.info( + "Dropped the following channels in group " f"{this_group}: {these_bads}" + ) # Check if combining less than 2 channel if len(set(this_picks)) < 2: - warn(f'Less than 2 channels in group "{this_group}" when ' - f'combining by method "{method}".') + warn( + f'Less than 2 channels in group "{this_group}" when ' + f'combining by method "{method}".' + ) # If all good create more detailed dict without bad channels groups[this_group] = dict(picks=this_picks, ch_type=this_ch_type[0]) # Combine channels and add them to the new instance for this_group, this_group_dict in groups.items(): new_ch_names.append(this_group) - new_ch_types.append(this_group_dict['ch_type']) - this_picks = this_group_dict['picks'] + new_ch_types.append(this_group_dict["ch_type"]) + this_picks = this_group_dict["picks"] this_data = np.take(inst_data, this_picks, axis=ch_axis) new_data.append(method(this_data)) new_data = np.swapaxes(new_data, 0, ch_axis) - info = create_info(sfreq=inst.info['sfreq'], ch_names=new_ch_names, - ch_types=new_ch_types) + info = create_info( + sfreq=inst.info["sfreq"], ch_names=new_ch_names, ch_types=new_ch_types + ) # create new instances and make sure to copy important attributes if isinstance(inst, BaseRaw): combined_inst = RawArray(new_data, info, first_samp=inst.first_samp) elif isinstance(inst, BaseEpochs): - combined_inst = EpochsArray(new_data, info, events=inst.events, - tmin=inst.times[0], baseline=inst.baseline) + combined_inst = EpochsArray( + new_data, + info, + events=inst.events, + tmin=inst.times[0], + baseline=inst.baseline, + ) if inst.metadata is not None: combined_inst.metadata = inst.metadata.copy() elif isinstance(inst, Evoked): - combined_inst = EvokedArray(new_data, info, tmin=inst.times[0], - baseline=inst.baseline) + combined_inst = EvokedArray( + new_data, info, tmin=inst.times[0], baseline=inst.baseline + ) return combined_inst # NeuroMag channel groupings -_SELECTIONS = ['Vertex', 'Left-temporal', 'Right-temporal', 'Left-parietal', - 'Right-parietal', 'Left-occipital', 'Right-occipital', - 'Left-frontal', 'Right-frontal'] -_EEG_SELECTIONS = ['EEG 1-32', 'EEG 33-64', 'EEG 65-96', 'EEG 97-128'] +_SELECTIONS = [ + "Vertex", + "Left-temporal", + "Right-temporal", + "Left-parietal", + "Right-parietal", + "Left-occipital", + "Right-occipital", + "Left-frontal", + "Right-frontal", +] +_EEG_SELECTIONS = ["EEG 1-32", "EEG 33-64", "EEG 65-96", "EEG 97-128"] def _divide_to_regions(info, add_stim=True): """Divide channels to regions by positions.""" from scipy.stats import zscore + picks = _pick_data_channels(info, exclude=[]) chs_in_lobe = len(picks) // 4 - pos = np.array([ch['loc'][:3] for ch in info['chs']]) + pos = np.array([ch["loc"][:3] for ch in info["chs"]]) x, y, z = pos.T frontal = picks[np.argsort(y[picks])[-chs_in_lobe:]] @@ -2090,14 +2351,14 @@ def _divide_to_regions(info, add_stim=True): # Because of the way the sides are divided, there may be outliers in the # temporal lobes. Here we switch the sides for these outliers. For other # lobes it is not a big problem because of the vicinity of the lobes. - with np.errstate(invalid='ignore'): # invalid division, greater compare + with np.errstate(invalid="ignore"): # invalid division, greater compare zs = np.abs(zscore(x[rt])) - outliers = np.array(rt)[np.where(zs > 2.)[0]] + outliers = np.array(rt)[np.where(zs > 2.0)[0]] rt = list(np.setdiff1d(rt, outliers)) - with np.errstate(invalid='ignore'): # invalid division, greater compare + with np.errstate(invalid="ignore"): # invalid division, greater compare zs = np.abs(zscore(x[lt])) - outliers = np.append(outliers, (np.array(lt)[np.where(zs > 2.)[0]])) + outliers = np.append(outliers, (np.array(lt)[np.where(zs > 2.0)[0]])) lt = list(np.setdiff1d(lt, outliers)) l_mean = np.mean(x[lt]) @@ -2112,11 +2373,19 @@ def _divide_to_regions(info, add_stim=True): stim_ch = _get_stim_channel(None, info, raise_error=False) if len(stim_ch) > 0: for region in [lf, rf, lo, ro, lp, rp, lt, rt]: - region.append(info['ch_names'].index(stim_ch[0])) - return OrderedDict([('Left-frontal', lf), ('Right-frontal', rf), - ('Left-parietal', lp), ('Right-parietal', rp), - ('Left-occipital', lo), ('Right-occipital', ro), - ('Left-temporal', lt), ('Right-temporal', rt)]) + region.append(info["ch_names"].index(stim_ch[0])) + return OrderedDict( + [ + ("Left-frontal", lf), + ("Right-frontal", rf), + ("Left-parietal", lp), + ("Right-parietal", rp), + ("Left-occipital", lo), + ("Right-occipital", ro), + ("Left-temporal", lt), + ("Right-temporal", rt), + ] + ) def _divide_side(lobe, x): @@ -2165,42 +2434,44 @@ def read_vectorview_selection(name, fname=None, info=None, verbose=None): name = [name] if isinstance(info, Info): picks = pick_types(info, meg=True, exclude=()) - if len(picks) > 0 and ' ' not in info['ch_names'][picks[0]]: - spacing = 'new' + if len(picks) > 0 and " " not in info["ch_names"][picks[0]]: + spacing = "new" else: - spacing = 'old' + spacing = "old" elif info is not None: - raise TypeError('info must be an instance of Info or None, not %s' - % (type(info),)) + raise TypeError( + "info must be an instance of Info or None, not %s" % (type(info),) + ) else: # info is None - spacing = 'old' + spacing = "old" # use built-in selections by default if fname is None: - fname = op.join(op.dirname(__file__), '..', 'data', 'mne_analyze.sel') + fname = op.join(op.dirname(__file__), "..", "data", "mne_analyze.sel") fname = str(_check_fname(fname, must_exist=True, overwrite="read")) # use this to make sure we find at least one match for each name name_found = {n: False for n in name} - with open(fname, 'r') as fid: + with open(fname, "r") as fid: sel = [] for line in fid: line = line.strip() # skip blank lines and comments - if len(line) == 0 or line[0] == '#': + if len(line) == 0 or line[0] == "#": continue # get the name of the selection in the file - pos = line.find(':') + pos = line.find(":") if pos < 0: - logger.info('":" delimiter not found in selections file, ' - 'skipping line') + logger.info( + '":" delimiter not found in selections file, ' "skipping line" + ) continue sel_name_file = line[:pos] # search for substring match with name provided for n in name: if sel_name_file.find(n) >= 0: - sel.extend(line[pos + 1:].split('|')) + sel.extend(line[pos + 1 :].split("|")) name_found[n] = True break @@ -2212,6 +2483,6 @@ def read_vectorview_selection(name, fname=None, info=None, verbose=None): # make the selection a sorted list with unique elements sel = list(set(sel)) sel.sort() - if spacing == 'new': # "new" or "old" by now, "old" is default - sel = [s.replace('MEG ', 'MEG') for s in sel] + if spacing == "new": # "new" or "old" by now, "old" is default + sel = [s.replace("MEG ", "MEG") for s in sel] return sel diff --git a/mne/channels/interpolation.py b/mne/channels/interpolation.py index d8c0a2be78a..f9dc0319992 100644 --- a/mne/channels/interpolation.py +++ b/mne/channels/interpolation.py @@ -26,9 +26,10 @@ def _calc_h(cosang, stiffness=4, n_legendre_terms=50): n_legendre_terms : int number of Legendre terms to evaluate. """ - factors = [(2 * n + 1) / - (n ** (stiffness - 1) * (n + 1) ** (stiffness - 1) * 4 * np.pi) - for n in range(1, n_legendre_terms + 1)] + factors = [ + (2 * n + 1) / (n ** (stiffness - 1) * (n + 1) ** (stiffness - 1) * 4 * np.pi) + for n in range(1, n_legendre_terms + 1) + ] return legval(cosang, [0] + factors) @@ -50,9 +51,10 @@ def _calc_g(cosang, stiffness=4, n_legendre_terms=50): G : np.ndrarray of float, shape(n_channels, n_channels) The G matrix. """ - factors = [(2 * n + 1) / (n ** stiffness * (n + 1) ** stiffness * - 4 * np.pi) - for n in range(1, n_legendre_terms + 1)] + factors = [ + (2 * n + 1) / (n**stiffness * (n + 1) ** stiffness * 4 * np.pi) + for n in range(1, n_legendre_terms + 1) + ] return legval(cosang, [0] + factors) @@ -83,6 +85,7 @@ def _make_interpolation_matrix(pos_from, pos_to, alpha=1e-5): Electroencephalography Clinical Neurophysiology, Feb; 72(2):184-7. """ from scipy import linalg + pos_from = pos_from.copy() pos_to = pos_to.copy() n_from = pos_from.shape[0] @@ -101,10 +104,14 @@ def _make_interpolation_matrix(pos_from, pos_to, alpha=1e-5): assert G_to_from.shape == (n_to, n_from) if alpha is not None: - G_from.flat[::len(G_from) + 1] += alpha - - C = np.vstack([np.hstack([G_from, np.ones((n_from, 1))]), - np.hstack([np.ones((1, n_from)), [[0]]])]) + G_from.flat[:: len(G_from) + 1] += alpha + + C = np.vstack( + [ + np.hstack([G_from, np.ones((n_from, 1))]), + np.hstack([np.ones((1, n_from)), [[0]]]), + ] + ) C_inv = linalg.pinv(C) interpolation = np.hstack([G_to_from, np.ones((n_to, 1))]) @ C_inv[:, :-1] @@ -117,9 +124,11 @@ def _do_interp_dots(inst, interpolation, goods_idx, bads_idx): from ..io.base import BaseRaw from ..epochs import BaseEpochs from ..evoked import Evoked - _validate_type(inst, (BaseRaw, BaseEpochs, Evoked), 'inst') + + _validate_type(inst, (BaseRaw, BaseEpochs, Evoked), "inst") inst._data[..., bads_idx, :] = np.matmul( - interpolation, inst._data[..., goods_idx, :]) + interpolation, inst._data[..., goods_idx, :] + ) @verbose @@ -131,7 +140,7 @@ def _interpolate_bads_eeg(inst, origin, exclude=None, verbose=None): picks = pick_types(inst.info, meg=False, eeg=True, exclude=exclude) inst.info._check_consistency() - bads_idx[picks] = [inst.ch_names[ch] in inst.info['bads'] for ch in picks] + bads_idx[picks] = [inst.ch_names[ch] in inst.info["bads"] for ch in picks] if len(picks) == 0 or bads_idx.sum() == 0: return @@ -148,30 +157,43 @@ def _interpolate_bads_eeg(inst, origin, exclude=None, verbose=None): # test spherical fit distance = np.linalg.norm(pos - origin, axis=-1) distance = np.mean(distance / np.mean(distance)) - if np.abs(1. - distance) > 0.1: - warn('Your spherical fit is poor, interpolation results are ' - 'likely to be inaccurate.') + if np.abs(1.0 - distance) > 0.1: + warn( + "Your spherical fit is poor, interpolation results are " + "likely to be inaccurate." + ) pos_good = pos[goods_idx_pos] - origin pos_bad = pos[bads_idx_pos] - origin - logger.info('Computing interpolation matrix from {} sensor ' - 'positions'.format(len(pos_good))) + logger.info( + "Computing interpolation matrix from {} sensor " + "positions".format(len(pos_good)) + ) interpolation = _make_interpolation_matrix(pos_good, pos_bad) - logger.info('Interpolating {} sensors'.format(len(pos_bad))) + logger.info("Interpolating {} sensors".format(len(pos_bad))) _do_interp_dots(inst, interpolation, goods_idx, bads_idx) -def _interpolate_bads_meg(inst, mode='accurate', origin=(0., 0., 0.04), - verbose=None, ref_meg=False): +def _interpolate_bads_meg( + inst, mode="accurate", origin=(0.0, 0.0, 0.04), verbose=None, ref_meg=False +): return _interpolate_bads_meeg( - inst, mode, origin, ref_meg=ref_meg, eeg=False, verbose=verbose) + inst, mode, origin, ref_meg=ref_meg, eeg=False, verbose=verbose + ) @verbose -def _interpolate_bads_meeg(inst, mode='accurate', origin=(0., 0., 0.04), - meg=True, eeg=True, ref_meg=False, - exclude=(), verbose=None): +def _interpolate_bads_meeg( + inst, + mode="accurate", + origin=(0.0, 0.0, 0.04), + meg=True, + eeg=True, + ref_meg=False, + exclude=(), + verbose=None, +): bools = dict(meg=meg, eeg=eeg) info = _simplify_info(inst.info) for ch_type, do in bools.items(): @@ -180,15 +202,14 @@ def _interpolate_bads_meeg(inst, mode='accurate', origin=(0., 0., 0.04), kw = dict(meg=False, eeg=False) kw[ch_type] = True picks_type = pick_types(info, ref_meg=ref_meg, exclude=exclude, **kw) - picks_good = pick_types(info, ref_meg=ref_meg, exclude='bads', **kw) - use_ch_names = [inst.info['ch_names'][p] for p in picks_type] - bads_type = [ch for ch in inst.info['bads'] if ch in use_ch_names] + picks_good = pick_types(info, ref_meg=ref_meg, exclude="bads", **kw) + use_ch_names = [inst.info["ch_names"][p] for p in picks_type] + bads_type = [ch for ch in inst.info["bads"] if ch in use_ch_names] if len(bads_type) == 0 or len(picks_type) == 0: continue # select the bad channels to be interpolated - picks_bad = pick_channels(inst.info['ch_names'], bads_type, - exclude=[]) - if ch_type == 'eeg': + picks_bad = pick_channels(inst.info["ch_names"], bads_type, exclude=[]) + if ch_type == "eeg": picks_to = picks_type bad_sel = np.in1d(picks_type, picks_bad) else: @@ -196,14 +217,13 @@ def _interpolate_bads_meeg(inst, mode='accurate', origin=(0., 0., 0.04), bad_sel = slice(None) info_from = pick_info(inst.info, picks_good) info_to = pick_info(inst.info, picks_to) - mapping = _map_meg_or_eeg_channels( - info_from, info_to, mode=mode, origin=origin) + mapping = _map_meg_or_eeg_channels(info_from, info_to, mode=mode, origin=origin) mapping = mapping[bad_sel] _do_interp_dots(inst, mapping, picks_good, picks_bad) @verbose -def _interpolate_bads_nirs(inst, method='nearest', exclude=(), verbose=None): +def _interpolate_bads_nirs(inst, method="nearest", exclude=(), verbose=None): from scipy.spatial.distance import pdist, squareform from mne.preprocessing.nirs import _validate_nirs_info @@ -212,21 +232,20 @@ def _interpolate_bads_nirs(inst, method='nearest', exclude=(), verbose=None): # Returns pick of all nirs and ensures channels are correctly ordered picks_nirs = _validate_nirs_info(inst.info) - nirs_ch_names = [inst.info['ch_names'][p] for p in picks_nirs] + nirs_ch_names = [inst.info["ch_names"][p] for p in picks_nirs] nirs_ch_names = [ch for ch in nirs_ch_names if ch not in exclude] - bads_nirs = [ch for ch in inst.info['bads'] if ch in nirs_ch_names] + bads_nirs = [ch for ch in inst.info["bads"] if ch in nirs_ch_names] if len(bads_nirs) == 0: return - picks_bad = pick_channels(inst.info['ch_names'], bads_nirs, exclude=[]) + picks_bad = pick_channels(inst.info["ch_names"], bads_nirs, exclude=[]) bads_mask = [p in picks_bad for p in picks_nirs] - chs = [inst.info['chs'][i] for i in picks_nirs] - locs3d = np.array([ch['loc'][:3] for ch in chs]) - - _check_option('fnirs_method', method, ['nearest']) + chs = [inst.info["chs"][i] for i in picks_nirs] + locs3d = np.array([ch["loc"][:3] for ch in chs]) - if method == 'nearest': + _check_option("fnirs_method", method, ["nearest"]) + if method == "nearest": dist = pdist(locs3d) dist = squareform(dist) @@ -240,6 +259,6 @@ def _interpolate_bads_nirs(inst, method='nearest', exclude=(), verbose=None): closest_idx = np.argmin(dists_to_bad) + (bad % 2) inst._data[bad] = inst._data[closest_idx] - inst.info['bads'] = [ch for ch in inst.info['bads'] if ch in exclude] + inst.info["bads"] = [ch for ch in inst.info["bads"] if ch in exclude] return inst diff --git a/mne/channels/layout.py b/mne/channels/layout.py index e59bb80a2a1..9be79f33102 100644 --- a/mne/channels/layout.py +++ b/mne/channels/layout.py @@ -20,8 +20,16 @@ from ..io.pick import pick_types, _picks_to_idx, _FNIRS_CH_TYPES_SPLIT from ..io.constants import FIFF from ..io.meas_info import Info -from ..utils import (_clean_names, warn, _check_ch_locs, fill_doc, - _check_fname, _check_option, _check_sphere, logger) +from ..utils import ( + _clean_names, + warn, + _check_ch_locs, + fill_doc, + _check_fname, + _check_option, + _check_sphere, + logger, +) from .channels import _get_ch_info @@ -74,26 +82,32 @@ def save(self, fname, overwrite=False): height = self.pos[:, 3] fname = _check_fname(fname, overwrite=overwrite, name=fname) if fname.suffix == ".lout": - out_str = '%8.2f %8.2f %8.2f %8.2f\n' % self.box + out_str = "%8.2f %8.2f %8.2f %8.2f\n" % self.box elif fname.suffix == ".lay": - out_str = '' + out_str = "" else: - raise ValueError('Unknown layout type. Should be of type ' - '.lout or .lay.') + raise ValueError("Unknown layout type. Should be of type " ".lout or .lay.") for ii in range(x.shape[0]): - out_str += ('%03d %8.2f %8.2f %8.2f %8.2f %s\n' - % (self.ids[ii], x[ii], y[ii], - width[ii], height[ii], self.names[ii])) + out_str += "%03d %8.2f %8.2f %8.2f %8.2f %s\n" % ( + self.ids[ii], + x[ii], + y[ii], + width[ii], + height[ii], + self.names[ii], + ) - f = open(fname, 'w') + f = open(fname, "w") f.write(out_str) f.close() def __repr__(self): """Return the string representation.""" - return '' % (self.kind, - ', '.join(self.names[:3])) + return "" % ( + self.kind, + ", ".join(self.names[:3]), + ) @fill_doc def plot(self, picks=None, show_axes=False, show=True): @@ -117,6 +131,7 @@ def plot(self, picks=None, show_axes=False, show=True): .. versionadded:: 0.12.0 """ from ..viz.topomap import plot_layout + return plot_layout(self, picks=picks, show_axes=show_axes, show=show) @@ -130,7 +145,7 @@ def _read_lout(fname): splits = line.split() if len(splits) == 7: cid, x, y, dx, dy, chkind, nb = splits - name = chkind + ' ' + nb + name = chkind + " " + nb else: cid, x, y, dx, dy, name = splits pos.append(np.array([x, y, dx, dy], dtype=np.float64)) @@ -151,7 +166,7 @@ def _read_lay(fname): splits = line.split() if len(splits) == 7: cid, x, y, dx, dy, chkind, nb = splits - name = chkind + ' ' + nb + name = chkind + " " + nb else: cid, x, y, dx, dy, name = splits pos.append(np.array([x, y, dx, dy], dtype=np.float64)) @@ -263,22 +278,14 @@ def read_layout(fname=None, path="", scale=True, *, kind=None): # kind should be the name as a string, but let's consider the case # where the path to the file is provided instead. kind = Path(kind) - if ( - len(kind.suffix) == 0 - and (path / kind.with_suffix(".lout")).exists() - ): + if len(kind.suffix) == 0 and (path / kind.with_suffix(".lout")).exists(): kind = kind.with_suffix(".lout") - elif ( - len(kind.suffix) == 0 - and (path / kind.with_suffix(".lay")).exists() - ): + elif len(kind.suffix) == 0 and (path / kind.with_suffix(".lay")).exists(): kind = kind.with_suffix(".lay") fname = kind if kind.exists() else path / kind.name if fname.suffix not in (".lout", ".lay"): - raise ValueError( - "Unknown layout type. Should be of type .lout or .lay." - ) + raise ValueError("Unknown layout type. Should be of type .lout or .lay.") kind = fname.stem else: # to be removed along the deprecated argument @@ -317,8 +324,9 @@ def read_layout(fname=None, path="", scale=True, *, kind=None): @fill_doc -def make_eeg_layout(info, radius=0.5, width=None, height=None, exclude='bads', - csd=False): +def make_eeg_layout( + info, radius=0.5, width=None, height=None, exclude="bads", csd=False +): """Create .lout file from EEG electrode digitization. Parameters @@ -348,18 +356,18 @@ def make_eeg_layout(info, radius=0.5, width=None, height=None, exclude='bads', make_grid_layout, generate_2d_layout """ if not (0 <= radius <= 0.5): - raise ValueError('The radius parameter should be between 0 and 0.5.') + raise ValueError("The radius parameter should be between 0 and 0.5.") if width is not None and not (0 <= width <= 1.0): - raise ValueError('The width parameter should be between 0 and 1.') + raise ValueError("The width parameter should be between 0 and 1.") if height is not None and not (0 <= height <= 1.0): - raise ValueError('The height parameter should be between 0 and 1.') + raise ValueError("The height parameter should be between 0 and 1.") pick_kwargs = dict(meg=False, eeg=True, ref_meg=False, exclude=exclude) if csd: pick_kwargs.update(csd=True, eeg=False) picks = pick_types(info, **pick_kwargs) loc2d = _find_topomap_coords(info, picks) - names = [info['chs'][i]['ch_name'] for i in picks] + names = [info["chs"][i]["ch_name"] for i in picks] # Scale [x, y] to be in the range [-0.5, 0.5] # Don't mess with the origin or aspect ratio @@ -376,7 +384,7 @@ def make_eeg_layout(info, radius=0.5, width=None, height=None, exclude='bads', # Some subplot centers will be at the figure edge. Shrink everything so it # fits in the figure. - scaling = min(1 / (1. + width), 1 / (1. + height)) + scaling = min(1 / (1.0 + width), 1 / (1.0 + height)) loc2d *= scaling width *= scaling height *= scaling @@ -385,14 +393,16 @@ def make_eeg_layout(info, radius=0.5, width=None, height=None, exclude='bads', loc2d += 0.5 n_channels = loc2d.shape[0] - pos = np.c_[loc2d[:, 0] - 0.5 * width, - loc2d[:, 1] - 0.5 * height, - width * np.ones(n_channels), - height * np.ones(n_channels)] + pos = np.c_[ + loc2d[:, 0] - 0.5 * width, + loc2d[:, 1] - 0.5 * height, + width * np.ones(n_channels), + height * np.ones(n_channels), + ] box = (0, 1, 0, 1) ids = 1 + np.arange(n_channels) - layout = Layout(box=box, pos=pos, names=names, kind='EEG', ids=ids) + layout = Layout(box=box, pos=pos, names=names, kind="EEG", ids=ids) return layout @@ -416,12 +426,12 @@ def make_grid_layout(info, picks=None, n_col=None): -------- make_eeg_layout, generate_2d_layout """ - picks = _picks_to_idx(info, picks, 'misc') + picks = _picks_to_idx(info, picks, "misc") - names = [info['chs'][k]['ch_name'] for k in picks] + names = [info["chs"][k]["ch_name"] for k in picks] if not names: - raise ValueError('No misc data channels found.') + raise ValueError("No misc data channels found.") ids = list(range(len(picks))) size = len(picks) @@ -439,16 +449,15 @@ def make_grid_layout(info, picks=None, n_col=None): n_row = int(np.ceil(size / float(n_col))) # setup position grid - x, y = np.meshgrid(np.linspace(-0.5, 0.5, n_col), - np.linspace(-0.5, 0.5, n_row)) + x, y = np.meshgrid(np.linspace(-0.5, 0.5, n_col), np.linspace(-0.5, 0.5, n_row)) x, y = x.ravel()[:size], y.ravel()[:size] width, height = _box_size(np.c_[x, y], padding=0.1) # Some axes will be at the figure edge. Shrink everything so it fits in the # figure. Add 0.01 border around everything border_x, border_y = (0.01, 0.01) - x_scaling = 1 / (1. + width + border_x) - y_scaling = 1 / (1. + height + border_y) + x_scaling = 1 / (1.0 + width + border_x) + y_scaling = 1 / (1.0 + height + border_y) x = x * x_scaling y = y * y_scaling width *= x_scaling @@ -459,16 +468,17 @@ def make_grid_layout(info, picks=None, n_col=None): y += 0.5 # calculate pos - pos = np.c_[x - 0.5 * width, y - 0.5 * height, - width * np.ones(size), height * np.ones(size)] + pos = np.c_[ + x - 0.5 * width, y - 0.5 * height, width * np.ones(size), height * np.ones(size) + ] box = (0, 1, 0, 1) - layout = Layout(box=box, pos=pos, names=names, kind='grid-misc', ids=ids) + layout = Layout(box=box, pos=pos, names=names, kind="grid-misc", ids=ids) return layout @fill_doc -def find_layout(info, ch_type=None, exclude='bads'): +def find_layout(info, ch_type=None, exclude="bads"): """Choose a layout based on the channels in the info 'chs' field. Parameters @@ -488,57 +498,70 @@ def find_layout(info, ch_type=None, exclude='bads'): layout : Layout instance | None None if layout not found. """ - _check_option('ch_type', ch_type, [None, 'mag', 'grad', 'meg', 'eeg', - 'csd']) - - (has_vv_mag, has_vv_grad, is_old_vv, has_4D_mag, ctf_other_types, - has_CTF_grad, n_kit_grads, has_any_meg, has_eeg_coils, - has_eeg_coils_and_meg, has_eeg_coils_only, - has_neuromag_122_grad, has_csd_coils) = _get_ch_info(info) + _check_option("ch_type", ch_type, [None, "mag", "grad", "meg", "eeg", "csd"]) + + ( + has_vv_mag, + has_vv_grad, + is_old_vv, + has_4D_mag, + ctf_other_types, + has_CTF_grad, + n_kit_grads, + has_any_meg, + has_eeg_coils, + has_eeg_coils_and_meg, + has_eeg_coils_only, + has_neuromag_122_grad, + has_csd_coils, + ) = _get_ch_info(info) has_vv_meg = has_vv_mag and has_vv_grad has_vv_only_mag = has_vv_mag and not has_vv_grad has_vv_only_grad = has_vv_grad and not has_vv_mag if ch_type == "meg" and not has_any_meg: - raise RuntimeError('No MEG channels present. Cannot find MEG layout.') + raise RuntimeError("No MEG channels present. Cannot find MEG layout.") if ch_type == "eeg" and not has_eeg_coils: - raise RuntimeError('No EEG channels present. Cannot find EEG layout.') + raise RuntimeError("No EEG channels present. Cannot find EEG layout.") layout_name = None - if ((has_vv_meg and ch_type is None) or - (any([has_vv_mag, has_vv_grad]) and ch_type == 'meg')): - layout_name = 'Vectorview-all' - elif has_vv_only_mag or (has_vv_meg and ch_type == 'mag'): - layout_name = 'Vectorview-mag' - elif has_vv_only_grad or (has_vv_meg and ch_type == 'grad'): - if info['ch_names'][0].endswith('X'): - layout_name = 'Vectorview-grad_norm' + if (has_vv_meg and ch_type is None) or ( + any([has_vv_mag, has_vv_grad]) and ch_type == "meg" + ): + layout_name = "Vectorview-all" + elif has_vv_only_mag or (has_vv_meg and ch_type == "mag"): + layout_name = "Vectorview-mag" + elif has_vv_only_grad or (has_vv_meg and ch_type == "grad"): + if info["ch_names"][0].endswith("X"): + layout_name = "Vectorview-grad_norm" else: - layout_name = 'Vectorview-grad' + layout_name = "Vectorview-grad" elif has_neuromag_122_grad: - layout_name = 'Neuromag_122' - elif ((has_eeg_coils_only and ch_type in [None, 'eeg']) or - (has_eeg_coils_and_meg and ch_type == 'eeg')): + layout_name = "Neuromag_122" + elif (has_eeg_coils_only and ch_type in [None, "eeg"]) or ( + has_eeg_coils_and_meg and ch_type == "eeg" + ): if not isinstance(info, (dict, Info)): - raise RuntimeError('Cannot make EEG layout, no measurement info ' - 'was passed to `find_layout`') + raise RuntimeError( + "Cannot make EEG layout, no measurement info " + "was passed to `find_layout`" + ) return make_eeg_layout(info, exclude=exclude) - elif has_csd_coils and ch_type in [None, 'csd']: + elif has_csd_coils and ch_type in [None, "csd"]: return make_eeg_layout(info, exclude=exclude, csd=True) elif has_4D_mag: - layout_name = 'magnesWH3600' + layout_name = "magnesWH3600" elif has_CTF_grad: - layout_name = 'CTF-275' + layout_name = "CTF-275" elif n_kit_grads > 0: layout_name = _find_kit_layout(info, n_kit_grads) # If no known layout is found, fall back on automatic layout if layout_name is None: - picks = _picks_to_idx(info, 'data', exclude=(), with_ref_meg=False) - ch_names = [info['ch_names'][pick] for pick in picks] + picks = _picks_to_idx(info, "data", exclude=(), with_ref_meg=False) + ch_names = [info["ch_names"][pick] for pick in picks] xy = _find_topomap_coords(info, picks=picks, ignore_overlap=True) - return generate_2d_layout(xy, ch_names=ch_names, name='custom', - normalize=True) + return generate_2d_layout(xy, ch_names=ch_names, name="custom", normalize=True) layout = read_layout(fname=layout_name) if not is_old_vv: @@ -547,8 +570,8 @@ def find_layout(info, ch_type=None, exclude='bads'): layout.names = _clean_names(layout.names, before_dash=True) # Apply mask for excluded channels. - if exclude == 'bads': - exclude = info['bads'] + if exclude == "bads": + exclude = info["bads"] idx = [ii for ii, name in enumerate(layout.names) if name not in exclude] layout.names = [layout.names[ii] for ii in idx] layout.pos = layout.pos[idx] @@ -572,34 +595,69 @@ def _find_kit_layout(info, n_grads): kit_layout : str | None String naming the detected KIT layout or ``None`` if layout is missing. """ - if info['kit_system_id'] is not None: + if info["kit_system_id"] is not None: # avoid circular import from ..io.kit.constants import KIT_LAYOUT - return KIT_LAYOUT.get(info['kit_system_id']) + + return KIT_LAYOUT.get(info["kit_system_id"]) elif n_grads == 160: - return 'KIT-160' + return "KIT-160" elif n_grads == 125: - return 'KIT-125' + return "KIT-125" elif n_grads > 157: - return 'KIT-AD' + return "KIT-AD" # channels which are on the left hemisphere for NY and right for UMD - test_chs = ('MEG 13', 'MEG 14', 'MEG 15', 'MEG 16', 'MEG 25', - 'MEG 26', 'MEG 27', 'MEG 28', 'MEG 29', 'MEG 30', - 'MEG 31', 'MEG 32', 'MEG 57', 'MEG 60', 'MEG 61', - 'MEG 62', 'MEG 63', 'MEG 64', 'MEG 73', 'MEG 90', - 'MEG 93', 'MEG 95', 'MEG 96', 'MEG 105', 'MEG 112', - 'MEG 120', 'MEG 121', 'MEG 122', 'MEG 123', 'MEG 124', - 'MEG 125', 'MEG 126', 'MEG 142', 'MEG 144', 'MEG 153', - 'MEG 154', 'MEG 155', 'MEG 156') - x = [ch['loc'][0] < 0 for ch in info['chs'] if ch['ch_name'] in test_chs] + test_chs = ( + "MEG 13", + "MEG 14", + "MEG 15", + "MEG 16", + "MEG 25", + "MEG 26", + "MEG 27", + "MEG 28", + "MEG 29", + "MEG 30", + "MEG 31", + "MEG 32", + "MEG 57", + "MEG 60", + "MEG 61", + "MEG 62", + "MEG 63", + "MEG 64", + "MEG 73", + "MEG 90", + "MEG 93", + "MEG 95", + "MEG 96", + "MEG 105", + "MEG 112", + "MEG 120", + "MEG 121", + "MEG 122", + "MEG 123", + "MEG 124", + "MEG 125", + "MEG 126", + "MEG 142", + "MEG 144", + "MEG 153", + "MEG 154", + "MEG 155", + "MEG 156", + ) + x = [ch["loc"][0] < 0 for ch in info["chs"] if ch["ch_name"] in test_chs] if np.all(x): - return 'KIT-157' # KIT-NY + return "KIT-157" # KIT-NY elif np.all(np.invert(x)): - raise NotImplementedError("Guessing sensor layout for legacy UMD " - "files is not implemented. Please convert " - "your files using MNE-Python 0.13 or " - "higher.") + raise NotImplementedError( + "Guessing sensor layout for legacy UMD " + "files is not implemented. Please convert " + "your files using MNE-Python 0.13 or " + "higher." + ) else: raise RuntimeError("KIT system could not be determined for data") @@ -660,8 +718,7 @@ def ydiff(a, b): if height is None: # Find all axes that could potentially overlap horizontally. hdist = pdist(points, xdiff) - candidates = [all_combinations[i] for i, d in enumerate(hdist) - if d < width] + candidates = [all_combinations[i] for i, d in enumerate(hdist) if d < width] if len(candidates) == 0: # No axes overlap, take all the height you want. @@ -674,8 +731,7 @@ def ydiff(a, b): elif width is None: # Find all axes that could potentially overlap vertically. vdist = pdist(points, ydiff) - candidates = [all_combinations[i] for i, d in enumerate(vdist) - if d < height] + candidates = [all_combinations[i] for i, d in enumerate(vdist) if d < height] if len(candidates) == 0: # No axes overlap, take all the width you want. @@ -693,8 +749,9 @@ def ydiff(a, b): @fill_doc -def _find_topomap_coords(info, picks, layout=None, ignore_overlap=False, - to_sphere=True, sphere=None): +def _find_topomap_coords( + info, picks, layout=None, ignore_overlap=False, to_sphere=True, sphere=None +): """Guess the E/MEG layout and return appropriate topomap coordinates. Parameters @@ -714,16 +771,20 @@ def _find_topomap_coords(info, picks, layout=None, ignore_overlap=False, coords : array, shape = (n_chs, 2) 2 dimensional coordinates for each sensor for a topomap plot. """ - picks = _picks_to_idx(info, picks, 'all', exclude=(), allow_empty=False) + picks = _picks_to_idx(info, picks, "all", exclude=(), allow_empty=False) if layout is not None: - chs = [info['chs'][i] for i in picks] - pos = [layout.pos[layout.names.index(ch['ch_name'])] for ch in chs] + chs = [info["chs"][i] for i in picks] + pos = [layout.pos[layout.names.index(ch["ch_name"])] for ch in chs] pos = np.asarray(pos) else: pos = _auto_topomap_coords( - info, picks, ignore_overlap=ignore_overlap, to_sphere=to_sphere, - sphere=sphere) + info, + picks, + ignore_overlap=ignore_overlap, + to_sphere=to_sphere, + sphere=sphere, + ) return pos @@ -756,50 +817,64 @@ def _auto_topomap_coords(info, picks, ignore_overlap, to_sphere, sphere): An array of positions of the 2 dimensional map. """ from scipy.spatial.distance import pdist, squareform + sphere = _check_sphere(sphere, info) - logger.debug(f'Generating coords using: {sphere}') + logger.debug(f"Generating coords using: {sphere}") - picks = _picks_to_idx(info, picks, 'all', exclude=(), allow_empty=False) - chs = [info['chs'][i] for i in picks] + picks = _picks_to_idx(info, picks, "all", exclude=(), allow_empty=False) + chs = [info["chs"][i] for i in picks] # Use channel locations if available - locs3d = np.array([ch['loc'][:3] for ch in chs]) + locs3d = np.array([ch["loc"][:3] for ch in chs]) # If electrode locations are not available, use digization points if not _check_ch_locs(info=info, picks=picks): - logging.warning('Did not find any electrode locations (in the info ' - 'object), will attempt to use digitization points ' - 'instead. However, if digitization points do not ' - 'correspond to the EEG electrodes, this will lead to ' - 'bad results. Please verify that the sensor locations ' - 'in the plot are accurate.') + logging.warning( + "Did not find any electrode locations (in the info " + "object), will attempt to use digitization points " + "instead. However, if digitization points do not " + "correspond to the EEG electrodes, this will lead to " + "bad results. Please verify that the sensor locations " + "in the plot are accurate." + ) # MEG/EOG/ECG sensors don't have digitization points; all requested # channels must be EEG for ch in chs: - if ch['kind'] != FIFF.FIFFV_EEG_CH: - raise ValueError("Cannot determine location of MEG/EOG/ECG " - "channels using digitization points.") - - eeg_ch_names = [ch['ch_name'] for ch in info['chs'] - if ch['kind'] == FIFF.FIFFV_EEG_CH] + if ch["kind"] != FIFF.FIFFV_EEG_CH: + raise ValueError( + "Cannot determine location of MEG/EOG/ECG " + "channels using digitization points." + ) + + eeg_ch_names = [ + ch["ch_name"] for ch in info["chs"] if ch["kind"] == FIFF.FIFFV_EEG_CH + ] # Get EEG digitization points - if info['dig'] is None or len(info['dig']) == 0: - raise RuntimeError('No digitization points found.') - - locs3d = np.array([point['r'] for point in info['dig'] - if point['kind'] == FIFF.FIFFV_POINT_EEG]) + if info["dig"] is None or len(info["dig"]) == 0: + raise RuntimeError("No digitization points found.") + + locs3d = np.array( + [ + point["r"] + for point in info["dig"] + if point["kind"] == FIFF.FIFFV_POINT_EEG + ] + ) if len(locs3d) == 0: - raise RuntimeError('Did not find any digitization points of ' - 'kind FIFFV_POINT_EEG (%d) in the info.' - % FIFF.FIFFV_POINT_EEG) + raise RuntimeError( + "Did not find any digitization points of " + "kind FIFFV_POINT_EEG (%d) in the info." % FIFF.FIFFV_POINT_EEG + ) if len(locs3d) != len(eeg_ch_names): - raise ValueError("Number of EEG digitization points (%d) " - "doesn't match the number of EEG channels " - "(%d)" % (len(locs3d), len(eeg_ch_names))) + raise ValueError( + "Number of EEG digitization points (%d) " + "doesn't match the number of EEG channels " + "(%d)" % (len(locs3d), len(eeg_ch_names)) + ) # We no longer center digitization points on head origin, as we work # in head coordinates always @@ -807,22 +882,24 @@ def _auto_topomap_coords(info, picks, ignore_overlap, to_sphere, sphere): # Match the digitization points with the requested # channels. eeg_ch_locs = dict(zip(eeg_ch_names, locs3d)) - locs3d = np.array([eeg_ch_locs[ch['ch_name']] for ch in chs]) + locs3d = np.array([eeg_ch_locs[ch["ch_name"]] for ch in chs]) # Sometimes we can get nans - locs3d[~np.isfinite(locs3d)] = 0. + locs3d[~np.isfinite(locs3d)] = 0.0 # Duplicate points cause all kinds of trouble during visualization dist = pdist(locs3d) if len(locs3d) > 1 and np.min(dist) < 1e-10 and not ignore_overlap: problematic_electrodes = [ - chs[elec_i]['ch_name'] + chs[elec_i]["ch_name"] for elec_i in squareform(dist < 1e-10).any(axis=0).nonzero()[0] ] - raise ValueError('The following electrodes have overlapping positions,' - ' which causes problems during visualization:\n' + - ', '.join(problematic_electrodes)) + raise ValueError( + "The following electrodes have overlapping positions," + " which causes problems during visualization:\n" + + ", ".join(problematic_electrodes) + ) if to_sphere: # translate to sphere origin, transform/flatten Z, translate back @@ -831,7 +908,7 @@ def _auto_topomap_coords(info, picks, ignore_overlap, to_sphere, sphere): cart_coords = _cart_to_sph(locs3d) out = _pol_to_cart(cart_coords[:, 1:][:, ::-1]) # scale from radians to mm - out *= cart_coords[:, [0]] / (np.pi / 2.) + out *= cart_coords[:, [0]] / (np.pi / 2.0) out += sphere[:2] else: out = _pol_to_cart(_cart_to_sph(locs3d)) @@ -862,18 +939,19 @@ def _topo_to_sphere(pos, eegs): xs += 0.5 - np.mean(xs[eegs]) # Center the points ys += 0.5 - np.mean(ys[eegs]) - xs = xs * 2. - 1. # Values ranging from -1 to 1 - ys = ys * 2. - 1. + xs = xs * 2.0 - 1.0 # Values ranging from -1 to 1 + ys = ys * 2.0 - 1.0 - rs = np.clip(np.sqrt(xs ** 2 + ys ** 2), 0., 1.) + rs = np.clip(np.sqrt(xs**2 + ys**2), 0.0, 1.0) alphas = np.arccos(rs) zs = np.sin(alphas) return np.column_stack([xs, ys, zs]) @fill_doc -def _pair_grad_sensors(info, layout=None, topomap_coords=True, exclude='bads', - raise_error=True): +def _pair_grad_sensors( + info, layout=None, topomap_coords=True, exclude="bads", raise_error=True +): """Find the picks for pairing grad channels. Parameters @@ -901,18 +979,18 @@ def _pair_grad_sensors(info, layout=None, topomap_coords=True, exclude='bads', """ # find all complete pairs of grad channels pairs = defaultdict(list) - grad_picks = pick_types(info, meg='grad', ref_meg=False, exclude=exclude) + grad_picks = pick_types(info, meg="grad", ref_meg=False, exclude=exclude) _, has_vv_grad, *_, has_neuromag_122_grad, _ = _get_ch_info(info) for i in grad_picks: - ch = info['chs'][i] - name = ch['ch_name'] - if has_vv_grad and name.startswith('MEG'): - if name.endswith(('2', '3')): + ch = info["chs"][i] + name = ch["ch_name"] + if has_vv_grad and name.startswith("MEG"): + if name.endswith(("2", "3")): key = name[-4:-1] pairs[key].append(ch) - if has_neuromag_122_grad and name.startswith('MEG'): + if has_neuromag_122_grad and name.startswith("MEG"): key = (int(name[-3:]) - 1) // 2 pairs[key].append(ch) @@ -926,13 +1004,12 @@ def _pair_grad_sensors(info, layout=None, topomap_coords=True, exclude='bads', # find the picks corresponding to the grad channels grad_chs = sum(pairs, []) - ch_names = info['ch_names'] - picks = [ch_names.index(c['ch_name']) for c in grad_chs] + ch_names = info["ch_names"] + picks = [ch_names.index(c["ch_name"]) for c in grad_chs] if topomap_coords: shape = (len(pairs), 2, -1) - coords = (_find_topomap_coords(info, picks, layout) - .reshape(shape).mean(axis=1)) + coords = _find_topomap_coords(info, picks, layout).reshape(shape).mean(axis=1) return picks, coords else: return picks @@ -955,8 +1032,8 @@ def _pair_grad_sensors_ch_names_vectorview(ch_names): """ pairs = defaultdict(list) for i, name in enumerate(ch_names): - if name.startswith('MEG'): - if name.endswith(('2', '3')): + if name.startswith("MEG"): + if name.endswith(("2", "3")): key = name[-4:-1] pairs[key].append(i) @@ -983,7 +1060,7 @@ def _pair_grad_sensors_ch_names_neuromag122(ch_names): """ pairs = defaultdict(list) for i, name in enumerate(ch_names): - if name.startswith('MEG'): + if name.startswith("MEG"): key = (int(name[-3:]) - 1) // 2 pairs[key].append(i) @@ -993,7 +1070,7 @@ def _pair_grad_sensors_ch_names_neuromag122(ch_names): return grad_chs -def _merge_ch_data(data, ch_type, names, method='rms'): +def _merge_ch_data(data, ch_type, names, method="rms"): """Merge data from channel pairs. Parameters @@ -1014,7 +1091,7 @@ def _merge_ch_data(data, ch_type, names, method='rms'): names : list List of channel names. """ - if ch_type == 'grad': + if ch_type == "grad": data = _merge_grad_data(data, method) else: assert ch_type in _FNIRS_CH_TYPES_SPLIT @@ -1022,7 +1099,7 @@ def _merge_ch_data(data, ch_type, names, method='rms'): return data, names -def _merge_grad_data(data, method='rms'): +def _merge_grad_data(data, method="rms"): """Merge data from channel pairs using the RMS or mean. Parameters @@ -1038,10 +1115,10 @@ def _merge_grad_data(data, method='rms'): The root mean square or mean for each pair. """ data, orig_shape = data.reshape((len(data) // 2, 2, -1)), data.shape - if method == 'mean': + if method == "mean": data = np.mean(data, axis=1) - elif method == 'rms': - data = np.sqrt(np.sum(data ** 2, axis=1) / 2) + elif method == "rms": + data = np.sqrt(np.sum(data**2, axis=1) / 2) else: raise ValueError('method must be "rms" or "mean", got %s.' % method) return data.reshape(data.shape[:1] + orig_shape[1:]) @@ -1070,7 +1147,7 @@ def _merge_nirs_data(data, merged_names): """ to_remove = np.empty(0, dtype=np.int32) for idx, ch in enumerate(merged_names): - if 'x' in ch: + if "x" in ch: indices = np.empty(0, dtype=np.int32) channels = ch.split("x") for sub_ch in channels[1:]: @@ -1084,9 +1161,17 @@ def _merge_nirs_data(data, merged_names): return data, merged_names -def generate_2d_layout(xy, w=.07, h=.05, pad=.02, ch_names=None, - ch_indices=None, name='ecog', bg_image=None, - normalize=True): +def generate_2d_layout( + xy, + w=0.07, + h=0.05, + pad=0.02, + ch_names=None, + ch_indices=None, + name="ecog", + bg_image=None, + normalize=True, +): """Generate a custom 2D layout from xy points. Generates a 2-D layout for plotting with plot_topo methods and @@ -1137,15 +1222,16 @@ def generate_2d_layout(xy, w=.07, h=.05, pad=.02, ch_names=None, .. versionadded:: 0.9.0 """ import matplotlib.pyplot as plt + if ch_indices is None: ch_indices = np.arange(xy.shape[0]) if ch_names is None: - ch_names = ['{}'.format(i) for i in ch_indices] + ch_names = ["{}".format(i) for i in ch_indices] if len(ch_names) != len(ch_indices): - raise ValueError('# channel names and indices must be equal') + raise ValueError("# channel names and indices must be equal") if len(ch_names) != len(xy): - raise ValueError('# channel names and xy vals must be equal') + raise ValueError("# channel names and xy vals must be equal") x, y = xy.copy().astype(float).T @@ -1159,7 +1245,7 @@ def generate_2d_layout(xy, w=.07, h=.05, pad=.02, ch_names=None, # Normalize x and y by their maxes for i_dim in [x, y]: i_dim -= i_dim.min(0) - i_dim /= (i_dim.max(0) - i_dim.min(0)) + i_dim /= i_dim.max(0) - i_dim.min(0) # Create box and pos variable box = _box_size(np.vstack([x, y]).T, padding=pad) diff --git a/mne/channels/montage.py b/mne/channels/montage.py index 178557bc520..11d48a099c8 100644 --- a/mne/channels/montage.py +++ b/mne/channels/montage.py @@ -22,22 +22,46 @@ from ..defaults import HEAD_SIZE_DEFAULT from .._freesurfer import get_mni_fiducials from ..viz import plot_montage -from ..transforms import (apply_trans, get_ras_to_neuromag_trans, _sph_to_cart, - _topo_to_sph, _frame_to_str, Transform, - _verbose_frames, _fit_matched_points, - _quat_to_affine, _ensure_trans) -from ..io._digitization import (_count_points_by_type, _ensure_fiducials_head, - _get_dig_eeg, _make_dig_points, write_dig, - _read_dig_fif, _format_dig_points, - _get_fid_coords, _coord_frame_const, - _get_data_as_dict_from_dig) +from ..transforms import ( + apply_trans, + get_ras_to_neuromag_trans, + _sph_to_cart, + _topo_to_sph, + _frame_to_str, + Transform, + _verbose_frames, + _fit_matched_points, + _quat_to_affine, + _ensure_trans, +) +from ..io._digitization import ( + _count_points_by_type, + _ensure_fiducials_head, + _get_dig_eeg, + _make_dig_points, + write_dig, + _read_dig_fif, + _format_dig_points, + _get_fid_coords, + _coord_frame_const, + _get_data_as_dict_from_dig, +) from ..io.meas_info import create_info from ..io.open import fiff_open from ..io.pick import pick_types, _picks_to_idx, channel_type from ..io.constants import FIFF, CHANNEL_LOC_ALIASES -from ..utils import (warn, copy_function_doc_to_method_doc, _pl, verbose, - _check_option, _validate_type, _check_fname, _on_missing, - fill_doc, _docdict) +from ..utils import ( + warn, + copy_function_doc_to_method_doc, + _pl, + verbose, + _check_option, + _validate_type, + _check_fname, + _on_missing, + fill_doc, + _docdict, +) from ._dig_montage_utils import _read_dig_montage_egi from ._dig_montage_utils import _parse_brainvision_dig_montage @@ -51,132 +75,133 @@ class _BuiltinStandardMontage: _BUILTIN_STANDARD_MONTAGES = [ _BuiltinStandardMontage( - name='standard_1005', - description='Electrodes are named and positioned according to the ' - 'international 10-05 system (343+3 locations)', + name="standard_1005", + description="Electrodes are named and positioned according to the " + "international 10-05 system (343+3 locations)", ), _BuiltinStandardMontage( - name='standard_1020', - description='Electrodes are named and positioned according to the ' - 'international 10-20 system (94+3 locations)', + name="standard_1020", + description="Electrodes are named and positioned according to the " + "international 10-20 system (94+3 locations)", ), _BuiltinStandardMontage( - name='standard_alphabetic', - description='Electrodes are named with LETTER-NUMBER combinations ' - '(A1, B2, F4, …) (65+3 locations)', + name="standard_alphabetic", + description="Electrodes are named with LETTER-NUMBER combinations " + "(A1, B2, F4, …) (65+3 locations)", ), _BuiltinStandardMontage( - name='standard_postfixed', - description='Electrodes are named according to the international ' - '10-20 system using postfixes for intermediate positions ' - '(100+3 locations)', + name="standard_postfixed", + description="Electrodes are named according to the international " + "10-20 system using postfixes for intermediate positions " + "(100+3 locations)", ), _BuiltinStandardMontage( - name='standard_prefixed', - description='Electrodes are named according to the international ' - '10-20 system using prefixes for intermediate positions ' - '(74+3 locations)', + name="standard_prefixed", + description="Electrodes are named according to the international " + "10-20 system using prefixes for intermediate positions " + "(74+3 locations)", ), _BuiltinStandardMontage( - name='standard_primed', + name="standard_primed", description="Electrodes are named according to the international " - "10-20 system using prime marks (' and '') for " - "intermediate positions (100+3 locations)", + "10-20 system using prime marks (' and '') for " + "intermediate positions (100+3 locations)", ), _BuiltinStandardMontage( - name='biosemi16', - description='BioSemi cap with 16 electrodes (16+3 locations)', + name="biosemi16", + description="BioSemi cap with 16 electrodes (16+3 locations)", ), _BuiltinStandardMontage( - name='biosemi32', - description='BioSemi cap with 32 electrodes (32+3 locations)', + name="biosemi32", + description="BioSemi cap with 32 electrodes (32+3 locations)", ), _BuiltinStandardMontage( - name='biosemi64', - description='BioSemi cap with 64 electrodes (64+3 locations)', + name="biosemi64", + description="BioSemi cap with 64 electrodes (64+3 locations)", ), _BuiltinStandardMontage( - name='biosemi128', - description='BioSemi cap with 128 electrodes (128+3 locations)', + name="biosemi128", + description="BioSemi cap with 128 electrodes (128+3 locations)", ), _BuiltinStandardMontage( - name='biosemi160', - description='BioSemi cap with 160 electrodes (160+3 locations)', + name="biosemi160", + description="BioSemi cap with 160 electrodes (160+3 locations)", ), _BuiltinStandardMontage( - name='biosemi256', - description='BioSemi cap with 256 electrodes (256+3 locations)', + name="biosemi256", + description="BioSemi cap with 256 electrodes (256+3 locations)", ), _BuiltinStandardMontage( - name='easycap-M1', - description='EasyCap with 10-05 electrode names (74 locations)', + name="easycap-M1", + description="EasyCap with 10-05 electrode names (74 locations)", ), _BuiltinStandardMontage( - name='easycap-M10', - description='EasyCap with numbered electrodes (61 locations)', + name="easycap-M10", + description="EasyCap with numbered electrodes (61 locations)", ), _BuiltinStandardMontage( - name='EGI_256', - description='Geodesic Sensor Net (256 locations)', + name="EGI_256", + description="Geodesic Sensor Net (256 locations)", ), _BuiltinStandardMontage( - name='GSN-HydroCel-32', - description='HydroCel Geodesic Sensor Net and Cz (33+3 locations)', + name="GSN-HydroCel-32", + description="HydroCel Geodesic Sensor Net and Cz (33+3 locations)", ), _BuiltinStandardMontage( - name='GSN-HydroCel-64_1.0', - description='HydroCel Geodesic Sensor Net (64+3 locations)', + name="GSN-HydroCel-64_1.0", + description="HydroCel Geodesic Sensor Net (64+3 locations)", ), _BuiltinStandardMontage( - name='GSN-HydroCel-65_1.0', - description='HydroCel Geodesic Sensor Net and Cz (65+3 locations)', + name="GSN-HydroCel-65_1.0", + description="HydroCel Geodesic Sensor Net and Cz (65+3 locations)", ), _BuiltinStandardMontage( - name='GSN-HydroCel-128', - description='HydroCel Geodesic Sensor Net (128+3 locations)', + name="GSN-HydroCel-128", + description="HydroCel Geodesic Sensor Net (128+3 locations)", ), _BuiltinStandardMontage( - name='GSN-HydroCel-129', - description='HydroCel Geodesic Sensor Net and Cz (129+3 locations)', + name="GSN-HydroCel-129", + description="HydroCel Geodesic Sensor Net and Cz (129+3 locations)", ), _BuiltinStandardMontage( - name='GSN-HydroCel-256', - description='HydroCel Geodesic Sensor Net (256+3 locations)', + name="GSN-HydroCel-256", + description="HydroCel Geodesic Sensor Net (256+3 locations)", ), _BuiltinStandardMontage( - name='GSN-HydroCel-257', - description='HydroCel Geodesic Sensor Net and Cz (257+3 locations)', + name="GSN-HydroCel-257", + description="HydroCel Geodesic Sensor Net and Cz (257+3 locations)", ), _BuiltinStandardMontage( - name='mgh60', - description='The (older) 60-channel cap used at MGH (60+3 locations)', + name="mgh60", + description="The (older) 60-channel cap used at MGH (60+3 locations)", ), _BuiltinStandardMontage( - name='mgh70', - description='The (newer) 70-channel BrainVision cap used at MGH ' - '(70+3 locations)', + name="mgh70", + description="The (newer) 70-channel BrainVision cap used at MGH " + "(70+3 locations)", ), _BuiltinStandardMontage( - name='artinis-octamon', - description='Artinis OctaMon fNIRS (8 sources, 2 detectors)', + name="artinis-octamon", + description="Artinis OctaMon fNIRS (8 sources, 2 detectors)", ), _BuiltinStandardMontage( - name='artinis-brite23', - description='Artinis Brite23 fNIRS (11 sources, 7 detectors)', + name="artinis-brite23", + description="Artinis Brite23 fNIRS (11 sources, 7 detectors)", ), _BuiltinStandardMontage( - name='brainproducts-RNP-BA-128', - description='Brain Products with 10-10 electrode names (128 channels)', - ) + name="brainproducts-RNP-BA-128", + description="Brain Products with 10-10 electrode names (128 channels)", + ), ] def _check_get_coord_frame(dig): - dig_coord_frames = sorted(set(d['coord_frame'] for d in dig)) + dig_coord_frames = sorted(set(d["coord_frame"] for d in dig)) if len(dig_coord_frames) != 1: raise RuntimeError( - 'Only a single coordinate frame in dig is supported, got ' - f'{dig_coord_frames}') + "Only a single coordinate frame in dig is supported, got " + f"{dig_coord_frames}" + ) return _frame_to_str[dig_coord_frames.pop()] if dig_coord_frames else None @@ -205,15 +230,20 @@ def get_builtin_montages(*, descriptions=False): If ``descriptions=True``, a list of tuples ``(name, description)``. """ if descriptions: - return [ - (m.name, m.description) for m in _BUILTIN_STANDARD_MONTAGES - ] + return [(m.name, m.description) for m in _BUILTIN_STANDARD_MONTAGES] else: return [m.name for m in _BUILTIN_STANDARD_MONTAGES] -def make_dig_montage(ch_pos=None, nasion=None, lpa=None, rpa=None, - hsp=None, hpi=None, coord_frame='unknown'): +def make_dig_montage( + ch_pos=None, + nasion=None, + lpa=None, + rpa=None, + hsp=None, + hpi=None, + coord_frame="unknown", +): r"""Make montage from arrays. Parameters @@ -263,14 +293,19 @@ def make_dig_montage(ch_pos=None, nasion=None, lpa=None, rpa=None, read_dig_localite read_dig_polhemus_isotrak """ - _validate_type(ch_pos, (dict, None), 'ch_pos') + _validate_type(ch_pos, (dict, None), "ch_pos") if ch_pos is None: ch_names = None else: ch_names = list(ch_pos) dig = _make_dig_points( - nasion=nasion, lpa=lpa, rpa=rpa, hpi=hpi, extra_points=hsp, - dig_ch_pos=ch_pos, coord_frame=coord_frame + nasion=nasion, + lpa=lpa, + rpa=rpa, + hpi=hpi, + extra_points=hsp, + dig_ch_pos=ch_pos, + coord_frame=coord_frame, ) return DigMontage(dig=dig, ch_names=ch_names) @@ -308,13 +343,13 @@ class DigMontage: def __init__(self, *, dig=None, ch_names=None): dig = list() if dig is None else dig - _validate_type(item=dig, types=list, item_name='dig') + _validate_type(item=dig, types=list, item_name="dig") ch_names = list() if ch_names is None else ch_names - n_eeg = sum([1 for d in dig if d['kind'] == FIFF.FIFFV_POINT_EEG]) + n_eeg = sum([1 for d in dig if d["kind"] == FIFF.FIFFV_POINT_EEG]) if n_eeg != len(ch_names): raise ValueError( - 'The number of EEG channels (%d) does not match the number' - ' of channel names provided (%d)' % (n_eeg, len(ch_names)) + "The number of EEG channels (%d) does not match the number" + " of channel names provided (%d)" % (n_eeg, len(ch_names)) ) self.dig = dig @@ -323,15 +358,32 @@ def __init__(self, *, dig=None, ch_names=None): def __repr__(self): """Return string representation.""" n_points = _count_points_by_type(self.dig) - return ('').format(**n_points) + return ( + "" + ).format(**n_points) @copy_function_doc_to_method_doc(plot_montage) - def plot(self, scale_factor=20, show_names=True, kind='topomap', show=True, - sphere=None, *, axes=None, verbose=None): - return plot_montage(self, scale_factor=scale_factor, - show_names=show_names, kind=kind, show=show, - sphere=sphere, axes=axes) + def plot( + self, + scale_factor=20, + show_names=True, + kind="topomap", + show=True, + sphere=None, + *, + axes=None, + verbose=None, + ): + return plot_montage( + self, + scale_factor=scale_factor, + show_names=show_names, + kind=kind, + show=show, + sphere=sphere, + axes=axes, + ) @fill_doc def rename_channels(self, mapping, allow_duplicates=False): @@ -347,9 +399,10 @@ def rename_channels(self, mapping, allow_duplicates=False): The instance. Operates in-place. """ from .channels import rename_channels - temp_info = create_info(list(self._get_ch_pos()), 1000., 'eeg') + + temp_info = create_info(list(self._get_ch_pos()), 1000.0, "eeg") rename_channels(temp_info, mapping, allow_duplicates) - self.ch_names = temp_info['ch_names'] + self.ch_names = temp_info["ch_names"] @verbose def save(self, fname, *, overwrite=False, verbose=None): @@ -374,20 +427,19 @@ def __iadd__(self, other): and if fiducials are present they should share the same coordinate system and location values. """ + def is_fid_defined(fid): - return not ( - fid.nasion is None and fid.lpa is None and fid.rpa is None - ) + return not (fid.nasion is None and fid.lpa is None and fid.rpa is None) # Check for none duplicated ch_names ch_names_intersection = set(self.ch_names).intersection(other.ch_names) if ch_names_intersection: - raise RuntimeError(( - "Cannot add two DigMontage objects if they contain duplicated" - " channel names. Duplicated channel(s) found: {}." - ).format( - ', '.join(['%r' % v for v in sorted(ch_names_intersection)]) - )) + raise RuntimeError( + ( + "Cannot add two DigMontage objects if they contain duplicated" + " channel names. Duplicated channel(s) found: {}." + ).format(", ".join(["%r" % v for v in sorted(ch_names_intersection)])) + ) # Check for unique matching fiducials self_fid, self_coord = _get_fid_coords(self.dig) @@ -395,20 +447,24 @@ def is_fid_defined(fid): if is_fid_defined(self_fid) and is_fid_defined(other_fid): if self_coord != other_coord: - raise RuntimeError('Cannot add two DigMontage objects if ' - 'fiducial locations are not in the same ' - 'coordinate system.') + raise RuntimeError( + "Cannot add two DigMontage objects if " + "fiducial locations are not in the same " + "coordinate system." + ) for kk in self_fid: if not np.array_equal(self_fid[kk], other_fid[kk]): - raise RuntimeError('Cannot add two DigMontage objects if ' - 'fiducial locations do not match ' - '(%s)' % kk) + raise RuntimeError( + "Cannot add two DigMontage objects if " + "fiducial locations do not match " + "(%s)" % kk + ) # keep self self.dig = _format_dig_points( - self.dig + [d for d in other.dig - if d['kind'] != FIFF.FIFFV_POINT_CARDINAL] + self.dig + + [d for d in other.dig if d["kind"] != FIFF.FIFFV_POINT_CARDINAL] ) else: self.dig = _format_dig_points(self.dig + other.dig) @@ -442,13 +498,13 @@ def __eq__(self, other): return self.dig == other.dig and self.ch_names == other.ch_names def _get_ch_pos(self): - pos = [d['r'] for d in _get_dig_eeg(self.dig)] + pos = [d["r"] for d in _get_dig_eeg(self.dig)] assert len(self.ch_names) == len(pos) return OrderedDict(zip(self.ch_names, pos)) def _get_dig_names(self): NAMED_KIND = (FIFF.FIFFV_POINT_EEG,) - is_eeg = np.array([d['kind'] in NAMED_KIND for d in self.dig]) + is_eeg = np.array([d["kind"] in NAMED_KIND for d in self.dig]) assert len(self.ch_names) == is_eeg.sum() dig_names = [None] * len(self.dig) for ch_name_idx, dig_idx in enumerate(np.where(is_eeg)[0]): @@ -509,16 +565,15 @@ def apply_trans(self, trans, verbose=None): The transformation matrix to be applied. %(verbose)s """ - _validate_type(trans, Transform, 'trans') - coord_frame = self.get_positions()['coord_frame'] - trans = _ensure_trans(trans, fro=coord_frame, to=trans['to']) + _validate_type(trans, Transform, "trans") + coord_frame = self.get_positions()["coord_frame"] + trans = _ensure_trans(trans, fro=coord_frame, to=trans["to"]) for d in self.dig: - d['r'] = apply_trans(trans, d['r']) - d['coord_frame'] = trans['to'] + d["r"] = apply_trans(trans, d["r"]) + d["coord_frame"] = trans["to"] @verbose - def add_estimated_fiducials(self, subject, subjects_dir=None, - verbose=None): + def add_estimated_fiducials(self, subject, subjects_dir=None, verbose=None): """Estimate fiducials based on FreeSurfer ``fsaverage`` subject. This takes a montage with the ``mri`` coordinate frame, @@ -558,8 +613,9 @@ def add_estimated_fiducials(self, subject, subjects_dir=None, if montage_bunch.coord_frame != FIFF.FIFFV_COORD_MRI: raise RuntimeError( f'Montage should be in the "mri" coordinate frame ' - f'to use `add_estimated_fiducials`. The current coordinate ' - f'frame is {montage_bunch.coord_frame}') + f"to use `add_estimated_fiducials`. The current coordinate " + f"frame is {montage_bunch.coord_frame}" + ) # estimate LPA, nasion, RPA from FreeSurfer fsaverage fids_mri = list(get_mni_fiducials(subject, subjects_dir)) @@ -598,14 +654,15 @@ def add_mni_fiducials(self, subjects_dir=None, verbose=None): if montage_bunch.coord_frame != FIFF.FIFFV_MNE_COORD_MNI_TAL: raise RuntimeError( f'Montage should be in the "mni_tal" coordinate frame ' - f'to use `add_estimated_fiducials`. The current coordinate ' - f'frame is {montage_bunch.coord_frame}') + f"to use `add_estimated_fiducials`. The current coordinate " + f"frame is {montage_bunch.coord_frame}" + ) - fids_mni = get_mni_fiducials('fsaverage', subjects_dir) + fids_mni = get_mni_fiducials("fsaverage", subjects_dir) for fid in fids_mni: # "mri" and "mni_tal" are equivalent for fsaverage - assert fid['coord_frame'] == FIFF.FIFFV_COORD_MRI - fid['coord_frame'] = FIFF.FIFFV_MNE_COORD_MNI_TAL + assert fid["coord_frame"] == FIFF.FIFFV_COORD_MRI + fid["coord_frame"] = FIFF.FIFFV_MNE_COORD_MNI_TAL self.dig = fids_mni + self.dig return self @@ -632,7 +689,7 @@ def remove_fiducials(self, verbose=None): should not be changed by removing fiducials. """ for d in self.dig.copy(): - if d['kind'] == FIFF.FIFFV_POINT_CARDINAL: + if d["kind"] == FIFF.FIFFV_POINT_CARDINAL: self.dig.remove(d) return self @@ -641,7 +698,7 @@ def remove_fiducials(self, verbose=None): def _check_unit_and_get_scaling(unit): - _check_option('unit', unit, sorted(VALID_SCALES.keys())) + _check_option("unit", unit, sorted(VALID_SCALES.keys())) return VALID_SCALES[unit] @@ -677,11 +734,11 @@ def transform_to_head(montage): # Get fiducial points and their coord_frame native_head_t = compute_native_head_t(montage) montage = montage.copy() # to avoid inplace modification - if native_head_t['from'] != FIFF.FIFFV_COORD_HEAD: + if native_head_t["from"] != FIFF.FIFFV_COORD_HEAD: for d in montage.dig: - if d['coord_frame'] == native_head_t['from']: - d['r'] = apply_trans(native_head_t, d['r']) - d['coord_frame'] = FIFF.FIFFV_COORD_HEAD + if d["coord_frame"] == native_head_t["from"]: + d["r"] = apply_trans(native_head_t, d["r"]) + d["coord_frame"] = FIFF.FIFFV_COORD_HEAD _ensure_fiducials_head(montage.dig) return montage @@ -722,9 +779,10 @@ def read_dig_dat(fname): a plain text editor. """ from ._standard_montage_utils import _check_dupes_odict - fname = _check_fname(fname, overwrite='read', must_exist=True) - with open(fname, 'r') as fid: + fname = _check_fname(fname, overwrite="read", must_exist=True) + + with open(fname, "r") as fid: lines = fid.readlines() ch_names, poss = list(), list() @@ -736,16 +794,17 @@ def read_dig_dat(fname): elif len(items) != 5: raise ValueError( "Error reading %s, line %s has unexpected number of entries:\n" - "%s" % (fname, i, line.rstrip())) + "%s" % (fname, i, line.rstrip()) + ) num = items[1] - if num == '67': + if num == "67": continue # centroid pos = np.array([float(item) for item in items[2:]]) - if num == '78': + if num == "78": nasion = pos - elif num == '76': + elif num == "76": lpa = pos - elif num == '82': + elif num == "82": rpa = pos else: ch_names.append(items[0]) @@ -782,7 +841,7 @@ def read_dig_fif(fname): read_dig_localite make_dig_montage """ - _check_fname(fname, overwrite='read', must_exist=True) + _check_fname(fname, overwrite="read", must_exist=True) # Load the dig data f, tree = fiff_open(fname)[:2] with f as fid: @@ -790,14 +849,14 @@ def read_dig_fif(fname): ch_names = [] for d in dig: - if d['kind'] == FIFF.FIFFV_POINT_EEG: - ch_names.append('EEG%03d' % d['ident']) + if d["kind"] == FIFF.FIFFV_POINT_EEG: + ch_names.append("EEG%03d" % d["ident"]) montage = DigMontage(dig=dig, ch_names=ch_names) return montage -def read_dig_hpts(fname, unit='mm'): +def read_dig_hpts(fname, unit="mm"): """Read historical ``.hpts`` MNE-C files. Parameters @@ -867,26 +926,27 @@ def read_dig_hpts(fname, unit='mm'): """ from ._standard_montage_utils import _str_names, _str - fname = _check_fname(fname, overwrite='read', must_exist=True) + fname = _check_fname(fname, overwrite="read", must_exist=True) _scale = _check_unit_and_get_scaling(unit) - out = np.genfromtxt(fname, comments='#', - dtype=(_str, _str, 'f8', 'f8', 'f8')) - kind, label = _str_names(out['f0']), _str_names(out['f1']) + out = np.genfromtxt(fname, comments="#", dtype=(_str, _str, "f8", "f8", "f8")) + kind, label = _str_names(out["f0"]), _str_names(out["f1"]) kind = [k.lower() for k in kind] - xyz = np.array([out['f%d' % ii] for ii in range(2, 5)]).T + xyz = np.array([out["f%d" % ii] for ii in range(2, 5)]).T xyz *= _scale del _scale - fid_idx_to_label = {'1': 'lpa', '2': 'nasion', '3': 'rpa'} - fid = {fid_idx_to_label[label[ii]]: this_xyz - for ii, this_xyz in enumerate(xyz) if kind[ii] == 'cardinal'} - ch_pos = {label[ii]: this_xyz - for ii, this_xyz in enumerate(xyz) if kind[ii] == 'eeg'} - hpi = np.array([this_xyz for ii, this_xyz in enumerate(xyz) - if kind[ii] == 'hpi']) + fid_idx_to_label = {"1": "lpa", "2": "nasion", "3": "rpa"} + fid = { + fid_idx_to_label[label[ii]]: this_xyz + for ii, this_xyz in enumerate(xyz) + if kind[ii] == "cardinal" + } + ch_pos = { + label[ii]: this_xyz for ii, this_xyz in enumerate(xyz) if kind[ii] == "eeg" + } + hpi = np.array([this_xyz for ii, this_xyz in enumerate(xyz) if kind[ii] == "hpi"]) hpi.shape = (-1, 3) # in case it's empty - hsp = np.array([this_xyz for ii, this_xyz in enumerate(xyz) - if kind[ii] == 'extra']) + hsp = np.array([this_xyz for ii, this_xyz in enumerate(xyz) if kind[ii] == "extra"]) hsp.shape = (-1, 3) # in case it's empty return make_dig_montage(ch_pos=ch_pos, **fid, hpi=hpi, hsp=hsp) @@ -915,12 +975,10 @@ def read_dig_egi(fname): read_dig_polhemus_isotrak make_dig_montage """ - _check_fname(fname, overwrite='read', must_exist=True) + _check_fname(fname, overwrite="read", must_exist=True) data = _read_dig_montage_egi( - fname=fname, - _scaling=1., - _all_data_kwargs_are_none=True + fname=fname, _scaling=1.0, _all_data_kwargs_are_none=True ) return make_dig_montage(**data) @@ -950,7 +1008,7 @@ def read_dig_captrak(fname): read_dig_polhemus_isotrak make_dig_montage """ - _check_fname(fname, overwrite='read', must_exist=True) + _check_fname(fname, overwrite="read", must_exist=True) data = _parse_brainvision_dig_montage(fname, scale=1e-3) return make_dig_montage(**data) @@ -1004,7 +1062,7 @@ def read_dig_localite(fname, nasion=None, lpa=None, rpa=None): def _get_montage_in_head(montage): - coords = set([d['coord_frame'] for d in montage.dig]) + coords = set([d["coord_frame"] for d in montage.dig]) montage = montage.copy() if len(coords) == 1 and coords.pop() == FIFF.FIFFV_COORD_HEAD: _ensure_fiducials_head(montage.dig) @@ -1023,33 +1081,33 @@ def _set_montage_fnirs(info, montage): place. """ from ..preprocessing.nirs import _validate_nirs_info + # Validate that the fNIRS info is correctly formatted picks = _validate_nirs_info(info) # Modify info['chs'][#]['loc'] in place num_ficiduals = len(montage.dig) - len(montage.ch_names) for ch_idx in picks: - ch = info['chs'][ch_idx]['ch_name'] - source, detector = ch.split(' ')[0].split('_') - source_pos = montage.dig[montage.ch_names.index(source) - + num_ficiduals]['r'] - detector_pos = montage.dig[montage.ch_names.index(detector) - + num_ficiduals]['r'] - - info['chs'][ch_idx]['loc'][3:6] = source_pos - info['chs'][ch_idx]['loc'][6:9] = detector_pos + ch = info["chs"][ch_idx]["ch_name"] + source, detector = ch.split(" ")[0].split("_") + source_pos = montage.dig[montage.ch_names.index(source) + num_ficiduals]["r"] + detector_pos = montage.dig[montage.ch_names.index(detector) + num_ficiduals][ + "r" + ] + + info["chs"][ch_idx]["loc"][3:6] = source_pos + info["chs"][ch_idx]["loc"][6:9] = detector_pos midpoint = (source_pos + detector_pos) / 2 - info['chs'][ch_idx]['loc'][:3] = midpoint - info['chs'][ch_idx]['coord_frame'] = FIFF.FIFFV_COORD_HEAD + info["chs"][ch_idx]["loc"][:3] = midpoint + info["chs"][ch_idx]["coord_frame"] = FIFF.FIFFV_COORD_HEAD # Modify info['dig'] in place with info._unlock(): - info['dig'] = montage.dig + info["dig"] = montage.dig @fill_doc -def _set_montage(info, montage, match_case=True, match_alias=False, - on_missing='raise'): +def _set_montage(info, montage, match_case=True, match_alias=False, on_missing="raise"): """Apply montage to data. With a DigMontage, this function will replace the digitizer info with @@ -1070,19 +1128,20 @@ def _set_montage(info, montage, match_case=True, match_alias=False, ----- This function will change the info variable in place. """ - _validate_type(montage, (DigMontage, None, str), 'montage') + _validate_type(montage, (DigMontage, None, str), "montage") if montage is None: # Next line modifies info['dig'] in place with info._unlock(): - info['dig'] = None - for ch in info['chs']: + info["dig"] = None + for ch in info["chs"]: # Next line modifies info['chs'][#]['loc'] in place - ch['loc'] = np.full(12, np.nan) + ch["loc"] = np.full(12, np.nan) return if isinstance(montage, str): # load builtin montage _check_option( - parameter='montage', value=montage, - allowed_values=[m.name for m in _BUILTIN_STANDARD_MONTAGES] + parameter="montage", + value=montage, + allowed_values=[m.name for m in _BUILTIN_STANDARD_MONTAGES], ) montage = make_standard_montage(montage) @@ -1100,66 +1159,72 @@ def _backcompat_value(pos, ref_pos): # only get the eeg, seeg, dbs, ecog channels picks = pick_types( - info, meg=False, eeg=True, seeg=True, dbs=True, ecog=True, - exclude=()) - non_picks = np.setdiff1d(np.arange(info['nchan']), picks) + info, meg=False, eeg=True, seeg=True, dbs=True, ecog=True, exclude=() + ) + non_picks = np.setdiff1d(np.arange(info["nchan"]), picks) # get the reference position from the loc[3:6] - chs = [info['chs'][ii] for ii in picks] - non_names = [info['chs'][ii]['ch_name'] for ii in non_picks] + chs = [info["chs"][ii] for ii in picks] + non_names = [info["chs"][ii]["ch_name"] for ii in non_picks] del picks - ref_pos = [ch['loc'][3:6] for ch in chs] + ref_pos = [ch["loc"][3:6] for ch in chs] # keep reference location from EEG-like channels if they # already exist and are all the same. custom_eeg_ref_dig = False # Note: ref position is an empty list for fieldtrip data if ref_pos: - if all([np.equal(ref_pos[0], pos).all() for pos in ref_pos]) \ - and not np.equal(ref_pos[0], [0, 0, 0]).all(): + if ( + all([np.equal(ref_pos[0], pos).all() for pos in ref_pos]) + and not np.equal(ref_pos[0], [0, 0, 0]).all() + ): eeg_ref_pos = ref_pos[0] # since we have an EEG reference position, we have # to add it into the info['dig'] as EEG000 custom_eeg_ref_dig = True if not custom_eeg_ref_dig: - refs = set(ch_pos) & {'EEG000', 'REF'} + refs = set(ch_pos) & {"EEG000", "REF"} assert len(refs) <= 1 eeg_ref_pos = np.zeros(3) if not refs else ch_pos.pop(refs.pop()) # This raises based on info being subset/superset of montage - info_names = [ch['ch_name'] for ch in chs] + info_names = [ch["ch_name"] for ch in chs] dig_names = mnt_head._get_dig_names() - ref_names = [None, 'EEG000', 'REF'] + ref_names = [None, "EEG000", "REF"] if match_case: info_names_use = info_names dig_names_use = dig_names non_names_use = non_names else: - ch_pos_use = OrderedDict( - (name.lower(), pos) for name, pos in ch_pos.items()) + ch_pos_use = OrderedDict((name.lower(), pos) for name, pos in ch_pos.items()) info_names_use = [name.lower() for name in info_names] - dig_names_use = [name.lower() if name is not None else name - for name in dig_names] + dig_names_use = [ + name.lower() if name is not None else name for name in dig_names + ] non_names_use = [name.lower() for name in non_names] - ref_names = [name.lower() if name is not None else name - for name in ref_names] + ref_names = [name.lower() if name is not None else name for name in ref_names] n_dup = len(ch_pos) - len(ch_pos_use) if n_dup: - raise ValueError('Cannot use match_case=False as %s montage ' - 'name(s) require case sensitivity' % n_dup) + raise ValueError( + "Cannot use match_case=False as %s montage " + "name(s) require case sensitivity" % n_dup + ) n_dup = len(info_names_use) - len(set(info_names_use)) if n_dup: - raise ValueError('Cannot use match_case=False as %s channel ' - 'name(s) require case sensitivity' % n_dup) + raise ValueError( + "Cannot use match_case=False as %s channel " + "name(s) require case sensitivity" % n_dup + ) ch_pos = ch_pos_use del ch_pos_use del dig_names # use lookup table to match unrecognized channel names to known aliases if match_alias: - alias_dict = (match_alias if isinstance(match_alias, dict) else - CHANNEL_LOC_ALIASES) + alias_dict = ( + match_alias if isinstance(match_alias, dict) else CHANNEL_LOC_ALIASES + ) if not match_case: alias_dict = { ch_name.lower(): ch_alias.lower() @@ -1168,16 +1233,11 @@ def _backcompat_value(pos, ref_pos): # excluded ch_alias not in info, to prevent unnecessary mapping and # warning messages based on aliases. - alias_dict = { - ch_name: ch_alias - for ch_name, ch_alias in alias_dict.items() - } + alias_dict = {ch_name: ch_alias for ch_name, ch_alias in alias_dict.items()} info_names_use = [ alias_dict.get(ch_name, ch_name) for ch_name in info_names_use ] - non_names_use = [ - alias_dict.get(ch_name, ch_name) for ch_name in non_names_use - ] + non_names_use = [alias_dict.get(ch_name, ch_name) for ch_name in non_names_use] # warn user if there is not a full overlap of montage with info_chs missing = np.where([use not in ch_pos for use in info_names_use])[0] @@ -1208,42 +1268,47 @@ def _backcompat_value(pos, ref_pos): # will have entries "D1" and "S1". extra = np.where([non in ch_pos for non in non_names_use])[0] if len(extra): - types = '/'.join(sorted(set( - channel_type(info, non_picks[ii]) for ii in extra))) + types = "/".join(sorted(set(channel_type(info, non_picks[ii]) for ii in extra))) names = [non_names[ii] for ii in extra] - warn(f'Not setting position{_pl(extra)} of {len(extra)} {types} ' - f'channel{_pl(extra)} found in montage:\n{names}\n' - 'Consider setting the channel types to be of ' - f'{_docdict["montage_types"]} ' - 'using inst.set_channel_types before calling inst.set_montage, ' - 'or omit these channels when creating your montage.') + warn( + f"Not setting position{_pl(extra)} of {len(extra)} {types} " + f"channel{_pl(extra)} found in montage:\n{names}\n" + "Consider setting the channel types to be of " + f'{_docdict["montage_types"]} ' + "using inst.set_channel_types before calling inst.set_montage, " + "or omit these channels when creating your montage." + ) for ch, use in zip(chs, info_names_use): # Next line modifies info['chs'][#]['loc'] in place if use in ch_pos: - ch['loc'][:6] = _backcompat_value(ch_pos[use], eeg_ref_pos) - ch['coord_frame'] = FIFF.FIFFV_COORD_HEAD + ch["loc"][:6] = _backcompat_value(ch_pos[use], eeg_ref_pos) + ch["coord_frame"] = FIFF.FIFFV_COORD_HEAD del ch_pos # XXX this is probably wrong as it uses the order from the montage # rather than the order of our info['ch_names'] ... digpoints = [ - mnt_head.dig[ii] for ii, name in enumerate(dig_names_use) - if name in (info_names_use + ref_names)] + mnt_head.dig[ii] + for ii, name in enumerate(dig_names_use) + if name in (info_names_use + ref_names) + ] # get a copy of the old dig - if info['dig'] is not None: - old_dig = info['dig'].copy() + if info["dig"] is not None: + old_dig = info["dig"].copy() else: old_dig = [] # determine if needed to add an extra EEG REF DigPoint if custom_eeg_ref_dig: # ref_name = 'EEG000' if match_case else 'eeg000' - ref_dig_dict = {'kind': FIFF.FIFFV_POINT_EEG, - 'r': eeg_ref_pos, - 'ident': 0, - 'coord_frame': info['dig'].pop()['coord_frame']} + ref_dig_dict = { + "kind": FIFF.FIFFV_POINT_EEG, + "r": eeg_ref_pos, + "ident": 0, + "coord_frame": info["dig"].pop()["coord_frame"], + } ref_dig_point = _format_dig_points([ref_dig_dict])[0] # only append the reference dig point if it was already # in the old dig @@ -1251,7 +1316,7 @@ def _backcompat_value(pos, ref_pos): digpoints.append(ref_dig_point) # Next line modifies info['dig'] in place with info._unlock(): - info['dig'] = _format_dig_points(digpoints, enforce_order=True) + info["dig"] = _format_dig_points(digpoints, enforce_order=True) del digpoints # TODO: Ideally we would have a check like this, but read_raw_bids for ECoG @@ -1267,7 +1332,7 @@ def _backcompat_value(pos, ref_pos): # 'not happen. Please contact MNE-Python developers.') # Handle fNIRS with source, detector and channel - fnirs_picks = _picks_to_idx(info, 'fnirs', allow_empty=True) + fnirs_picks = _picks_to_idx(info, "fnirs", allow_empty=True) if len(fnirs_picks) > 0: _set_montage_fnirs(info, mnt_head) @@ -1292,13 +1357,16 @@ def _read_isotrak_elp_points(fname): with open(fname) as fid: file_str = fid.read() - points_str = [m.groups() for m in re.finditer(coord_pattern, file_str, - re.MULTILINE)] + points_str = [ + m.groups() for m in re.finditer(coord_pattern, file_str, re.MULTILINE) + ] points = np.array(points_str, dtype=float) return { - 'nasion': points[0], 'lpa': points[1], 'rpa': points[2], - 'points': points[3:] + "nasion": points[0], + "lpa": points[1], + "rpa": points[2], + "points": points[3:], } @@ -1316,12 +1384,13 @@ def _read_isotrak_hsp_points(fname): The dictionary containing locations for 'nasion', 'lpa', 'rpa' and 'points'. """ + def get_hsp_fiducial(line): - return np.fromstring(line.replace('%F', ''), dtype=float, sep='\t') + return np.fromstring(line.replace("%F", ""), dtype=float, sep="\t") with open(fname) as ff: for line in ff: - if 'position of fiducials' in line.lower(): + if "position of fiducials" in line.lower(): break nasion = get_hsp_fiducial(ff.readline()) @@ -1331,20 +1400,20 @@ def get_hsp_fiducial(line): _ = ff.readline() line = ff.readline() if line: - n_points, n_cols = np.fromstring(line, dtype=int, sep='\t') + n_points, n_cols = np.fromstring(line, dtype=int, sep="\t") points = np.fromstring( - string=ff.read(), dtype=float, sep='\t', + string=ff.read(), + dtype=float, + sep="\t", ).reshape(-1, n_cols) assert points.shape[0] == n_points else: points = np.empty((0, 3)) - return { - 'nasion': nasion, 'lpa': lpa, 'rpa': rpa, 'points': points - } + return {"nasion": nasion, "lpa": lpa, "rpa": rpa, "points": points} -def read_dig_polhemus_isotrak(fname, ch_names=None, unit='m'): +def read_dig_polhemus_isotrak(fname, ch_names=None, unit="m"): """Read Polhemus digitizer data from a file. Parameters @@ -1377,14 +1446,14 @@ def read_dig_polhemus_isotrak(fname, ch_names=None, unit='m'): read_dig_fif read_dig_localite """ - VALID_FILE_EXT = ('.hsp', '.elp', '.eeg') + VALID_FILE_EXT = (".hsp", ".elp", ".eeg") fname = str(_check_fname(fname, overwrite="read", must_exist=True)) _scale = _check_unit_and_get_scaling(unit) _, ext = op.splitext(fname) - _check_option('fname', ext, VALID_FILE_EXT) + _check_option("fname", ext, VALID_FILE_EXT) - if ext == '.elp': + if ext == ".elp": data = _read_isotrak_elp_points(fname) else: # Default case we read points as hsp since is the most likely scenario @@ -1396,39 +1465,40 @@ def read_dig_polhemus_isotrak(fname, ch_names=None, unit='m'): pass # noqa if ch_names is None: - keyword = 'hpi' if ext == '.elp' else 'hsp' - data[keyword] = data.pop('points') + keyword = "hpi" if ext == ".elp" else "hsp" + data[keyword] = data.pop("points") else: - points = data.pop('points') + points = data.pop("points") if points.shape[0] == len(ch_names): - data['ch_pos'] = OrderedDict(zip(ch_names, points)) + data["ch_pos"] = OrderedDict(zip(ch_names, points)) else: - raise ValueError(( - "Length of ``ch_names`` does not match the number of points" - " in {fname}. Expected ``ch_names`` length {n_points:d}," - " given {n_chnames:d}" - ).format( - fname=fname, n_points=points.shape[0], n_chnames=len(ch_names) - )) + raise ValueError( + ( + "Length of ``ch_names`` does not match the number of points" + " in {fname}. Expected ``ch_names`` length {n_points:d}," + " given {n_chnames:d}" + ).format(fname=fname, n_points=points.shape[0], n_chnames=len(ch_names)) + ) return make_dig_montage(**data) def _is_polhemus_fastscan(fname): - header = '' - with open(fname, 'r') as fid: + header = "" + with open(fname, "r") as fid: for line in fid: - if not line.startswith('%'): + if not line.startswith("%"): break header += line - return 'FastSCAN' in header + return "FastSCAN" in header @verbose -def read_polhemus_fastscan(fname, unit='mm', on_header_missing='raise', *, - verbose=None): +def read_polhemus_fastscan( + fname, unit="mm", on_header_missing="raise", *, verbose=None +): """Read Polhemus FastSCAN digitizer data from a ``.txt`` file. Parameters @@ -1451,18 +1521,18 @@ def read_polhemus_fastscan(fname, unit='mm', on_header_missing='raise', *, read_dig_polhemus_isotrak make_dig_montage """ - VALID_FILE_EXT = ['.txt'] + VALID_FILE_EXT = [".txt"] fname = str(_check_fname(fname, overwrite="read", must_exist=True)) _scale = _check_unit_and_get_scaling(unit) _, ext = op.splitext(fname) - _check_option('fname', ext, VALID_FILE_EXT) + _check_option("fname", ext, VALID_FILE_EXT) if not _is_polhemus_fastscan(fname): msg = "%s does not contain a valid Polhemus FastSCAN header" % fname _on_missing(on_header_missing, msg) - points = _scale * np.loadtxt(fname, comments='%', ndmin=2) + points = _scale * np.loadtxt(fname, comments="%", ndmin=2) _check_dig_shape(points) return points @@ -1521,66 +1591,75 @@ def read_custom_montage(fname, head_size=HEAD_SIZE_DEFAULT, coord_frame=None): :func:`make_dig_montage` that takes arrays as input. """ from ._standard_montage_utils import ( - _read_theta_phi_in_degrees, _read_sfp, _read_csd, _read_elc, - _read_elp_besa, _read_brainvision, _read_xyz + _read_theta_phi_in_degrees, + _read_sfp, + _read_csd, + _read_elc, + _read_elp_besa, + _read_brainvision, + _read_xyz, ) + SUPPORTED_FILE_EXT = { - 'eeglab': ('.loc', '.locs', '.eloc', ), - 'hydrocel': ('.sfp', ), - 'matlab': ('.csd', ), - 'asa electrode': ('.elc', ), - 'generic (Theta-phi in degrees)': ('.txt', ), - 'standard BESA spherical': ('.elp', ), # NB: not same as polhemus elp - 'brainvision': ('.bvef', ), - 'xyz': ('.csv', '.tsv', '.xyz'), + "eeglab": ( + ".loc", + ".locs", + ".eloc", + ), + "hydrocel": (".sfp",), + "matlab": (".csd",), + "asa electrode": (".elc",), + "generic (Theta-phi in degrees)": (".txt",), + "standard BESA spherical": (".elp",), # NB: not same as polhemus elp + "brainvision": (".bvef",), + "xyz": (".csv", ".tsv", ".xyz"), } fname = str(_check_fname(fname, overwrite="read", must_exist=True)) _, ext = op.splitext(fname) - _check_option('fname', ext, list(sum(SUPPORTED_FILE_EXT.values(), ()))) + _check_option("fname", ext, list(sum(SUPPORTED_FILE_EXT.values(), ()))) - if ext in SUPPORTED_FILE_EXT['eeglab']: + if ext in SUPPORTED_FILE_EXT["eeglab"]: if head_size is None: - raise ValueError( - "``head_size`` cannot be None for '{}'".format(ext)) + raise ValueError("``head_size`` cannot be None for '{}'".format(ext)) ch_names, pos = _read_eeglab_locations(fname) scale = head_size / np.median(np.linalg.norm(pos, axis=-1)) pos *= scale montage = make_dig_montage( ch_pos=OrderedDict(zip(ch_names, pos)), - coord_frame='head', + coord_frame="head", ) - elif ext in SUPPORTED_FILE_EXT['hydrocel']: + elif ext in SUPPORTED_FILE_EXT["hydrocel"]: montage = _read_sfp(fname, head_size=head_size) - elif ext in SUPPORTED_FILE_EXT['matlab']: + elif ext in SUPPORTED_FILE_EXT["matlab"]: montage = _read_csd(fname, head_size=head_size) - elif ext in SUPPORTED_FILE_EXT['asa electrode']: + elif ext in SUPPORTED_FILE_EXT["asa electrode"]: montage = _read_elc(fname, head_size=head_size) - elif ext in SUPPORTED_FILE_EXT['generic (Theta-phi in degrees)']: + elif ext in SUPPORTED_FILE_EXT["generic (Theta-phi in degrees)"]: if head_size is None: - raise ValueError( - "``head_size`` cannot be None for '{}'".format(ext)) - montage = _read_theta_phi_in_degrees(fname, head_size=head_size, - fid_names=('Nz', 'LPA', 'RPA')) + raise ValueError("``head_size`` cannot be None for '{}'".format(ext)) + montage = _read_theta_phi_in_degrees( + fname, head_size=head_size, fid_names=("Nz", "LPA", "RPA") + ) - elif ext in SUPPORTED_FILE_EXT['standard BESA spherical']: + elif ext in SUPPORTED_FILE_EXT["standard BESA spherical"]: montage = _read_elp_besa(fname, head_size) - elif ext in SUPPORTED_FILE_EXT['brainvision']: + elif ext in SUPPORTED_FILE_EXT["brainvision"]: montage = _read_brainvision(fname, head_size) - elif ext in SUPPORTED_FILE_EXT['xyz']: + elif ext in SUPPORTED_FILE_EXT["xyz"]: montage = _read_xyz(fname) if coord_frame is not None: coord_frame = _coord_frame_const(coord_frame) for d in montage.dig: - d['coord_frame'] = coord_frame + d["coord_frame"] = coord_frame return montage @@ -1602,31 +1681,49 @@ def compute_dev_head_t(montage): """ _, coord_frame = _get_fid_coords(montage.dig) if coord_frame != FIFF.FIFFV_COORD_HEAD: - raise ValueError('montage should have been set to head coordinate ' - 'system with transform_to_head function.') + raise ValueError( + "montage should have been set to head coordinate " + "system with transform_to_head function." + ) hpi_head = np.array( - [d['r'] for d in montage.dig - if (d['kind'] == FIFF.FIFFV_POINT_HPI and - d['coord_frame'] == FIFF.FIFFV_COORD_HEAD)], float) + [ + d["r"] + for d in montage.dig + if ( + d["kind"] == FIFF.FIFFV_POINT_HPI + and d["coord_frame"] == FIFF.FIFFV_COORD_HEAD + ) + ], + float, + ) hpi_dev = np.array( - [d['r'] for d in montage.dig - if (d['kind'] == FIFF.FIFFV_POINT_HPI and - d['coord_frame'] == FIFF.FIFFV_COORD_DEVICE)], float) + [ + d["r"] + for d in montage.dig + if ( + d["kind"] == FIFF.FIFFV_POINT_HPI + and d["coord_frame"] == FIFF.FIFFV_COORD_DEVICE + ) + ], + float, + ) if not (len(hpi_head) == len(hpi_dev) and len(hpi_dev) > 0): - raise ValueError(( - "To compute Device-to-Head transformation, the same number of HPI" - " points in device and head coordinates is required. (Got {dev}" - " points in device and {head} points in head coordinate systems)" - ).format(dev=len(hpi_dev), head=len(hpi_head))) + raise ValueError( + ( + "To compute Device-to-Head transformation, the same number of HPI" + " points in device and head coordinates is required. (Got {dev}" + " points in device and {head} points in head coordinate systems)" + ).format(dev=len(hpi_dev), head=len(hpi_head)) + ) trans = _quat_to_affine(_fit_matched_points(hpi_dev, hpi_head)[0]) - return Transform(fro='meg', to='head', trans=trans) + return Transform(fro="meg", to="head", trans=trans) @verbose -def compute_native_head_t(montage, *, on_missing='warn', verbose=None): +def compute_native_head_t(montage, *, on_missing="warn", verbose=None): """Compute the native-to-head transformation for a montage. This uses the fiducials in the native space to transform to compute the @@ -1653,23 +1750,25 @@ def compute_native_head_t(montage, *, on_missing='warn', verbose=None): if coord_frame == FIFF.FIFFV_COORD_HEAD: native_head_t = np.eye(4) else: - fid_keys = ('nasion', 'lpa', 'rpa') + fid_keys = ("nasion", "lpa", "rpa") for key in fid_keys: this_coord = fid_coords[key] if this_coord is None or np.any(np.isnan(this_coord)): msg = ( - f'Fiducial point {key} not found, assuming identity ' - f'{_verbose_frames[coord_frame]} to head transformation') + f"Fiducial point {key} not found, assuming identity " + f"{_verbose_frames[coord_frame]} to head transformation" + ) _on_missing(on_missing, msg, error_klass=RuntimeError) native_head_t = np.eye(4) break else: native_head_t = get_ras_to_neuromag_trans( - *[fid_coords[key] for key in fid_keys]) - return Transform(coord_frame, 'head', native_head_t) + *[fid_coords[key] for key in fid_keys] + ) + return Transform(coord_frame, "head", native_head_t) -def make_standard_montage(kind, head_size='auto'): +def make_standard_montage(kind, head_size="auto"): """Read a generic (built-in) standard montage that ships with MNE-Python. Parameters @@ -1708,15 +1807,17 @@ def make_standard_montage(kind, head_size='auto'): .. versionadded:: 0.19.0 """ from ._standard_montage_utils import standard_montage_look_up_table - _validate_type(kind, str, 'kind') + + _validate_type(kind, str, "kind") _check_option( - parameter='kind', value=kind, - allowed_values=[m.name for m in _BUILTIN_STANDARD_MONTAGES] + parameter="kind", + value=kind, + allowed_values=[m.name for m in _BUILTIN_STANDARD_MONTAGES], ) - _validate_type(head_size, ('numeric', str, None), 'head_size') + _validate_type(head_size, ("numeric", str, None), "head_size") if isinstance(head_size, str): - _check_option('head_size', head_size, ('auto',), extra='when str') - if kind.startswith(('standard', 'mgh', 'artinis')): + _check_option("head_size", head_size, ("auto",), extra="when str") + if kind.startswith(("standard", "mgh", "artinis")): head_size = None else: head_size = HEAD_SIZE_DEFAULT @@ -1724,7 +1825,6 @@ def make_standard_montage(kind, head_size='auto'): def _check_dig_shape(pts): - _validate_type(pts, np.ndarray, 'points') + _validate_type(pts, np.ndarray, "points") if pts.ndim != 2 or pts.shape[-1] != 3: - raise ValueError( - f'Points must be of shape (n, 3) instead of {pts.shape}') + raise ValueError(f"Points must be of shape (n, 3) instead of {pts.shape}") diff --git a/mne/channels/tests/test_channels.py b/mne/channels/tests/test_channels.py index 04f07d84ec3..2b719b7e3af 100644 --- a/mne/channels/tests/test_channels.py +++ b/mne/channels/tests/test_channels.py @@ -13,19 +13,42 @@ from scipy.io import savemat from numpy.testing import assert_array_equal, assert_equal, assert_allclose -from mne.channels import (rename_channels, read_ch_adjacency, combine_channels, - find_ch_adjacency, make_1020_channel_selections, - read_custom_montage, equalize_channels, - get_builtin_ch_adjacencies) +from mne.channels import ( + rename_channels, + read_ch_adjacency, + combine_channels, + find_ch_adjacency, + make_1020_channel_selections, + read_custom_montage, + equalize_channels, + get_builtin_ch_adjacencies, +) from mne.channels.channels import ( - _ch_neighbor_adjacency, _compute_ch_adjacency, - _BUILTIN_CHANNEL_ADJACENCIES, _BuiltinChannelAdjacency + _ch_neighbor_adjacency, + _compute_ch_adjacency, + _BUILTIN_CHANNEL_ADJACENCIES, + _BuiltinChannelAdjacency, +) +from mne.io import ( + read_info, + read_raw_fif, + read_raw_ctf, + read_raw_bti, + read_raw_eeglab, + read_raw_kit, + RawArray, ) -from mne.io import (read_info, read_raw_fif, read_raw_ctf, read_raw_bti, - read_raw_eeglab, read_raw_kit, RawArray) from mne.io.constants import FIFF -from mne import (pick_types, pick_channels, EpochsArray, EvokedArray, - make_ad_hoc_cov, create_info, read_events, Epochs) +from mne import ( + pick_types, + pick_channels, + EpochsArray, + EvokedArray, + make_ad_hoc_cov, + create_info, + read_events, + Epochs, +) from mne.datasets import testing from mne.utils import requires_pandas, requires_version from mne.parallel import parallel_func @@ -38,8 +61,8 @@ testing_path = testing.data_path(download=False) -@pytest.mark.parametrize('preload', (True, False)) -@pytest.mark.parametrize('proj', (True, False)) +@pytest.mark.parametrize("preload", (True, False)) +@pytest.mark.parametrize("proj", (True, False)) def test_reorder_channels(preload, proj): """Test reordering of channels.""" raw = read_raw_fif(raw_fname).crop(0, 0.1).del_proj() @@ -49,7 +72,7 @@ def test_reorder_channels(preload, proj): raw.load_data() # with .reorder_channels if proj and not preload: - with pytest.raises(RuntimeError, match='load data'): + with pytest.raises(RuntimeError, match="load data"): raw.copy().reorder_channels(raw.ch_names[::-1]) return raw_new = raw.copy().reorder_channels(raw.ch_names[::-1]) @@ -63,7 +86,7 @@ def test_reorder_channels(preload, proj): raw_new.reorder_channels(raw_new.ch_names[::-1][1:-1]) raw.drop_channels(raw.ch_names[:1] + raw.ch_names[-1:]) assert_array_equal(raw[:][0], raw_new[:][0]) - with pytest.raises(ValueError, match='repeated'): + with pytest.raises(ValueError, match="repeated"): raw.reorder_channels(raw.ch_names[:1] + raw.ch_names[:1]) # and with .pick reord = [1, 0] + list(range(2, len(raw.ch_names))) @@ -77,41 +100,41 @@ def test_rename_channels(): info = read_info(raw_fname) # Error Tests # Test channel name exists in ch_names - mapping = {'EEG 160': 'EEG060'} + mapping = {"EEG 160": "EEG060"} pytest.raises(ValueError, rename_channels, info, mapping) # Test improper mapping configuration - mapping = {'MEG 2641': 1.0} + mapping = {"MEG 2641": 1.0} pytest.raises(TypeError, rename_channels, info, mapping) # Test non-unique mapping configuration - mapping = {'MEG 2641': 'MEG 2642'} + mapping = {"MEG 2641": "MEG 2642"} pytest.raises(ValueError, rename_channels, info, mapping) # Test bad input - pytest.raises(ValueError, rename_channels, info, 1.) - pytest.raises(ValueError, rename_channels, info, 1.) + pytest.raises(ValueError, rename_channels, info, 1.0) + pytest.raises(ValueError, rename_channels, info, 1.0) # Test successful changes # Test ch_name and ch_names are changed info2 = deepcopy(info) # for consistency at the start of each test - info2['bads'] = ['EEG 060', 'EOG 061'] - mapping = {'EEG 060': 'EEG060', 'EOG 061': 'EOG061'} + info2["bads"] = ["EEG 060", "EOG 061"] + mapping = {"EEG 060": "EEG060", "EOG 061": "EOG061"} rename_channels(info2, mapping) - assert info2['chs'][374]['ch_name'] == 'EEG060' - assert info2['ch_names'][374] == 'EEG060' - assert info2['chs'][375]['ch_name'] == 'EOG061' - assert info2['ch_names'][375] == 'EOG061' - assert_array_equal(['EEG060', 'EOG061'], info2['bads']) + assert info2["chs"][374]["ch_name"] == "EEG060" + assert info2["ch_names"][374] == "EEG060" + assert info2["chs"][375]["ch_name"] == "EOG061" + assert info2["ch_names"][375] == "EOG061" + assert_array_equal(["EEG060", "EOG061"], info2["bads"]) info2 = deepcopy(info) - rename_channels(info2, lambda x: x.replace(' ', '')) - assert info2['chs'][373]['ch_name'] == 'EEG059' + rename_channels(info2, lambda x: x.replace(" ", "")) + assert info2["chs"][373]["ch_name"] == "EEG059" info2 = deepcopy(info) - info2['bads'] = ['EEG 060', 'EEG 060'] + info2["bads"] = ["EEG 060", "EEG 060"] rename_channels(info2, mapping) - assert_array_equal(['EEG060', 'EEG060'], info2['bads']) + assert_array_equal(["EEG060", "EEG060"], info2["bads"]) # test that keys in Raw._orig_units will be renamed, too raw = read_raw_fif(raw_fname).crop(0, 0.1) - old, new = 'EEG 060', 'New' - raw._orig_units = {old: 'V'} + old, new = "EEG 060", "New" + raw._orig_units = {old: "V"} raw.rename_channels({old: new}) assert old not in raw._orig_units @@ -123,74 +146,81 @@ def test_set_channel_types(): raw = read_raw_fif(raw_fname) # Error Tests # Test channel name exists in ch_names - mapping = {'EEG 160': 'EEG060'} + mapping = {"EEG 160": "EEG060"} with pytest.raises(ValueError, match=r"name \(EEG 160\) doesn't exist"): raw.set_channel_types(mapping) # Test change to illegal channel type - mapping = {'EOG 061': 'xxx'} - with pytest.raises(ValueError, match='cannot change to this channel type'): + mapping = {"EOG 061": "xxx"} + with pytest.raises(ValueError, match="cannot change to this channel type"): raw.set_channel_types(mapping) # Test changing type if in proj - mapping = {'EEG 057': 'dbs', 'EEG 058': 'ecog', 'EEG 059': 'ecg', - 'EEG 060': 'eog', 'EOG 061': 'seeg', 'MEG 2441': 'eeg', - 'MEG 2443': 'eeg', 'MEG 2442': 'hbo', 'EEG 001': 'resp'} + mapping = { + "EEG 057": "dbs", + "EEG 058": "ecog", + "EEG 059": "ecg", + "EEG 060": "eog", + "EOG 061": "seeg", + "MEG 2441": "eeg", + "MEG 2443": "eeg", + "MEG 2442": "hbo", + "EEG 001": "resp", + } raw2 = read_raw_fif(raw_fname) - raw2.info['bads'] = ['EEG 059', 'EEG 060', 'EOG 061'] + raw2.info["bads"] = ["EEG 059", "EEG 060", "EOG 061"] with pytest.raises(RuntimeError, match='type .* in projector "PCA-v1"'): raw2.set_channel_types(mapping) # has prj raw2.add_proj([], remove_existing=True) # Should raise - with pytest.raises(ValueError, match='unit for channel.* has changed'): - raw2.copy().set_channel_types(mapping, on_unit_change='raise') + with pytest.raises(ValueError, match="unit for channel.* has changed"): + raw2.copy().set_channel_types(mapping, on_unit_change="raise") # Should warn - with pytest.warns(RuntimeWarning, match='unit for channel.* has changed'): + with pytest.warns(RuntimeWarning, match="unit for channel.* has changed"): raw2.copy().set_channel_types(mapping) # Shouldn't warn - raw2.set_channel_types(mapping, on_unit_change='ignore') + raw2.set_channel_types(mapping, on_unit_change="ignore") info = raw2.info - assert info['chs'][371]['ch_name'] == 'EEG 057' - assert info['chs'][371]['kind'] == FIFF.FIFFV_DBS_CH - assert info['chs'][371]['unit'] == FIFF.FIFF_UNIT_V - assert info['chs'][371]['coil_type'] == FIFF.FIFFV_COIL_EEG - assert info['chs'][372]['ch_name'] == 'EEG 058' - assert info['chs'][372]['kind'] == FIFF.FIFFV_ECOG_CH - assert info['chs'][372]['unit'] == FIFF.FIFF_UNIT_V - assert info['chs'][372]['coil_type'] == FIFF.FIFFV_COIL_EEG - assert info['chs'][373]['ch_name'] == 'EEG 059' - assert info['chs'][373]['kind'] == FIFF.FIFFV_ECG_CH - assert info['chs'][373]['unit'] == FIFF.FIFF_UNIT_V - assert info['chs'][373]['coil_type'] == FIFF.FIFFV_COIL_NONE - assert info['chs'][374]['ch_name'] == 'EEG 060' - assert info['chs'][374]['kind'] == FIFF.FIFFV_EOG_CH - assert info['chs'][374]['unit'] == FIFF.FIFF_UNIT_V - assert info['chs'][374]['coil_type'] == FIFF.FIFFV_COIL_NONE - assert info['chs'][375]['ch_name'] == 'EOG 061' - assert info['chs'][375]['kind'] == FIFF.FIFFV_SEEG_CH - assert info['chs'][375]['unit'] == FIFF.FIFF_UNIT_V - assert info['chs'][375]['coil_type'] == FIFF.FIFFV_COIL_EEG - for idx in pick_channels(raw.ch_names, ['MEG 2441', 'MEG 2443'], - ordered=False): - assert info['chs'][idx]['kind'] == FIFF.FIFFV_EEG_CH - assert info['chs'][idx]['unit'] == FIFF.FIFF_UNIT_V - assert info['chs'][idx]['coil_type'] == FIFF.FIFFV_COIL_EEG - idx = pick_channels(raw.ch_names, ['MEG 2442'])[0] - assert info['chs'][idx]['kind'] == FIFF.FIFFV_FNIRS_CH - assert info['chs'][idx]['unit'] == FIFF.FIFF_UNIT_MOL - assert info['chs'][idx]['coil_type'] == FIFF.FIFFV_COIL_FNIRS_HBO + assert info["chs"][371]["ch_name"] == "EEG 057" + assert info["chs"][371]["kind"] == FIFF.FIFFV_DBS_CH + assert info["chs"][371]["unit"] == FIFF.FIFF_UNIT_V + assert info["chs"][371]["coil_type"] == FIFF.FIFFV_COIL_EEG + assert info["chs"][372]["ch_name"] == "EEG 058" + assert info["chs"][372]["kind"] == FIFF.FIFFV_ECOG_CH + assert info["chs"][372]["unit"] == FIFF.FIFF_UNIT_V + assert info["chs"][372]["coil_type"] == FIFF.FIFFV_COIL_EEG + assert info["chs"][373]["ch_name"] == "EEG 059" + assert info["chs"][373]["kind"] == FIFF.FIFFV_ECG_CH + assert info["chs"][373]["unit"] == FIFF.FIFF_UNIT_V + assert info["chs"][373]["coil_type"] == FIFF.FIFFV_COIL_NONE + assert info["chs"][374]["ch_name"] == "EEG 060" + assert info["chs"][374]["kind"] == FIFF.FIFFV_EOG_CH + assert info["chs"][374]["unit"] == FIFF.FIFF_UNIT_V + assert info["chs"][374]["coil_type"] == FIFF.FIFFV_COIL_NONE + assert info["chs"][375]["ch_name"] == "EOG 061" + assert info["chs"][375]["kind"] == FIFF.FIFFV_SEEG_CH + assert info["chs"][375]["unit"] == FIFF.FIFF_UNIT_V + assert info["chs"][375]["coil_type"] == FIFF.FIFFV_COIL_EEG + for idx in pick_channels(raw.ch_names, ["MEG 2441", "MEG 2443"], ordered=False): + assert info["chs"][idx]["kind"] == FIFF.FIFFV_EEG_CH + assert info["chs"][idx]["unit"] == FIFF.FIFF_UNIT_V + assert info["chs"][idx]["coil_type"] == FIFF.FIFFV_COIL_EEG + idx = pick_channels(raw.ch_names, ["MEG 2442"])[0] + assert info["chs"][idx]["kind"] == FIFF.FIFFV_FNIRS_CH + assert info["chs"][idx]["unit"] == FIFF.FIFF_UNIT_MOL + assert info["chs"][idx]["coil_type"] == FIFF.FIFFV_COIL_FNIRS_HBO # resp channel type - idx = pick_channels(raw.ch_names, ['EEG 001'])[0] - assert info['chs'][idx]['kind'] == FIFF.FIFFV_RESP_CH - assert info['chs'][idx]['unit'] == FIFF.FIFF_UNIT_V - assert info['chs'][idx]['coil_type'] == FIFF.FIFFV_COIL_NONE + idx = pick_channels(raw.ch_names, ["EEG 001"])[0] + assert info["chs"][idx]["kind"] == FIFF.FIFFV_RESP_CH + assert info["chs"][idx]["unit"] == FIFF.FIFF_UNIT_V + assert info["chs"][idx]["coil_type"] == FIFF.FIFFV_COIL_NONE # Test meaningful error when setting channel type with unknown unit - raw.info['chs'][0]['unit'] = 0. - ch_types = {raw.ch_names[0]: 'misc'} + raw.info["chs"][0]["unit"] = 0.0 + ch_types = {raw.ch_names[0]: "misc"} pytest.raises(ValueError, raw.set_channel_types, ch_types) @@ -208,17 +238,21 @@ def test_get_builtin_ch_adjacencies(): def test_read_ch_adjacency(tmp_path): """Test reading channel adjacency templates.""" - a = partial(np.array, dtype=' ps # are channels in the correct selection? @@ -405,9 +443,9 @@ def test_1020_selection(): def test_find_ch_adjacency(): """Test computing the adjacency matrix.""" raw = read_raw_fif(raw_fname, preload=True) - sizes = {'mag': 828, 'grad': 1700, 'eeg': 384} - nchans = {'mag': 102, 'grad': 204, 'eeg': 60} - for ch_type in ['mag', 'grad', 'eeg']: + sizes = {"mag": 828, "grad": 1700, "eeg": 384} + nchans = {"mag": 102, "grad": 204, "eeg": 60} + for ch_type in ["mag", "grad", "eeg"]: conn, ch_names = find_ch_adjacency(raw.info, ch_type) # Silly test for checking the number of neighbors. assert_equal(conn.getnnz(), sizes[ch_type]) @@ -415,30 +453,30 @@ def test_find_ch_adjacency(): pytest.raises(ValueError, find_ch_adjacency, raw.info, None) # Test computing the conn matrix with gradiometers. - conn, ch_names = _compute_ch_adjacency(raw.info, 'grad') + conn, ch_names = _compute_ch_adjacency(raw.info, "grad") assert_equal(conn.getnnz(), 2680) # Test ch_type=None. - raw.pick_types(meg='mag') + raw.pick_types(meg="mag") find_ch_adjacency(raw.info, None) bti_fname = testing_path / "BTi" / "erm_HFH" / "c,rfDC" bti_config_name = testing_path / "BTi" / "erm_HFH" / "config" raw = read_raw_bti(bti_fname, bti_config_name, None) - _, ch_names = find_ch_adjacency(raw.info, 'mag') - assert 'A1' in ch_names + _, ch_names = find_ch_adjacency(raw.info, "mag") + assert "A1" in ch_names ctf_fname = testing_path / "CTF" / "testdata_ctf_short.ds" raw = read_raw_ctf(ctf_fname) - _, ch_names = find_ch_adjacency(raw.info, 'mag') - assert 'MLC11' in ch_names + _, ch_names = find_ch_adjacency(raw.info, "mag") + assert "MLC11" in ch_names - pytest.raises(ValueError, find_ch_adjacency, raw.info, 'eog') + pytest.raises(ValueError, find_ch_adjacency, raw.info, "eog") raw_kit = read_raw_kit(fname_kit_157) - neighb, ch_names = find_ch_adjacency(raw_kit.info, 'mag') + neighb, ch_names = find_ch_adjacency(raw_kit.info, "mag") assert neighb.data.size == 1329 - assert ch_names[0] == 'MEG 001' + assert ch_names[0] == "MEG 001" @testing.requires_testing_data @@ -446,7 +484,7 @@ def test_neuromag122_adjacency(): """Test computing the adjacency matrix of Neuromag122-Data.""" nm122_fname = testing_path / "misc" / "neuromag122_test_file-raw.fif" raw = read_raw_fif(nm122_fname, preload=True) - conn, ch_names = find_ch_adjacency(raw.info, 'grad') + conn, ch_names = find_ch_adjacency(raw.info, "grad") assert conn.getnnz() == 1564 assert len(ch_names) == 122 assert conn.shape == (122, 122) @@ -463,13 +501,13 @@ def test_drop_channels(): # by default, drop channels raises a ValueError if a channel can't be found m_chs = ["MEG 0111", "MEG blahblah"] - with pytest.raises(ValueError, match='not found, nothing dropped'): + with pytest.raises(ValueError, match="not found, nothing dropped"): raw.drop_channels(m_chs) # ...but this can be turned to a warning - with pytest.warns(RuntimeWarning, match='not found, nothing dropped'): - raw.drop_channels(m_chs, on_missing='warn') + with pytest.warns(RuntimeWarning, match="not found, nothing dropped"): + raw.drop_channels(m_chs, on_missing="warn") # ...or ignored altogether - raw.drop_channels(m_chs, on_missing='ignore') + raw.drop_channels(m_chs, on_missing="ignore") def test_pick_channels(): @@ -477,17 +515,17 @@ def test_pick_channels(): raw = read_raw_fif(raw_fname, preload=True).crop(0, 0.1) # selected correctly 3 channels - raw.pick(['MEG 0113', 'MEG 0112', 'MEG 0111']) + raw.pick(["MEG 0113", "MEG 0112", "MEG 0111"]) assert len(raw.ch_names) == 3 # selected correctly 3 channels and ignored 'meg', and emit warning - with pytest.raises(ValueError, match='not present in the info'): - raw.pick(['MEG 0113', "meg", 'MEG 0112', 'MEG 0111']) + with pytest.raises(ValueError, match="not present in the info"): + raw.pick(["MEG 0113", "meg", "MEG 0112", "MEG 0111"]) names_len = len(raw.ch_names) - raw.pick(['all']) # selected correctly all channels + raw.pick(["all"]) # selected correctly all channels assert len(raw.ch_names) == names_len - raw.pick('all') # selected correctly all channels + raw.pick("all") # selected correctly all channels assert len(raw.ch_names) == names_len @@ -502,16 +540,16 @@ def test_add_reference_channels(): n_evoked_original_channels = len(evoked.ch_names) # Raw object - raw.add_reference_channels(['REF 123']) + raw.add_reference_channels(["REF 123"]) assert len(raw.ch_names) == n_raw_original_channels + 1 assert np.all(raw.get_data()[-1] == 0) # Epochs object - epochs.add_reference_channels(['REF 123']) + epochs.add_reference_channels(["REF 123"]) assert epochs._data.shape[1] == epochs_original_shape + 1 # Evoked object - evoked.add_reference_channels(['REF 123']) + evoked.add_reference_channels(["REF 123"]) assert len(evoked.ch_names) == n_evoked_original_channels + 1 assert np.all(evoked._data[-1] == 0) @@ -521,30 +559,35 @@ def test_equalize_channels(): # This function only tests the generic functionality of equalize_channels. # Additional tests for each instance type are included in the accompanying # test suite for each type. - pytest.raises(TypeError, equalize_channels, ['foo', 'bar'], - match='Instances to be modified must be an instance of') + pytest.raises( + TypeError, + equalize_channels, + ["foo", "bar"], + match="Instances to be modified must be an instance of", + ) - raw = RawArray([[1.], [2.], [3.], [4.]], - create_info(['CH1', 'CH2', 'CH3', 'CH4'], sfreq=1.)) - epochs = EpochsArray([[[1.], [2.], [3.]]], - create_info(['CH5', 'CH2', 'CH1'], sfreq=1.)) - cov = make_ad_hoc_cov(create_info(['CH2', 'CH1', 'CH8'], sfreq=1., - ch_types='eeg')) - cov['bads'] = ['CH1'] - ave = EvokedArray([[1.], [2.]], create_info(['CH1', 'CH2'], sfreq=1.)) + raw = RawArray( + [[1.0], [2.0], [3.0], [4.0]], + create_info(["CH1", "CH2", "CH3", "CH4"], sfreq=1.0), + ) + epochs = EpochsArray( + [[[1.0], [2.0], [3.0]]], create_info(["CH5", "CH2", "CH1"], sfreq=1.0) + ) + cov = make_ad_hoc_cov(create_info(["CH2", "CH1", "CH8"], sfreq=1.0, ch_types="eeg")) + cov["bads"] = ["CH1"] + ave = EvokedArray([[1.0], [2.0]], create_info(["CH1", "CH2"], sfreq=1.0)) - raw2, epochs2, cov2, ave2 = equalize_channels([raw, epochs, cov, ave], - copy=True) + raw2, epochs2, cov2, ave2 = equalize_channels([raw, epochs, cov, ave], copy=True) # The Raw object was the first in the list, so should have been used as # template for the ordering of the channels. No bad channels should have # been dropped. - assert raw2.ch_names == ['CH1', 'CH2'] - assert_array_equal(raw2.get_data(), [[1.], [2.]]) - assert epochs2.ch_names == ['CH1', 'CH2'] - assert_array_equal(epochs2.get_data(), [[[3.], [2.]]]) - assert cov2.ch_names == ['CH1', 'CH2'] - assert cov2['bads'] == cov['bads'] + assert raw2.ch_names == ["CH1", "CH2"] + assert_array_equal(raw2.get_data(), [[1.0], [2.0]]) + assert epochs2.ch_names == ["CH1", "CH2"] + assert_array_equal(epochs2.get_data(), [[[3.0], [2.0]]]) + assert cov2.ch_names == ["CH1", "CH2"] + assert cov2["bads"] == cov["bads"] assert ave2.ch_names == ave.ch_names assert_array_equal(ave2.data, ave.data) @@ -565,7 +608,7 @@ def test_combine_channels(): """Test channel combination on Raw, Epochs, and Evoked.""" raw = read_raw_fif(raw_fname, preload=True) raw_ch_bad = read_raw_fif(raw_fname, preload=True) - raw_ch_bad.info['bads'] = ['MEG 0113', 'MEG 0112'] + raw_ch_bad.info["bads"] = ["MEG 0113", "MEG 0112"] epochs = Epochs(raw, read_events(eve_fname)) evoked = epochs.average() good = dict(foo=[0, 1, 3, 4], bar=[5, 2]) # good grad and mag @@ -583,35 +626,32 @@ def test_combine_channels(): # Test with stimulus channels combine_stim = combine_channels(raw, good, keep_stim=True) target_nchan = len(good) + len(pick_types(raw.info, meg=False, stim=True)) - assert combine_stim.info['nchan'] == target_nchan + assert combine_stim.info["nchan"] == target_nchan # Test results with one ROI good_single = dict(foo=[0, 1, 3, 4]) # good grad - combined_mean = combine_channels(raw, good_single, method='mean') - combined_median = combine_channels(raw, good_single, method='median') - combined_std = combine_channels(raw, good_single, method='std') - foo_mean = np.mean(raw.get_data()[good_single['foo']], axis=0) - foo_median = np.median(raw.get_data()[good_single['foo']], axis=0) - foo_std = np.std(raw.get_data()[good_single['foo']], axis=0) - assert_array_equal(combined_mean.get_data(), - np.expand_dims(foo_mean, axis=0)) - assert_array_equal(combined_median.get_data(), - np.expand_dims(foo_median, axis=0)) - assert_array_equal(combined_std.get_data(), - np.expand_dims(foo_std, axis=0)) + combined_mean = combine_channels(raw, good_single, method="mean") + combined_median = combine_channels(raw, good_single, method="median") + combined_std = combine_channels(raw, good_single, method="std") + foo_mean = np.mean(raw.get_data()[good_single["foo"]], axis=0) + foo_median = np.median(raw.get_data()[good_single["foo"]], axis=0) + foo_std = np.std(raw.get_data()[good_single["foo"]], axis=0) + assert_array_equal(combined_mean.get_data(), np.expand_dims(foo_mean, axis=0)) + assert_array_equal(combined_median.get_data(), np.expand_dims(foo_median, axis=0)) + assert_array_equal(combined_std.get_data(), np.expand_dims(foo_std, axis=0)) # Test bad cases bad1 = dict(foo=[0, 376], bar=[5, 2]) # out of bounds bad2 = dict(foo=[0, 2], bar=[5, 2]) # type mix in same group with pytest.raises(ValueError, match='"method" must be a callable, or'): - combine_channels(raw, good, method='bad_method') + combine_channels(raw, good, method="bad_method") with pytest.raises(TypeError, match='"keep_stim" must be of type bool'): - combine_channels(raw, good, keep_stim='bad_type') + combine_channels(raw, good, keep_stim="bad_type") with pytest.raises(TypeError, match='"drop_bad" must be of type bool'): - combine_channels(raw, good, drop_bad='bad_type') - with pytest.raises(ValueError, match='Some channel indices are out of'): + combine_channels(raw, good, drop_bad="bad_type") + with pytest.raises(ValueError, match="Some channel indices are out of"): combine_channels(raw, bad1) - with pytest.raises(ValueError, match='Cannot combine sensors of diff'): + with pytest.raises(ValueError, match="Cannot combine sensors of diff"): combine_channels(raw, bad2) # Test warnings @@ -620,9 +660,9 @@ def test_combine_channels(): warn1 = dict(foo=[375, 375], bar=[5, 2]) # same channel in same group warn2 = dict(foo=[375], bar=[5, 2]) # one channel (last channel) warn3 = dict(foo=[0, 4], bar=[5, 2]) # one good channel left - with pytest.warns(RuntimeWarning, match='Could not find stimulus'): + with pytest.warns(RuntimeWarning, match="Could not find stimulus"): combine_channels(raw_no_stim, good, keep_stim=True) - with pytest.warns(RuntimeWarning, match='Less than 2 channels') as record: + with pytest.warns(RuntimeWarning, match="Less than 2 channels") as record: combine_channels(raw, warn1) combine_channels(raw, warn2) combine_channels(raw_ch_bad, warn3, drop_bad=True) @@ -637,8 +677,7 @@ def test_combine_channels_metadata(): raw = read_raw_fif(raw_fname, preload=True) epochs = Epochs(raw, read_events(eve_fname), preload=True) - metadata = pd.DataFrame({"A": np.arange(len(epochs)), - "B": np.ones(len(epochs))}) + metadata = pd.DataFrame({"A": np.arange(len(epochs)), "B": np.ones(len(epochs))}) epochs.metadata = metadata good = dict(foo=[0, 1, 3, 4], bar=[5, 2]) # good grad and mag diff --git a/mne/channels/tests/test_interpolation.py b/mne/channels/tests/test_interpolation.py index f6c71d1ff00..2425db488eb 100644 --- a/mne/channels/tests/test_interpolation.py +++ b/mne/channels/tests/test_interpolation.py @@ -8,8 +8,11 @@ from mne import io, pick_types, pick_channels, read_events, Epochs from mne.channels.interpolation import _make_interpolation_matrix from mne.datasets import testing -from mne.preprocessing.nirs import (optical_density, scalp_coupling_index, - beer_lambert_law) +from mne.preprocessing.nirs import ( + optical_density, + scalp_coupling_index, + beer_lambert_law, +) from mne.io import read_raw_nirx from mne.io.proj import _has_eeg_average_ref_proj from mne.utils import _record_warnings, requires_version @@ -30,95 +33,118 @@ def _load_data(kind): raw = io.read_raw_fif(raw_fname) events = read_events(event_name) # subselect channels for speed - if kind == 'eeg': + if kind == "eeg": picks = pick_types(raw.info, meg=False, eeg=True, exclude=[])[:15] - epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, - preload=True, reject=dict(eeg=80e-6)) + epochs = Epochs( + raw, + events, + event_id, + tmin, + tmax, + picks=picks, + preload=True, + reject=dict(eeg=80e-6), + ) else: picks = pick_types(raw.info, meg=True, eeg=False, exclude=[])[1:200:2] - assert kind == 'meg' - with pytest.warns(RuntimeWarning, match='projection'): - epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, - preload=True, - reject=dict(grad=1000e-12, mag=4e-12)) + assert kind == "meg" + with pytest.warns(RuntimeWarning, match="projection"): + epochs = Epochs( + raw, + events, + event_id, + tmin, + tmax, + picks=picks, + preload=True, + reject=dict(grad=1000e-12, mag=4e-12), + ) return raw, epochs -@pytest.mark.parametrize('offset', (0., 0.1)) -@pytest.mark.parametrize('avg_proj, ctol', [ - (True, (0.86, 0.93)), - (False, (0.97, 0.99)), -]) -@pytest.mark.parametrize('method, atol', [ - pytest.param(None, 3e-6, marks=pytest.mark.slowtest), # slow on Azure - (dict(eeg='MNE'), 4e-6), -]) -@pytest.mark.filterwarnings('ignore:.*than 20 mm from head frame origin.*') +@pytest.mark.parametrize("offset", (0.0, 0.1)) +@pytest.mark.parametrize( + "avg_proj, ctol", + [ + (True, (0.86, 0.93)), + (False, (0.97, 0.99)), + ], +) +@pytest.mark.parametrize( + "method, atol", + [ + pytest.param(None, 3e-6, marks=pytest.mark.slowtest), # slow on Azure + (dict(eeg="MNE"), 4e-6), + ], +) +@pytest.mark.filterwarnings("ignore:.*than 20 mm from head frame origin.*") def test_interpolation_eeg(offset, avg_proj, ctol, atol, method): """Test interpolation of EEG channels.""" - raw, epochs_eeg = _load_data('eeg') + raw, epochs_eeg = _load_data("eeg") epochs_eeg = epochs_eeg.copy() assert not _has_eeg_average_ref_proj(epochs_eeg.info) # Offsetting the coordinate frame should have no effect on the output for inst in (raw, epochs_eeg): - for ch in inst.info['chs']: - if ch['kind'] == io.constants.FIFF.FIFFV_EEG_CH: - ch['loc'][:3] += offset - ch['loc'][3:6] += offset - for d in inst.info['dig']: - d['r'] += offset + for ch in inst.info["chs"]: + if ch["kind"] == io.constants.FIFF.FIFFV_EEG_CH: + ch["loc"][:3] += offset + ch["loc"][3:6] += offset + for d in inst.info["dig"]: + d["r"] += offset # check that interpolation does nothing if no bads are marked - epochs_eeg.info['bads'] = [] + epochs_eeg.info["bads"] = [] evoked_eeg = epochs_eeg.average() kw = dict(method=method) - with pytest.warns(RuntimeWarning, match='Doing nothing'): + with pytest.warns(RuntimeWarning, match="Doing nothing"): evoked_eeg.interpolate_bads(**kw) # create good and bad channels for EEG - epochs_eeg.info['bads'] = [] + epochs_eeg.info["bads"] = [] goods_idx = np.ones(len(epochs_eeg.ch_names), dtype=bool) - goods_idx[epochs_eeg.ch_names.index('EEG 012')] = False + goods_idx[epochs_eeg.ch_names.index("EEG 012")] = False bads_idx = ~goods_idx pos = epochs_eeg._get_channel_positions() evoked_eeg = epochs_eeg.average() if avg_proj: evoked_eeg.set_eeg_reference(projection=True).apply_proj() - assert_allclose(evoked_eeg.data.mean(0), 0., atol=1e-20) + assert_allclose(evoked_eeg.data.mean(0), 0.0, atol=1e-20) ave_before = evoked_eeg.data[bads_idx] # interpolate bad channels for EEG - epochs_eeg.info['bads'] = ['EEG 012'] + epochs_eeg.info["bads"] = ["EEG 012"] evoked_eeg = epochs_eeg.average() if avg_proj: evoked_eeg.set_eeg_reference(projection=True).apply_proj() good_picks = pick_types(evoked_eeg.info, meg=False, eeg=True) - assert_allclose(evoked_eeg.data[good_picks].mean(0), 0., atol=1e-20) + assert_allclose(evoked_eeg.data[good_picks].mean(0), 0.0, atol=1e-20) evoked_eeg_bad = evoked_eeg.copy() bads_picks = pick_channels( - epochs_eeg.ch_names, include=epochs_eeg.info['bads'], ordered=True + epochs_eeg.ch_names, include=epochs_eeg.info["bads"], ordered=True ) evoked_eeg_bad.data[bads_picks, :] = 1e10 # Test first the exclude parameter evoked_eeg_2_bads = evoked_eeg_bad.copy() - evoked_eeg_2_bads.info['bads'] = ['EEG 004', 'EEG 012'] + evoked_eeg_2_bads.info["bads"] = ["EEG 004", "EEG 012"] evoked_eeg_2_bads.data[ - pick_channels(evoked_eeg_bad.ch_names, ['EEG 004', 'EEG 012']) + pick_channels(evoked_eeg_bad.ch_names, ["EEG 004", "EEG 012"]) ] = 1e10 evoked_eeg_interp = evoked_eeg_2_bads.interpolate_bads( - origin=(0., 0., 0.), exclude=['EEG 004'], **kw) - assert evoked_eeg_interp.info['bads'] == ['EEG 004'] - assert np.all(evoked_eeg_interp.get_data('EEG 004') == 1e10) - assert np.all(evoked_eeg_interp.get_data('EEG 012') != 1e10) + origin=(0.0, 0.0, 0.0), exclude=["EEG 004"], **kw + ) + assert evoked_eeg_interp.info["bads"] == ["EEG 004"] + assert np.all(evoked_eeg_interp.get_data("EEG 004") == 1e10) + assert np.all(evoked_eeg_interp.get_data("EEG 012") != 1e10) # Now test without exclude parameter - evoked_eeg_bad.info['bads'] = ['EEG 012'] + evoked_eeg_bad.info["bads"] = ["EEG 012"] evoked_eeg_interp = evoked_eeg_bad.copy().interpolate_bads( - origin=(0., 0., 0.), **kw) + origin=(0.0, 0.0, 0.0), **kw + ) if avg_proj: - assert_allclose(evoked_eeg_interp.data.mean(0), 0., atol=1e-6) + assert_allclose(evoked_eeg_interp.data.mean(0), 0.0, atol=1e-6) interp_zero = evoked_eeg_interp.data[bads_idx] if method is None: # using pos_good = pos[goods_idx] @@ -136,7 +162,7 @@ def test_interpolation_eeg(offset, avg_proj, ctol, atol, method): # check that interpolation fails when preload is False epochs_eeg.preload = False - with pytest.raises(RuntimeError, match='requires epochs data to be loade'): + with pytest.raises(RuntimeError, match="requires epochs data to be loade"): epochs_eeg.interpolate_bads(**kw) epochs_eeg.preload = True @@ -148,10 +174,10 @@ def test_interpolation_eeg(offset, avg_proj, ctol, atol, method): # check that interpolation fails when preload is False for inst in [raw, epochs_eeg]: - assert hasattr(inst, 'preload') + assert hasattr(inst, "preload") inst.preload = False - inst.info['bads'] = [inst.ch_names[1]] - with pytest.raises(RuntimeError, match='requires.*data to be loaded'): + inst.info["bads"] = [inst.ch_names[1]] + with pytest.raises(RuntimeError, match="requires.*data to be loaded"): inst.interpolate_bads(**kw) # check that interpolation works with few channels @@ -159,11 +185,11 @@ def test_interpolation_eeg(offset, avg_proj, ctol, atol, method): raw_few.pick_channels(raw_few.ch_names[:1] + raw_few.ch_names[3:4]) assert len(raw_few.ch_names) == 2 raw_few.del_proj() - raw_few.info['bads'] = [raw_few.ch_names[-1]] + raw_few.info["bads"] = [raw_few.ch_names[-1]] orig_data = raw_few[1][0] with _record_warnings() as w: raw_few.interpolate_bads(reset_bads=False, **kw) - assert len([ww for ww in w if 'more than' not in str(ww.message)]) == 0 + assert len([ww for ww in w if "more than" not in str(ww.message)]) == 0 new_data = raw_few[1][0] assert (new_data == 0).mean() < 0.5 assert np.corrcoef(new_data, orig_data)[0, 1] > 0.2 @@ -176,82 +202,80 @@ def test_interpolation_meg(): # correlation drops thresh = 0.68 - raw, epochs_meg = _load_data('meg') + raw, epochs_meg = _load_data("meg") # check that interpolation works when non M/EEG channels are present # before MEG channels raw.crop(0, 0.1).load_data().pick_channels(epochs_meg.ch_names) raw.info.normalize_proj() - raw.set_channel_types({raw.ch_names[0]: 'stim'}, on_unit_change='ignore') - raw.info['bads'] = [raw.ch_names[1]] + raw.set_channel_types({raw.ch_names[0]: "stim"}, on_unit_change="ignore") + raw.info["bads"] = [raw.ch_names[1]] raw.load_data() - raw.interpolate_bads(mode='fast') + raw.interpolate_bads(mode="fast") del raw # check that interpolation works for MEG - epochs_meg.info['bads'] = ['MEG 0141'] + epochs_meg.info["bads"] = ["MEG 0141"] evoked = epochs_meg.average() - pick = pick_channels(epochs_meg.info['ch_names'], epochs_meg.info['bads']) + pick = pick_channels(epochs_meg.info["ch_names"], epochs_meg.info["bads"]) # MEG -- raw raw_meg = io.RawArray(data=epochs_meg._data[0], info=epochs_meg.info) - raw_meg.info['bads'] = ['MEG 0141'] + raw_meg.info["bads"] = ["MEG 0141"] data1 = raw_meg[pick, :][0][0] raw_meg.info.normalize_proj() - data2 = raw_meg.interpolate_bads(reset_bads=False, - mode='fast')[pick, :][0][0] + data2 = raw_meg.interpolate_bads(reset_bads=False, mode="fast")[pick, :][0][0] assert np.corrcoef(data1, data2)[0, 1] > thresh # the same number of bads as before - assert len(raw_meg.info['bads']) == len(raw_meg.info['bads']) + assert len(raw_meg.info["bads"]) == len(raw_meg.info["bads"]) # MEG -- epochs data1 = epochs_meg.get_data()[:, pick, :].ravel() epochs_meg.info.normalize_proj() - epochs_meg.interpolate_bads(mode='fast') + epochs_meg.interpolate_bads(mode="fast") data2 = epochs_meg.get_data()[:, pick, :].ravel() assert np.corrcoef(data1, data2)[0, 1] > thresh - assert len(epochs_meg.info['bads']) == 0 + assert len(epochs_meg.info["bads"]) == 0 # MEG -- evoked (plus auto origin) data1 = evoked.data[pick] evoked.info.normalize_proj() - data2 = evoked.interpolate_bads(origin='auto').data[pick] + data2 = evoked.interpolate_bads(origin="auto").data[pick] assert np.corrcoef(data1, data2)[0, 1] > thresh # MEG -- with exclude - evoked.info['bads'] = ['MEG 0141', 'MEG 0121'] - pick = pick_channels(evoked.ch_names, evoked.info['bads'], ordered=True) + evoked.info["bads"] = ["MEG 0141", "MEG 0121"] + pick = pick_channels(evoked.ch_names, evoked.info["bads"], ordered=True) evoked.data[pick[-1]] = 1e10 data1 = evoked.data[pick] evoked.info.normalize_proj() - data2 = evoked.interpolate_bads( - origin='auto', exclude=['MEG 0121'] - ).data[pick] + data2 = evoked.interpolate_bads(origin="auto", exclude=["MEG 0121"]).data[pick] assert np.corrcoef(data1[0], data2[0])[0, 1] > thresh assert np.all(data2[1] == 1e10) def _this_interpol(inst, ref_meg=False): from mne.channels.interpolation import _interpolate_bads_meg - _interpolate_bads_meg(inst, ref_meg=ref_meg, mode='fast') + + _interpolate_bads_meg(inst, ref_meg=ref_meg, mode="fast") return inst @pytest.mark.slowtest def test_interpolate_meg_ctf(): """Test interpolation of MEG channels from CTF system.""" - thresh = .85 - tol = .05 # assert the new interpol correlates at least .05 "better" - bad = 'MLC22-2622' # select a good channel to test the interpolation + thresh = 0.85 + tol = 0.05 # assert the new interpol correlates at least .05 "better" + bad = "MLC22-2622" # select a good channel to test the interpolation raw = io.read_raw_fif(raw_fname_ctf).crop(0, 1.0).load_data() # 3 secs raw.apply_gradient_compensation(3) # Show that we have to exclude ref_meg for interpolating CTF MEG-channels # (fixed in #5965): - raw.info['bads'] = [bad] - pick_bad = pick_channels(raw.info['ch_names'], raw.info['bads']) + raw.info["bads"] = [bad] + pick_bad = pick_channels(raw.info["ch_names"], raw.info["bads"]) data_orig = raw[pick_bad, :][0] # mimic old behavior (the ref_meg-arg in _interpolate_bads_meg only serves # this purpose): @@ -260,12 +284,12 @@ def test_interpolate_meg_ctf(): data_interp_no_refmeg = _this_interpol(raw, ref_meg=False)[pick_bad, :][0] R = dict() - R['no_refmeg'] = np.corrcoef(data_orig, data_interp_no_refmeg)[0, 1] - R['with_refmeg'] = np.corrcoef(data_orig, data_interp_refmeg)[0, 1] + R["no_refmeg"] = np.corrcoef(data_orig, data_interp_no_refmeg)[0, 1] + R["with_refmeg"] = np.corrcoef(data_orig, data_interp_refmeg)[0, 1] - print('Corrcoef of interpolated with original channel: ', R) - assert R['no_refmeg'] > R['with_refmeg'] + tol - assert R['no_refmeg'] > thresh + print("Corrcoef of interpolated with original channel: ", R) + assert R["no_refmeg"] > R["with_refmeg"] + tol + assert R["no_refmeg"] > thresh @testing.requires_testing_data @@ -273,33 +297,30 @@ def test_interpolation_ctf_comp(): """Test interpolation with compensated CTF data.""" raw_fname = testing_path / "CTF" / "somMDYO-18av.ds" raw = io.read_raw_ctf(raw_fname, preload=True) - raw.info['bads'] = [raw.ch_names[5], raw.ch_names[-5]] - raw.interpolate_bads(mode='fast', origin=(0., 0., 0.04)) - assert raw.info['bads'] == [] + raw.info["bads"] = [raw.ch_names[5], raw.ch_names[-5]] + raw.interpolate_bads(mode="fast", origin=(0.0, 0.0, 0.04)) + assert raw.info["bads"] == [] -@requires_version('pymatreader') +@requires_version("pymatreader") @testing.requires_testing_data def test_interpolation_nirs(): """Test interpolating bad nirs channels.""" - fname = ( - testing_path / "NIRx" / "nirscout" / "nirx_15_2_recording_w_overlap" - ) + fname = testing_path / "NIRx" / "nirscout" / "nirx_15_2_recording_w_overlap" raw_intensity = read_raw_nirx(fname, preload=False) raw_od = optical_density(raw_intensity) sci = scalp_coupling_index(raw_od) - raw_od.info['bads'] = list(compress(raw_od.ch_names, sci < 0.5)) - bad_0 = np.where([name == raw_od.info['bads'][0] for - name in raw_od.ch_names])[0][0] + raw_od.info["bads"] = list(compress(raw_od.ch_names, sci < 0.5)) + bad_0 = np.where([name == raw_od.info["bads"][0] for name in raw_od.ch_names])[0][0] bad_0_std_pre_interp = np.std(raw_od._data[bad_0]) - bads_init = list(raw_od.info['bads']) + bads_init = list(raw_od.info["bads"]) raw_od.interpolate_bads(exclude=bads_init[:2]) - assert raw_od.info['bads'] == bads_init[:2] + assert raw_od.info["bads"] == bads_init[:2] raw_od.interpolate_bads() - assert raw_od.info['bads'] == [] + assert raw_od.info["bads"] == [] assert bad_0_std_pre_interp > np.std(raw_od._data[bad_0]) raw_haemo = beer_lambert_law(raw_od, ppf=6) - raw_haemo.info['bads'] = raw_haemo.ch_names[2:4] - assert raw_haemo.info['bads'] == ['S1_D2 hbo', 'S1_D2 hbr'] + raw_haemo.info["bads"] = raw_haemo.ch_names[2:4] + assert raw_haemo.info["bads"] == ["S1_D2 hbo", "S1_D2 hbr"] raw_haemo.interpolate_bads() - assert raw_haemo.info['bads'] == [] + assert raw_haemo.info["bads"] == [] diff --git a/mne/channels/tests/test_layout.py b/mne/channels/tests/test_layout.py index e17f90cafaf..2362fb2e23b 100644 --- a/mne/channels/tests/test_layout.py +++ b/mne/channels/tests/test_layout.py @@ -9,15 +9,23 @@ from pathlib import Path import numpy as np -from numpy.testing import (assert_array_almost_equal, assert_array_equal, - assert_allclose, assert_equal) +from numpy.testing import ( + assert_array_almost_equal, + assert_array_equal, + assert_allclose, + assert_equal, +) import pytest import matplotlib.pyplot as plt -from mne.channels import (make_eeg_layout, make_grid_layout, read_layout, - find_layout, HEAD_SIZE_DEFAULT) -from mne.channels.layout import (_box_size, _find_topomap_coords, - generate_2d_layout) +from mne.channels import ( + make_eeg_layout, + make_grid_layout, + read_layout, + find_layout, + HEAD_SIZE_DEFAULT, +) +from mne.channels.layout import _box_size, _find_topomap_coords, generate_2d_layout from mne import pick_types, pick_info from mne.io import read_raw_kit, _empty_info, read_info from mne.io.constants import FIFF @@ -34,18 +42,50 @@ def _get_test_info(): """Make test info.""" test_info = _empty_info(1000) - loc = np.array([0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1.], - dtype=np.float32) - test_info['chs'] = [ - {'cal': 1, 'ch_name': 'ICA 001', 'coil_type': 0, 'coord_frame': 0, - 'kind': 502, 'loc': loc.copy(), 'logno': 1, 'range': 1.0, 'scanno': 1, - 'unit': -1, 'unit_mul': 0}, - {'cal': 1, 'ch_name': 'ICA 002', 'coil_type': 0, 'coord_frame': 0, - 'kind': 502, 'loc': loc.copy(), 'logno': 2, 'range': 1.0, 'scanno': 2, - 'unit': -1, 'unit_mul': 0}, - {'cal': 0.002142000012099743, 'ch_name': 'EOG 061', 'coil_type': 1, - 'coord_frame': 0, 'kind': 202, 'loc': loc.copy(), 'logno': 61, - 'range': 1.0, 'scanno': 376, 'unit': 107, 'unit_mul': 0}] + loc = np.array( + [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0], dtype=np.float32 + ) + test_info["chs"] = [ + { + "cal": 1, + "ch_name": "ICA 001", + "coil_type": 0, + "coord_frame": 0, + "kind": 502, + "loc": loc.copy(), + "logno": 1, + "range": 1.0, + "scanno": 1, + "unit": -1, + "unit_mul": 0, + }, + { + "cal": 1, + "ch_name": "ICA 002", + "coil_type": 0, + "coord_frame": 0, + "kind": 502, + "loc": loc.copy(), + "logno": 2, + "range": 1.0, + "scanno": 2, + "unit": -1, + "unit_mul": 0, + }, + { + "cal": 0.002142000012099743, + "ch_name": "EOG 061", + "coil_type": 1, + "coord_frame": 0, + "kind": 202, + "loc": loc.copy(), + "logno": 61, + "range": 1.0, + "scanno": 376, + "unit": 107, + "unit_mul": 0, + }, + ] test_info._unlocked = False test_info._update_redundant() test_info._check_consistency() @@ -57,7 +97,8 @@ def test_io_layout_lout(tmp_path): layout = read_layout(fname="Vectorview-all", scale=False) layout.save(tmp_path / "foobar.lout", overwrite=True) layout_read = read_layout( - fname=tmp_path / "foobar.lout", scale=False, + fname=tmp_path / "foobar.lout", + scale=False, ) assert_array_almost_equal(layout.pos, layout_read.pos, decimal=2) assert layout.names == layout_read.names @@ -66,15 +107,17 @@ def test_io_layout_lout(tmp_path): # deprecation with pytest.warns(DeprecationWarning, match="should not be provided"): layout_read = read_layout( - fname=tmp_path / "foobar.lout", kind="Vectorview-all", scale=False, + fname=tmp_path / "foobar.lout", + kind="Vectorview-all", + scale=False, ) with pytest.warns(DeprecationWarning, match="should not be provided"): layout_read = read_layout( - fname=tmp_path / "foobar.lout", path=None, scale=False, + fname=tmp_path / "foobar.lout", + path=None, + scale=False, ) - with pytest.warns( - DeprecationWarning, match="'kind' and 'path' are deprecated" - ): + with pytest.warns(DeprecationWarning, match="'kind' and 'path' are deprecated"): layout_read = read_layout(kind="Vectorview-all", scale=False) @@ -82,9 +125,7 @@ def test_io_layout_lay(tmp_path): """Test IO with .lay files.""" layout = read_layout(fname="CTF151", scale=False) layout.save(str(tmp_path / "foobar.lay")) - layout_read = read_layout( - fname=tmp_path / "foobar.lay", scale=False - ) + layout_read = read_layout(fname=tmp_path / "foobar.lay", scale=False) assert_array_almost_equal(layout.pos, layout_read.pos, decimal=2) assert layout.names == layout_read.names @@ -96,60 +137,57 @@ def test_find_topomap_coords(): # Remove extra digitization point, so EEG digitization points match up # with the EEG channels - del info['dig'][85] + del info["dig"][85] # Use channel locations - kwargs = dict(ignore_overlap=False, to_sphere=True, - sphere=HEAD_SIZE_DEFAULT) + kwargs = dict(ignore_overlap=False, to_sphere=True, sphere=HEAD_SIZE_DEFAULT) l0 = _find_topomap_coords(info, picks, **kwargs) # Remove electrode position information, use digitization points from now # on. - for ch in info['chs']: - ch['loc'].fill(np.nan) + for ch in info["chs"]: + ch["loc"].fill(np.nan) l1 = _find_topomap_coords(info, picks, **kwargs) assert_allclose(l1, l0, atol=1e-3) - for z_pt in ((HEAD_SIZE_DEFAULT, 0., 0.), - (0., HEAD_SIZE_DEFAULT, 0.)): - info['dig'][-1]['r'] = np.array(z_pt) + for z_pt in ((HEAD_SIZE_DEFAULT, 0.0, 0.0), (0.0, HEAD_SIZE_DEFAULT, 0.0)): + info["dig"][-1]["r"] = np.array(z_pt) l1 = _find_topomap_coords(info, picks, **kwargs) - assert_allclose(l1[-1], z_pt[:2], err_msg='Z=0 point moved', atol=1e-6) + assert_allclose(l1[-1], z_pt[:2], err_msg="Z=0 point moved", atol=1e-6) # Test plotting mag topomap without channel locations: it should fail - mag_picks = pick_types(info, meg='mag') - with pytest.raises(ValueError, match='Cannot determine location'): + mag_picks = pick_types(info, meg="mag") + with pytest.raises(ValueError, match="Cannot determine location"): _find_topomap_coords(info, mag_picks, **kwargs) # Test function with too many EEG digitization points: it should fail - info['dig'].append({'r': [1, 2, 3], 'kind': FIFF.FIFFV_POINT_EEG}) - with pytest.raises(ValueError, match='Number of EEG digitization points'): + info["dig"].append({"r": [1, 2, 3], "kind": FIFF.FIFFV_POINT_EEG}) + with pytest.raises(ValueError, match="Number of EEG digitization points"): _find_topomap_coords(info, picks, **kwargs) # Test function with too little EEG digitization points: it should fail info._unlocked = True - info['dig'] = info['dig'][:-2] - with pytest.raises(ValueError, match='Number of EEG digitization points'): + info["dig"] = info["dig"][:-2] + with pytest.raises(ValueError, match="Number of EEG digitization points"): _find_topomap_coords(info, picks, **kwargs) # Electrode positions must be unique - info['dig'].append(info['dig'][-1]) - with pytest.raises(ValueError, match='overlapping positions'): + info["dig"].append(info["dig"][-1]) + with pytest.raises(ValueError, match="overlapping positions"): _find_topomap_coords(info, picks, **kwargs) # Test function without EEG digitization points: it should fail - info['dig'] = [d for d in info['dig'] - if d['kind'] != FIFF.FIFFV_POINT_EEG] - with pytest.raises(RuntimeError, match='Did not find any digitization'): + info["dig"] = [d for d in info["dig"] if d["kind"] != FIFF.FIFFV_POINT_EEG] + with pytest.raises(RuntimeError, match="Did not find any digitization"): _find_topomap_coords(info, picks, **kwargs) # Test function without any digitization points, it should fail - info['dig'] = None - with pytest.raises(RuntimeError, match='No digitization points found'): + info["dig"] = None + with pytest.raises(RuntimeError, match="No digitization points found"): _find_topomap_coords(info, picks, **kwargs) - info['dig'] = [] - with pytest.raises(RuntimeError, match='No digitization points found'): + info["dig"] = [] + with pytest.raises(RuntimeError, match="No digitization points found"): _find_topomap_coords(info, picks, **kwargs) @@ -200,92 +238,92 @@ def test_make_grid_layout(tmp_path): def test_find_layout(): """Test finding layout.""" - pytest.raises(ValueError, find_layout, _get_test_info(), ch_type='meep') + pytest.raises(ValueError, find_layout, _get_test_info(), ch_type="meep") sample_info = read_info(fif_fname) - grads = pick_types(sample_info, meg='grad') + grads = pick_types(sample_info, meg="grad") sample_info2 = pick_info(sample_info, grads) - mags = pick_types(sample_info, meg='mag') + mags = pick_types(sample_info, meg="mag") sample_info3 = pick_info(sample_info, mags) # mock new convention sample_info4 = copy.deepcopy(sample_info) - for ii, name in enumerate(sample_info4['ch_names']): - new = name.replace(' ', '') - sample_info4['chs'][ii]['ch_name'] = new + for ii, name in enumerate(sample_info4["ch_names"]): + new = name.replace(" ", "") + sample_info4["chs"][ii]["ch_name"] = new eegs = pick_types(sample_info, meg=False, eeg=True) sample_info5 = pick_info(sample_info, eegs) lout = find_layout(sample_info, ch_type=None) - assert lout.kind == 'Vectorview-all' - assert all(' ' in k for k in lout.names) + assert lout.kind == "Vectorview-all" + assert all(" " in k for k in lout.names) - lout = find_layout(sample_info2, ch_type='meg') - assert_equal(lout.kind, 'Vectorview-all') + lout = find_layout(sample_info2, ch_type="meg") + assert_equal(lout.kind, "Vectorview-all") # test new vector-view lout = find_layout(sample_info4, ch_type=None) - assert_equal(lout.kind, 'Vectorview-all') - assert all(' ' not in k for k in lout.names) + assert_equal(lout.kind, "Vectorview-all") + assert all(" " not in k for k in lout.names) - lout = find_layout(sample_info, ch_type='grad') - assert_equal(lout.kind, 'Vectorview-grad') + lout = find_layout(sample_info, ch_type="grad") + assert_equal(lout.kind, "Vectorview-grad") lout = find_layout(sample_info2) - assert_equal(lout.kind, 'Vectorview-grad') - lout = find_layout(sample_info2, ch_type='grad') - assert_equal(lout.kind, 'Vectorview-grad') - lout = find_layout(sample_info2, ch_type='meg') - assert_equal(lout.kind, 'Vectorview-all') - - lout = find_layout(sample_info, ch_type='mag') - assert_equal(lout.kind, 'Vectorview-mag') + assert_equal(lout.kind, "Vectorview-grad") + lout = find_layout(sample_info2, ch_type="grad") + assert_equal(lout.kind, "Vectorview-grad") + lout = find_layout(sample_info2, ch_type="meg") + assert_equal(lout.kind, "Vectorview-all") + + lout = find_layout(sample_info, ch_type="mag") + assert_equal(lout.kind, "Vectorview-mag") lout = find_layout(sample_info3) - assert_equal(lout.kind, 'Vectorview-mag') - lout = find_layout(sample_info3, ch_type='mag') - assert_equal(lout.kind, 'Vectorview-mag') - lout = find_layout(sample_info3, ch_type='meg') - assert_equal(lout.kind, 'Vectorview-all') - - lout = find_layout(sample_info, ch_type='eeg') - assert_equal(lout.kind, 'EEG') + assert_equal(lout.kind, "Vectorview-mag") + lout = find_layout(sample_info3, ch_type="mag") + assert_equal(lout.kind, "Vectorview-mag") + lout = find_layout(sample_info3, ch_type="meg") + assert_equal(lout.kind, "Vectorview-all") + + lout = find_layout(sample_info, ch_type="eeg") + assert_equal(lout.kind, "EEG") lout = find_layout(sample_info5) - assert_equal(lout.kind, 'EEG') - lout = find_layout(sample_info5, ch_type='eeg') - assert_equal(lout.kind, 'EEG') + assert_equal(lout.kind, "EEG") + lout = find_layout(sample_info5, ch_type="eeg") + assert_equal(lout.kind, "EEG") # no common layout, 'meg' option not supported lout = find_layout(read_info(fname_ctf_raw)) - assert_equal(lout.kind, 'CTF-275') + assert_equal(lout.kind, "CTF-275") fname_bti_raw = bti_dir / "exported4D_linux_raw.fif" lout = find_layout(read_info(fname_bti_raw)) - assert_equal(lout.kind, 'magnesWH3600') + assert_equal(lout.kind, "magnesWH3600") raw_kit = read_raw_kit(fname_kit_157) lout = find_layout(raw_kit.info) - assert_equal(lout.kind, 'KIT-157') + assert_equal(lout.kind, "KIT-157") - raw_kit.info['bads'] = ['MEG 013', 'MEG 014', 'MEG 015', 'MEG 016'] + raw_kit.info["bads"] = ["MEG 013", "MEG 014", "MEG 015", "MEG 016"] raw_kit.info._check_consistency() lout = find_layout(raw_kit.info) - assert_equal(lout.kind, 'KIT-157') + assert_equal(lout.kind, "KIT-157") # fallback for missing IDs for val in (35, 52, 54, 1001): with raw_kit.info._unlock(): - raw_kit.info['kit_system_id'] = val + raw_kit.info["kit_system_id"] = val lout = find_layout(raw_kit.info) - assert lout.kind == 'custom' + assert lout.kind == "custom" raw_umd = read_raw_kit(fname_kit_umd) lout = find_layout(raw_umd.info) - assert_equal(lout.kind, 'KIT-UMD-3') + assert_equal(lout.kind, "KIT-UMD-3") # Test plotting lout.plot() lout.plot(picks=np.arange(10)) - plt.close('all') + plt.close("all") def test_box_size(): @@ -357,7 +395,7 @@ def test_generate_2d_layout(): sbg = 15 side = range(snobg) bg_image = np.random.RandomState(42).randn(sbg, sbg) - w, h = [.2, .5] + w, h = [0.2, 0.5] # Generate fake data xy = np.array([(i, j) for i in side for j in side]) @@ -367,9 +405,10 @@ def test_generate_2d_layout(): comp_1, comp_2 = [(5, 0), (7, 0)] assert lt.pos[:, :2].max() == 1 assert lt.pos[:, :2].min() == 0 - with np.errstate(invalid='ignore'): # divide by zero - assert_allclose(xy[comp_2] / float(xy[comp_1]), - lt.pos[comp_2] / float(lt.pos[comp_1])) + with np.errstate(invalid="ignore"): # divide by zero + assert_allclose( + xy[comp_2] / float(xy[comp_1]), lt.pos[comp_2] / float(lt.pos[comp_1]) + ) assert_allclose(lt.pos[0, [2, 3]], [w, h]) # Correct number elements diff --git a/mne/channels/tests/test_montage.py b/mne/channels/tests/test_montage.py index f78e6bb3f2d..ca7347b21ad 100644 --- a/mne/channels/tests/test_montage.py +++ b/mne/channels/tests/test_montage.py @@ -14,37 +14,60 @@ from functools import partial from string import ascii_lowercase -from numpy.testing import (assert_array_equal, assert_array_less, - assert_allclose, assert_equal) +from numpy.testing import ( + assert_array_equal, + assert_array_less, + assert_allclose, + assert_equal, +) import matplotlib.pyplot as plt from mne import __file__ as _mne_file, create_info, read_evokeds, pick_types from mne.source_space import get_mni_fiducials from mne.utils._testing import assert_object_equal -from mne.channels import (get_builtin_montages, DigMontage, read_dig_dat, - read_dig_egi, read_dig_captrak, read_dig_fif, - make_standard_montage, read_custom_montage, - compute_dev_head_t, make_dig_montage, - read_dig_polhemus_isotrak, compute_native_head_t, - read_polhemus_fastscan, read_dig_localite, - read_dig_hpts) +from mne.channels import ( + get_builtin_montages, + DigMontage, + read_dig_dat, + read_dig_egi, + read_dig_captrak, + read_dig_fif, + make_standard_montage, + read_custom_montage, + compute_dev_head_t, + make_dig_montage, + read_dig_polhemus_isotrak, + compute_native_head_t, + read_polhemus_fastscan, + read_dig_localite, + read_dig_hpts, +) from mne.channels.montage import ( - transform_to_head, _check_get_coord_frame, _BUILTIN_STANDARD_MONTAGES + transform_to_head, + _check_get_coord_frame, + _BUILTIN_STANDARD_MONTAGES, ) from mne.preprocessing import compute_current_source_density from mne.utils import assert_dig_allclose, _record_warnings from mne.bem import _fit_sphere from mne.io.constants import FIFF -from mne.io._digitization import (_format_dig_points, - _get_fid_coords, _get_dig_eeg, - _count_points_by_type) -from mne.transforms import (_ensure_trans, apply_trans, invert_transform, - _get_trans) +from mne.io._digitization import ( + _format_dig_points, + _get_fid_coords, + _get_dig_eeg, + _count_points_by_type, +) +from mne.transforms import _ensure_trans, apply_trans, invert_transform, _get_trans from mne.viz._3d import _fiducial_coords from mne.io.kit import read_mrk -from mne.io import (read_raw_brainvision, read_raw_egi, read_raw_fif, - read_fiducials, read_raw_nirx) +from mne.io import ( + read_raw_brainvision, + read_raw_egi, + read_raw_fif, + read_fiducials, + read_raw_nirx, +) from mne.io import RawArray from mne.datasets import testing @@ -88,41 +111,43 @@ def _make_toy_raw(n_channels): return RawArray( data=np.empty([n_channels, 1]), info=create_info( - ch_names=list(ascii_lowercase[:n_channels]), - sfreq=1, ch_types='eeg' - ) + ch_names=list(ascii_lowercase[:n_channels]), sfreq=1, ch_types="eeg" + ), ) def _make_toy_dig_montage(n_channels, **kwargs): return make_dig_montage( - ch_pos=dict(zip( - list(ascii_lowercase[:n_channels]), - np.arange(n_channels * 3).reshape(n_channels, 3), - )), - **kwargs + ch_pos=dict( + zip( + list(ascii_lowercase[:n_channels]), + np.arange(n_channels * 3).reshape(n_channels, 3), + ) + ), + **kwargs, ) def _get_dig_montage_pos(montage): - return np.array([d['r'] for d in _get_dig_eeg(montage.dig)]) + return np.array([d["r"] for d in _get_dig_eeg(montage.dig)]) def test_dig_montage_trans(tmp_path): """Test getting a trans from and applying a trans to a montage.""" nasion, lpa, rpa, *ch_pos = np.random.RandomState(0).randn(10, 3) - ch_pos = {f'EEG{ii:3d}': pos for ii, pos in enumerate(ch_pos, 1)} - montage = make_dig_montage(ch_pos, nasion=nasion, lpa=lpa, rpa=rpa, - coord_frame='mri') + ch_pos = {f"EEG{ii:3d}": pos for ii, pos in enumerate(ch_pos, 1)} + montage = make_dig_montage( + ch_pos, nasion=nasion, lpa=lpa, rpa=rpa, coord_frame="mri" + ) trans = compute_native_head_t(montage) _ensure_trans(trans) # ensure that we can save and load it, too - fname = tmp_path / 'temp-mon.fif' - _check_roundtrip(montage, fname, 'mri') + fname = tmp_path / "temp-mon.fif" + _check_roundtrip(montage, fname, "mri") # test applying a trans position1 = montage.get_positions() montage.apply_trans(trans) - assert montage.get_positions()['coord_frame'] == 'head' + assert montage.get_positions()["coord_frame"] == "head" montage.apply_trans(invert_transform(trans)) position2 = montage.get_positions() assert str(position1) == str(position2) # exactly equal @@ -137,300 +162,356 @@ def test_fiducials(): points = _fiducial_coords(fids, coord_frame) assert points.shape == (3, 3) # Fids - assert_allclose(points[:, 2], 0., atol=1e-6) - assert_allclose(points[::2, 1], 0., atol=1e-6) + assert_allclose(points[:, 2], 0.0, atol=1e-6) + assert_allclose(points[::2, 1], 0.0, atol=1e-6) assert points[2, 0] > 0 # RPA assert points[0, 0] < 0 # LPA # Nasion - assert_allclose(points[1, 0], 0., atol=1e-6) + assert_allclose(points[1, 0], 0.0, atol=1e-6) assert points[1, 1] > 0 def test_documented(): """Test that standard montages are documented.""" - montage_dir = Path(_mne_file).parent / 'channels' / 'data' / 'montages' - montage_files = Path(montage_dir).glob('*') + montage_dir = Path(_mne_file).parent / "channels" / "data" / "montages" + montage_files = Path(montage_dir).glob("*") montage_names = [f.stem for f in montage_files] assert len(montage_names) == len(_BUILTIN_STANDARD_MONTAGES) - assert set(montage_names) == set( - [m.name for m in _BUILTIN_STANDARD_MONTAGES] - ) - - -@pytest.mark.parametrize('reader, file_content, expected_dig, ext, warning', [ - pytest.param( - partial(read_custom_montage, head_size=None), - ('FidNz 0 9.071585155 -2.359754454\n' - 'FidT9 -6.711765 0.040402876 -3.251600355\n' - 'very_very_very_long_name -5.831241498 -4.494821698 4.955347697\n' - 'Cz 0 0 1\n' - 'Cz 0 0 8.899186843'), - make_dig_montage( - ch_pos={ - 'very_very_very_long_name': [-5.8312416, -4.4948215, 4.9553475], # noqa - 'Cz': [0., 0., 8.899187], - }, - nasion=[0., 9.071585, -2.3597546], - lpa=[-6.711765, 0.04040287, -3.2516003], - rpa=None, + assert set(montage_names) == set([m.name for m in _BUILTIN_STANDARD_MONTAGES]) + + +@pytest.mark.parametrize( + "reader, file_content, expected_dig, ext, warning", + [ + pytest.param( + partial(read_custom_montage, head_size=None), + ( + "FidNz 0 9.071585155 -2.359754454\n" + "FidT9 -6.711765 0.040402876 -3.251600355\n" + "very_very_very_long_name -5.831241498 -4.494821698 4.955347697\n" + "Cz 0 0 1\n" + "Cz 0 0 8.899186843" + ), + make_dig_montage( + ch_pos={ + "very_very_very_long_name": [ + -5.8312416, + -4.4948215, + 4.9553475, + ], # noqa + "Cz": [0.0, 0.0, 8.899187], + }, + nasion=[0.0, 9.071585, -2.3597546], + lpa=[-6.711765, 0.04040287, -3.2516003], + rpa=None, + ), + "sfp", + (RuntimeWarning, r"Duplicate.*last will be used for Cz \(2\)"), + id="sfp_duplicate", ), - 'sfp', - (RuntimeWarning, r'Duplicate.*last will be used for Cz \(2\)'), - id='sfp_duplicate'), - - pytest.param( - partial(read_custom_montage, head_size=None), - ('FidNz 0 9.071585155 -2.359754454\n' - 'FidT9 -6.711765 0.040402876 -3.251600355\n' - 'headshape 1 2 3\n' - 'headshape 4 5 6\n' - 'Cz 0 0 8.899186843'), - make_dig_montage( - hsp=[ - [1, 2, 3], - [4, 5, 6], - ], - ch_pos={ - 'Cz': [0., 0., 8.899187], - }, - nasion=[0., 9.071585, -2.3597546], - lpa=[-6.711765, 0.04040287, -3.2516003], - rpa=None, + pytest.param( + partial(read_custom_montage, head_size=None), + ( + "FidNz 0 9.071585155 -2.359754454\n" + "FidT9 -6.711765 0.040402876 -3.251600355\n" + "headshape 1 2 3\n" + "headshape 4 5 6\n" + "Cz 0 0 8.899186843" + ), + make_dig_montage( + hsp=[ + [1, 2, 3], + [4, 5, 6], + ], + ch_pos={ + "Cz": [0.0, 0.0, 8.899187], + }, + nasion=[0.0, 9.071585, -2.3597546], + lpa=[-6.711765, 0.04040287, -3.2516003], + rpa=None, + ), + "sfp", + None, + id="sfp_headshape", ), - 'sfp', - None, - id='sfp_headshape'), - - pytest.param( - partial(read_custom_montage, head_size=1), - ('1 0 0.50669 FPz\n' - '2 23 0.71 EOG1\n' - '3 -39.947 0.34459 F3\n' - '4 0 0.25338 Fz\n'), - make_dig_montage( - ch_pos={ - 'EOG1': [0.30873816, 0.72734152, -0.61290705], - 'F3': [-0.56705965, 0.67706631, 0.46906776], - 'FPz': [0., 0.99977915, -0.02101571], - 'Fz': [0., 0.71457525, 0.69955859], - }, - nasion=None, lpa=None, rpa=None, coord_frame='head', + pytest.param( + partial(read_custom_montage, head_size=1), + ( + "1 0 0.50669 FPz\n" + "2 23 0.71 EOG1\n" + "3 -39.947 0.34459 F3\n" + "4 0 0.25338 Fz\n" + ), + make_dig_montage( + ch_pos={ + "EOG1": [0.30873816, 0.72734152, -0.61290705], + "F3": [-0.56705965, 0.67706631, 0.46906776], + "FPz": [0.0, 0.99977915, -0.02101571], + "Fz": [0.0, 0.71457525, 0.69955859], + }, + nasion=None, + lpa=None, + rpa=None, + coord_frame="head", + ), + "loc", + None, + id="EEGLAB", ), - 'loc', - None, - id='EEGLAB'), - - pytest.param( - partial(read_custom_montage, head_size=None, coord_frame='mri'), - "// MatLab Sphere coordinates [degrees] Cartesian coordinates\n" # noqa: E501 - "// Label Theta Phi Radius X Y Z off sphere surface\n" # noqa: E501 - "E1 37.700 -14.000 1.000 0.7677 0.5934 -0.2419 -0.00000000000000011\n" # noqa: E501 - "E3 51.700 11.000 1.000 0.6084 0.7704 0.1908 0.00000000000000000\n" # noqa: E501 - "E31 90.000 -11.000 1.000 0.0000 0.9816 -0.1908 0.00000000000000000\n" # noqa: E501 - "E61 158.000 -17.200 1.000 -0.8857 0.3579 -0.2957 -0.00000000000000022", # noqa: E501 - make_dig_montage( - ch_pos={ - 'E1': [0.7677, 0.5934, -0.2419], - 'E3': [0.6084, 0.7704, 0.1908], - 'E31': [0., 0.9816, -0.1908], - 'E61': [-0.8857, 0.3579, -0.2957], - }, - nasion=None, lpa=None, rpa=None, coord_frame='mri', + pytest.param( + partial(read_custom_montage, head_size=None, coord_frame="mri"), + "// MatLab Sphere coordinates [degrees] Cartesian coordinates\n" # noqa: E501 + "// Label Theta Phi Radius X Y Z off sphere surface\n" # noqa: E501 + "E1 37.700 -14.000 1.000 0.7677 0.5934 -0.2419 -0.00000000000000011\n" # noqa: E501 + "E3 51.700 11.000 1.000 0.6084 0.7704 0.1908 0.00000000000000000\n" # noqa: E501 + "E31 90.000 -11.000 1.000 0.0000 0.9816 -0.1908 0.00000000000000000\n" # noqa: E501 + "E61 158.000 -17.200 1.000 -0.8857 0.3579 -0.2957 -0.00000000000000022", # noqa: E501 + make_dig_montage( + ch_pos={ + "E1": [0.7677, 0.5934, -0.2419], + "E3": [0.6084, 0.7704, 0.1908], + "E31": [0.0, 0.9816, -0.1908], + "E61": [-0.8857, 0.3579, -0.2957], + }, + nasion=None, + lpa=None, + rpa=None, + coord_frame="mri", + ), + "csd", + None, + id="matlab", ), - 'csd', - None, - id='matlab'), - - pytest.param( - partial(read_custom_montage, head_size=None), - ('# ASA electrode file\nReferenceLabel avg\nUnitPosition mm\n' - 'NumberPositions= 68\n' - 'Positions\n' - '-86.0761 -19.9897 -47.9860\n' - '85.7939 -20.0093 -48.0310\n' - '0.0083 86.8110 -39.9830\n' - '-86.0761 -24.9897 -67.9860\n' - 'Labels\nLPA\nRPA\nNz\nDummy\n'), - make_dig_montage( - ch_pos={ - 'Dummy': [-0.0860761, -0.0249897, -0.067986], - }, - nasion=[8.3000e-06, 8.6811e-02, -3.9983e-02], - lpa=[-0.0860761, -0.0199897, -0.047986], - rpa=[0.0857939, -0.0200093, -0.048031], + pytest.param( + partial(read_custom_montage, head_size=None), + ( + "# ASA electrode file\nReferenceLabel avg\nUnitPosition mm\n" + "NumberPositions= 68\n" + "Positions\n" + "-86.0761 -19.9897 -47.9860\n" + "85.7939 -20.0093 -48.0310\n" + "0.0083 86.8110 -39.9830\n" + "-86.0761 -24.9897 -67.9860\n" + "Labels\nLPA\nRPA\nNz\nDummy\n" + ), + make_dig_montage( + ch_pos={ + "Dummy": [-0.0860761, -0.0249897, -0.067986], + }, + nasion=[8.3000e-06, 8.6811e-02, -3.9983e-02], + lpa=[-0.0860761, -0.0199897, -0.047986], + rpa=[0.0857939, -0.0200093, -0.048031], + ), + "elc", + None, + id="ASA electrode", ), - 'elc', - None, - id='ASA electrode'), - - pytest.param( - partial(read_custom_montage, head_size=1), - ('Site Theta Phi\n' - 'Fp1 -92 -72\n' - 'Fp2 92 72\n' - 'very_very_very_long_name -92 72\n' - 'O2 92 -90\n'), - make_dig_montage( - ch_pos={ - 'Fp1': [-0.30882875, 0.95047716, -0.0348995], - 'Fp2': [0.30882875, 0.95047716, -0.0348995], - 'very_very_very_long_name': [-0.30882875, -0.95047716, -0.0348995], # noqa - 'O2': [6.11950389e-17, -9.99390827e-01, -3.48994967e-02] - }, - nasion=None, lpa=None, rpa=None, + pytest.param( + partial(read_custom_montage, head_size=1), + ( + "Site Theta Phi\n" + "Fp1 -92 -72\n" + "Fp2 92 72\n" + "very_very_very_long_name -92 72\n" + "O2 92 -90\n" + ), + make_dig_montage( + ch_pos={ + "Fp1": [-0.30882875, 0.95047716, -0.0348995], + "Fp2": [0.30882875, 0.95047716, -0.0348995], + "very_very_very_long_name": [ + -0.30882875, + -0.95047716, + -0.0348995, + ], # noqa + "O2": [6.11950389e-17, -9.99390827e-01, -3.48994967e-02], + }, + nasion=None, + lpa=None, + rpa=None, + ), + "txt", + None, + id="generic theta-phi (txt)", ), - 'txt', - None, - id='generic theta-phi (txt)'), - - pytest.param( - partial(read_custom_montage, head_size=None), - ('FID\t LPA\t -120.03\t 0\t 85\n' - 'FID\t RPA\t 120.03\t 0\t 85\n' - 'FID\t Nz\t 114.03\t 90\t 85\n' - 'EEG\t F3\t -62.027\t -50.053\t 85\n' - 'EEG\t Fz\t 45.608\t 90\t 85\n' - 'EEG\t F4\t 62.01\t 50.103\t 85\n' - 'EEG\t FCz\t 68.01\t 58.103\t 85\n'), - make_dig_montage( - ch_pos={ - 'F3': [-0.48200427, 0.57551063, 0.39869712], - 'Fz': [3.71915931e-17, 6.07384809e-01, 5.94629038e-01], - 'F4': [0.48142596, 0.57584026, 0.39891983], - 'FCz': [0.41645989, 0.66914889, 0.31827805], - }, - nasion=[4.75366562e-17, 7.76332511e-01, -3.46132681e-01], - lpa=[-7.35898963e-01, 9.01216309e-17, -4.25385374e-01], - rpa=[0.73589896, 0., -0.42538537], + pytest.param( + partial(read_custom_montage, head_size=None), + ( + "FID\t LPA\t -120.03\t 0\t 85\n" + "FID\t RPA\t 120.03\t 0\t 85\n" + "FID\t Nz\t 114.03\t 90\t 85\n" + "EEG\t F3\t -62.027\t -50.053\t 85\n" + "EEG\t Fz\t 45.608\t 90\t 85\n" + "EEG\t F4\t 62.01\t 50.103\t 85\n" + "EEG\t FCz\t 68.01\t 58.103\t 85\n" + ), + make_dig_montage( + ch_pos={ + "F3": [-0.48200427, 0.57551063, 0.39869712], + "Fz": [3.71915931e-17, 6.07384809e-01, 5.94629038e-01], + "F4": [0.48142596, 0.57584026, 0.39891983], + "FCz": [0.41645989, 0.66914889, 0.31827805], + }, + nasion=[4.75366562e-17, 7.76332511e-01, -3.46132681e-01], + lpa=[-7.35898963e-01, 9.01216309e-17, -4.25385374e-01], + rpa=[0.73589896, 0.0, -0.42538537], + ), + "elp", + None, + id="BESA spherical model", ), - 'elp', - None, - id='BESA spherical model'), - - pytest.param( - partial(read_dig_hpts, unit='m'), - ('eeg Fp1 -95.0 -3. -3.\n' - 'eeg AF7 -1 -1 -3\n' - 'eeg A3 -2 -2 2\n' - 'eeg A 0 0 0'), - make_dig_montage( - ch_pos={ - 'A': [0., 0., 0.], 'A3': [-2., -2., 2.], - 'AF7': [-1., -1., -3.], 'Fp1': [-95., -3., -3.], - }, - nasion=None, lpa=None, rpa=None, + pytest.param( + partial(read_dig_hpts, unit="m"), + ( + "eeg Fp1 -95.0 -3. -3.\n" + "eeg AF7 -1 -1 -3\n" + "eeg A3 -2 -2 2\n" + "eeg A 0 0 0" + ), + make_dig_montage( + ch_pos={ + "A": [0.0, 0.0, 0.0], + "A3": [-2.0, -2.0, 2.0], + "AF7": [-1.0, -1.0, -3.0], + "Fp1": [-95.0, -3.0, -3.0], + }, + nasion=None, + lpa=None, + rpa=None, + ), + "hpts", + None, + id="legacy mne-c", ), - 'hpts', - None, - id='legacy mne-c'), - - pytest.param( - read_custom_montage, - ('ch_name, x, y, z\n' - 'Fp1, -95.0, -3., -3.\n' - 'AF7, -1, -1, -3\n' - 'A3, -2, -2, 2\n' - 'A, 0, 0, 0'), - make_dig_montage( - ch_pos={ - 'A': [0., 0., 0.], 'A3': [-2., -2., 2.], - 'AF7': [-1., -1., -3.], 'Fp1': [-95., -3., -3.], - }, - nasion=None, lpa=None, rpa=None, + pytest.param( + read_custom_montage, + ( + "ch_name, x, y, z\n" + "Fp1, -95.0, -3., -3.\n" + "AF7, -1, -1, -3\n" + "A3, -2, -2, 2\n" + "A, 0, 0, 0" + ), + make_dig_montage( + ch_pos={ + "A": [0.0, 0.0, 0.0], + "A3": [-2.0, -2.0, 2.0], + "AF7": [-1.0, -1.0, -3.0], + "Fp1": [-95.0, -3.0, -3.0], + }, + nasion=None, + lpa=None, + rpa=None, + ), + "csv", + None, + id="CSV file", ), - 'csv', - None, - id='CSV file'), - - pytest.param( - read_custom_montage, - ('1\t-95.0\t-3.\t-3.\tFp1\n' - '2\t-1\t-1\t-3\tAF7\n' - '3\t-2\t-2\t2\tA3\n' - '4\t0\t0\t0\tA'), - make_dig_montage( - ch_pos={ - 'A': [0., 0., 0.], 'A3': [-2., -2., 2.], - 'AF7': [-1., -1., -3.], 'Fp1': [-95., -3., -3.], - }, - nasion=None, lpa=None, rpa=None, + pytest.param( + read_custom_montage, + ( + "1\t-95.0\t-3.\t-3.\tFp1\n" + "2\t-1\t-1\t-3\tAF7\n" + "3\t-2\t-2\t2\tA3\n" + "4\t0\t0\t0\tA" + ), + make_dig_montage( + ch_pos={ + "A": [0.0, 0.0, 0.0], + "A3": [-2.0, -2.0, 2.0], + "AF7": [-1.0, -1.0, -3.0], + "Fp1": [-95.0, -3.0, -3.0], + }, + nasion=None, + lpa=None, + rpa=None, + ), + "xyz", + None, + id="XYZ file", ), - 'xyz', - None, - id='XYZ file'), - - pytest.param( - read_custom_montage, - ('ch_name\tx\ty\tz\n' - 'Fp1\t-95.0\t-3.\t-3.\n' - 'AF7\t-1\t-1\t-3\n' - 'A3\t-2\t-2\t2\n' - 'A\t0\t0\t0'), - make_dig_montage( - ch_pos={ - 'A': [0., 0., 0.], 'A3': [-2., -2., 2.], - 'AF7': [-1., -1., -3.], 'Fp1': [-95., -3., -3.], - }, - nasion=None, lpa=None, rpa=None, + pytest.param( + read_custom_montage, + ( + "ch_name\tx\ty\tz\n" + "Fp1\t-95.0\t-3.\t-3.\n" + "AF7\t-1\t-1\t-3\n" + "A3\t-2\t-2\t2\n" + "A\t0\t0\t0" + ), + make_dig_montage( + ch_pos={ + "A": [0.0, 0.0, 0.0], + "A3": [-2.0, -2.0, 2.0], + "AF7": [-1.0, -1.0, -3.0], + "Fp1": [-95.0, -3.0, -3.0], + }, + nasion=None, + lpa=None, + rpa=None, + ), + "tsv", + None, + id="TSV file", ), - 'tsv', - None, - id='TSV file'), - - pytest.param( - partial(read_custom_montage, head_size=None), - ('\n' - '\n' - '\n' - ' \n' - ' Fp1\n' - ' -90\n' - ' -72\n' - ' 1\n' - ' 1\n' - ' \n' - ' \n' - ' Fz\n' - ' 45\n' - ' 90\n' - ' 1\n' - ' 2\n' - ' \n' - ' \n' - ' F3\n' - ' -60\n' - ' -51\n' - ' 1\n' - ' 3\n' - ' \n' - ' \n' - ' F7\n' - ' -90\n' - ' -36\n' - ' 1\n' - ' 4\n' - ' \n' - ''), - make_dig_montage( - ch_pos={ - 'Fp1': [-3.09016994e-01, 9.51056516e-01, 6.12323400e-17], - 'Fz': [4.32978028e-17, 7.07106781e-01, 7.07106781e-01], - 'F3': [-0.54500745, 0.67302815, 0.5], - 'F7': [-8.09016994e-01, 5.87785252e-01, 6.12323400e-17], - }, - nasion=None, lpa=None, rpa=None, + pytest.param( + partial(read_custom_montage, head_size=None), + ( + '\n' + "\n" + '\n' + " \n" + " Fp1\n" + " -90\n" + " -72\n" + " 1\n" + " 1\n" + " \n" + " \n" + " Fz\n" + " 45\n" + " 90\n" + " 1\n" + " 2\n" + " \n" + " \n" + " F3\n" + " -60\n" + " -51\n" + " 1\n" + " 3\n" + " \n" + " \n" + " F7\n" + " -90\n" + " -36\n" + " 1\n" + " 4\n" + " \n" + "" + ), + make_dig_montage( + ch_pos={ + "Fp1": [-3.09016994e-01, 9.51056516e-01, 6.12323400e-17], + "Fz": [4.32978028e-17, 7.07106781e-01, 7.07106781e-01], + "F3": [-0.54500745, 0.67302815, 0.5], + "F7": [-8.09016994e-01, 5.87785252e-01, 6.12323400e-17], + }, + nasion=None, + lpa=None, + rpa=None, + ), + "bvef", + None, + id="brainvision", ), - 'bvef', - None, - id='brainvision'), -]) -def test_montage_readers( - reader, file_content, expected_dig, ext, warning, tmp_path -): + ], +) +def test_montage_readers(reader, file_content, expected_dig, ext, warning, tmp_path): """Test that we have an equivalent of read_montage for all file formats.""" fname = tmp_path / f"test.{ext}" - with open(fname, 'w') as fid: + with open(fname, "w") as fid: fid.write(file_content) if warning is None: @@ -447,15 +528,15 @@ def test_montage_readers( assert_allclose(actual_ch_pos[kk], expected_ch_pos[kk], atol=1e-5) assert len(dig_montage.dig) == len(expected_dig.dig) for d1, d2 in zip(dig_montage.dig, expected_dig.dig): - assert d1['coord_frame'] == d2['coord_frame'] - for key in ('coord_frame', 'ident', 'kind'): + assert d1["coord_frame"] == d2["coord_frame"] + for key in ("coord_frame", "ident", "kind"): assert isinstance(d1[key], int) assert isinstance(d2[key], int) with _record_warnings() as w: xform = compute_native_head_t(dig_montage) - assert xform['to'] == FIFF.FIFFV_COORD_HEAD - assert xform['from'] == FIFF.FIFFV_COORD_UNKNOWN - n = int(np.allclose(xform['trans'], np.eye(4))) + assert xform["to"] == FIFF.FIFFV_COORD_HEAD + assert xform["from"] == FIFF.FIFFV_COORD_UNKNOWN + n = int(np.allclose(xform["trans"], np.eye(4))) assert len(w) == n @@ -465,32 +546,34 @@ def test_read_locs(): data = read_custom_montage(locs_montage_fname)._get_ch_pos() assert_allclose( actual=np.stack( - [data[kk] for kk in ('FPz', 'EOG1', 'F3', 'Fz')] # 4 random chs + [data[kk] for kk in ("FPz", "EOG1", "F3", "Fz")] # 4 random chs ), - desired=[[0., 0.094979, -0.001996], - [0.02933, 0.069097, -0.058226], - [-0.053871, 0.064321, 0.044561], - [0., 0.067885, 0.066458]], - atol=1e-6 + desired=[ + [0.0, 0.094979, -0.001996], + [0.02933, 0.069097, -0.058226], + [-0.053871, 0.064321, 0.044561], + [0.0, 0.067885, 0.066458], + ], + atol=1e-6, ) def test_read_dig_dat(tmp_path): """Test reading *.dat electrode locations.""" rows = [ - ['Nasion', 78, 0.00, 1.00, 0.00], - ['Left', 76, -1.00, 0.00, 0.00], - ['Right', 82, 1.00, -0.00, 0.00], - ['O2', 69, -0.50, -0.90, 0.05], - ['O2', 68, 0.00, 0.01, 0.02], - ['Centroid', 67, 0.00, 0.00, 0.00], + ["Nasion", 78, 0.00, 1.00, 0.00], + ["Left", 76, -1.00, 0.00, 0.00], + ["Right", 82, 1.00, -0.00, 0.00], + ["O2", 69, -0.50, -0.90, 0.05], + ["O2", 68, 0.00, 0.01, 0.02], + ["Centroid", 67, 0.00, 0.00, 0.00], ] # write mock test.dat file fname_temp = tmp_path / "test.dat" - with open(fname_temp, 'w') as fid: + with open(fname_temp, "w") as fid: for row in rows: name = row[0].rjust(10) - data = '\t'.join(map(str, row[1:])) + data = "\t".join(map(str, row[1:])) fid.write("%s\t%s\n" % (name, data)) # construct expected value idents = { @@ -507,15 +590,21 @@ def test_read_dig_dat(tmp_path): 69: FIFF.FIFFV_POINT_EEG, 68: FIFF.FIFFV_POINT_EEG, } - target = {row[0]: {'r': row[2:], 'ident': idents[row[1]], - 'kind': kinds[row[1]], 'coord_frame': 0} - for row in rows[:-1]} - assert_allclose(target['O2']['r'], [0, 0.01, 0.02]) + target = { + row[0]: { + "r": row[2:], + "ident": idents[row[1]], + "kind": kinds[row[1]], + "coord_frame": 0, + } + for row in rows[:-1] + } + assert_allclose(target["O2"]["r"], [0, 0.01, 0.02]) # read it - with pytest.warns(RuntimeWarning, match=r'Duplic.*for O2 \(2\)'): + with pytest.warns(RuntimeWarning, match=r"Duplic.*for O2 \(2\)"): dig = read_dig_dat(fname_temp) - assert set(dig.ch_names) == {'O2'} - keys = chain(['Left', 'Nasion', 'Right'], dig.ch_names) + assert set(dig.ch_names) == {"O2"} + keys = chain(["Left", "Nasion", "Right"], dig.ch_names) target = [target[k] for k in keys] assert dig.dig == target @@ -526,32 +615,29 @@ def test_read_dig_montage_using_polhemus_fastscan(): my_electrode_positions = read_polhemus_fastscan(kit_dir / "test_elp.txt") montage = make_dig_montage( # EEG_CH - ch_pos=dict(zip(ascii_lowercase[:N_EEG_CH], - np.random.RandomState(0).rand(N_EEG_CH, 3))), + ch_pos=dict( + zip(ascii_lowercase[:N_EEG_CH], np.random.RandomState(0).rand(N_EEG_CH, 3)) + ), # NO NAMED points nasion=my_electrode_positions[0], lpa=my_electrode_positions[1], rpa=my_electrode_positions[2], hpi=my_electrode_positions[3:], hsp=read_polhemus_fastscan(kit_dir / "test_hsp.txt"), - # Other defaults - coord_frame='unknown' + coord_frame="unknown", ) assert repr(montage) == ( - '' + "" ) - assert set([d['coord_frame'] for d in montage.dig]) == { - FIFF.FIFFV_COORD_UNKNOWN - } + assert set([d["coord_frame"] for d in montage.dig]) == {FIFF.FIFFV_COORD_UNKNOWN} EXPECTED_FID_IN_POLHEMUS = { - 'nasion': [0.001393, 0.0131613, -0.0046967], - 'lpa': [-0.0624997, -0.0737271, 0.07996], - 'rpa': [-0.0748957, 0.0873785, 0.0811943], + "nasion": [0.001393, 0.0131613, -0.0046967], + "lpa": [-0.0624997, -0.0737271, 0.07996], + "rpa": [-0.0748957, 0.0873785, 0.0811943], } fiducials, fid_coordframe = _get_fid_coords(montage.dig) assert fid_coordframe == FIFF.FIFFV_COORD_UNKNOWN @@ -562,17 +648,17 @@ def test_read_dig_montage_using_polhemus_fastscan(): def test_read_dig_montage_using_polhemus_fastscan_error_handling(tmp_path): """Test reading Polhemus FastSCAN errors.""" with open(kit_dir / "test_elp.txt") as fid: - content = fid.read().replace('FastSCAN', 'XxxxXXXX') + content = fid.read().replace("FastSCAN", "XxxxXXXX") - fname = tmp_path / 'faulty_FastSCAN.txt' - with open(fname, 'w') as fid: + fname = tmp_path / "faulty_FastSCAN.txt" + with open(fname, "w") as fid: fid.write(content) - with pytest.raises(ValueError, match='not contain.*Polhemus FastSCAN'): + with pytest.raises(ValueError, match="not contain.*Polhemus FastSCAN"): _ = read_polhemus_fastscan(fname) - fname = tmp_path / 'faulty_FastSCAN.bar' - with open(fname, 'w') as fid: + fname = tmp_path / "faulty_FastSCAN.bar" + with open(fname, "w") as fid: fid.write(content) EXPECTED_ERR_MSG = "allowed value is '.txt', but got '.bar' instead" with pytest.raises(ValueError, match=EXPECTED_ERR_MSG): @@ -582,16 +668,13 @@ def test_read_dig_montage_using_polhemus_fastscan_error_handling(tmp_path): def test_read_dig_polhemus_isotrak_hsp(): """Test reading Polhemus IsoTrak HSP file.""" EXPECTED_FID_IN_POLHEMUS = { - 'nasion': np.array([1.1056e-01, -5.4210e-19, 0]), - 'lpa': np.array([-2.1075e-04, 8.0793e-02, -7.5894e-19]), - 'rpa': np.array([2.1075e-04, -8.0793e-02, -2.8731e-18]), + "nasion": np.array([1.1056e-01, -5.4210e-19, 0]), + "lpa": np.array([-2.1075e-04, 8.0793e-02, -7.5894e-19]), + "rpa": np.array([2.1075e-04, -8.0793e-02, -2.8731e-18]), } - montage = read_dig_polhemus_isotrak( - fname=kit_dir / "test.hsp", ch_names=None - ) + montage = read_dig_polhemus_isotrak(fname=kit_dir / "test.hsp", ch_names=None) assert repr(montage) == ( - '' + "" ) fiducials, fid_coordframe = _get_fid_coords(montage.dig) @@ -604,16 +687,13 @@ def test_read_dig_polhemus_isotrak_hsp(): def test_read_dig_polhemus_isotrak_elp(): """Test reading Polhemus IsoTrak ELP file.""" EXPECTED_FID_IN_POLHEMUS = { - 'nasion': np.array([1.1056e-01, -5.4210e-19, 0]), - 'lpa': np.array([-2.1075e-04, 8.0793e-02, -7.5894e-19]), - 'rpa': np.array([2.1075e-04, -8.0793e-02, -2.8731e-18]), + "nasion": np.array([1.1056e-01, -5.4210e-19, 0]), + "lpa": np.array([-2.1075e-04, 8.0793e-02, -7.5894e-19]), + "rpa": np.array([2.1075e-04, -8.0793e-02, -2.8731e-18]), } - montage = read_dig_polhemus_isotrak( - fname=kit_dir / "test.elp", ch_names=None - ) + montage = read_dig_polhemus_isotrak(fname=kit_dir / "test.elp", ch_names=None) assert repr(montage) == ( - '' + "" ) fiducials, fid_coordframe = _get_fid_coords(montage.dig) @@ -622,35 +702,39 @@ def test_read_dig_polhemus_isotrak_elp(): assert_array_equal(val, EXPECTED_FID_IN_POLHEMUS[kk]) -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def isotrak_eeg(tmp_path_factory): """Mock isotrak file with EEG positions.""" _SEED = 42 N_ROWS, N_COLS = 5, 3 content = np.random.RandomState(_SEED).randn(N_ROWS, N_COLS) - fname = tmp_path_factory.mktemp('data') / 'test.eeg' - with open(str(fname), 'w') as fid: - fid.write(( - '3 200\n' - '//Shape file\n' - '//Minor revision number\n' - '2\n' - '//Subject Name\n' - '%N Name \n' - '////Shape code, number of digitized points\n' - )) - fid.write('0 {rows:d}\n'.format(rows=N_ROWS)) - fid.write(( - '//Position of fiducials X+, Y+, Y- on the subject\n' - '%F 0.11056 -5.421e-19 0 \n' - '%F -0.00021075 0.080793 -7.5894e-19 \n' - '%F 0.00021075 -0.080793 -2.8731e-18 \n' - '//No of rows, no of columns; position of digitized points\n' - )) - fid.write('{rows:d} {cols:d}\n'.format(rows=N_ROWS, cols=N_COLS)) + fname = tmp_path_factory.mktemp("data") / "test.eeg" + with open(str(fname), "w") as fid: + fid.write( + ( + "3 200\n" + "//Shape file\n" + "//Minor revision number\n" + "2\n" + "//Subject Name\n" + "%N Name \n" + "////Shape code, number of digitized points\n" + ) + ) + fid.write("0 {rows:d}\n".format(rows=N_ROWS)) + fid.write( + ( + "//Position of fiducials X+, Y+, Y- on the subject\n" + "%F 0.11056 -5.421e-19 0 \n" + "%F -0.00021075 0.080793 -7.5894e-19 \n" + "%F 0.00021075 -0.080793 -2.8731e-18 \n" + "//No of rows, no of columns; position of digitized points\n" + ) + ) + fid.write("{rows:d} {cols:d}\n".format(rows=N_ROWS, cols=N_COLS)) for row in content: - fid.write('\t'.join('%0.18e' % cell for cell in row) + '\n') + fid.write("\t".join("%0.18e" % cell for cell in row) + "\n") return str(fname) @@ -660,18 +744,18 @@ def test_read_dig_polhemus_isotrak_eeg(isotrak_eeg): N_CHANNELS = 5 _SEED = 42 EXPECTED_FID_IN_POLHEMUS = { - 'nasion': np.array([1.1056e-01, -5.4210e-19, 0]), - 'lpa': np.array([-2.1075e-04, 8.0793e-02, -7.5894e-19]), - 'rpa': np.array([2.1075e-04, -8.0793e-02, -2.8731e-18]), + "nasion": np.array([1.1056e-01, -5.4210e-19, 0]), + "lpa": np.array([-2.1075e-04, 8.0793e-02, -7.5894e-19]), + "rpa": np.array([2.1075e-04, -8.0793e-02, -2.8731e-18]), } - ch_names = ['eeg {:01d}'.format(ii) for ii in range(N_CHANNELS)] - EXPECTED_CH_POS = dict(zip( - ch_names, np.random.RandomState(_SEED).randn(N_CHANNELS, 3))) + ch_names = ["eeg {:01d}".format(ii) for ii in range(N_CHANNELS)] + EXPECTED_CH_POS = dict( + zip(ch_names, np.random.RandomState(_SEED).randn(N_CHANNELS, 3)) + ) montage = read_dig_polhemus_isotrak(fname=isotrak_eeg, ch_names=ch_names) assert repr(montage) == ( - '' + "" ) fiducials, fid_coordframe = _get_fid_coords(montage.dig) @@ -681,8 +765,8 @@ def test_read_dig_polhemus_isotrak_eeg(isotrak_eeg): assert_array_equal(val, EXPECTED_FID_IN_POLHEMUS[kk]) for kk, dig_point in zip(montage.ch_names, _get_dig_eeg(montage.dig)): - assert_array_equal(dig_point['r'], EXPECTED_CH_POS[kk]) - assert dig_point['coord_frame'] == FIFF.FIFFV_COORD_UNKNOWN + assert_array_equal(dig_point["r"], EXPECTED_CH_POS[kk]) + assert dig_point["coord_frame"] == FIFF.FIFFV_COORD_UNKNOWN def test_read_dig_polhemus_isotrak_error_handling(isotrak_eeg, tmp_path): @@ -697,7 +781,7 @@ def test_read_dig_polhemus_isotrak_error_handling(isotrak_eeg, tmp_path): with pytest.raises(ValueError, match=EXPECTED_ERR_MSG): _ = read_dig_polhemus_isotrak( fname=isotrak_eeg, - ch_names=['eeg {:01d}'.format(ii) for ii in range(N_CHANNELS + 42)] + ch_names=["eeg {:01d}".format(ii) for ii in range(N_CHANNELS + 42)], ) # Check fname extensions @@ -706,7 +790,7 @@ def test_read_dig_polhemus_isotrak_error_handling(isotrak_eeg, tmp_path): with pytest.raises( ValueError, - match="Allowed val.*'.hsp', '.elp', and '.eeg', but got '.bar' instead" + match="Allowed val.*'.hsp', '.elp', and '.eeg', but got '.bar' instead", ): _ = read_dig_polhemus_isotrak(fname=fname, ch_names=None) @@ -714,52 +798,64 @@ def test_read_dig_polhemus_isotrak_error_handling(isotrak_eeg, tmp_path): def test_combining_digmontage_objects(): """Test combining different DigMontage objects.""" rng = np.random.RandomState(0) - fiducials = dict(zip(('nasion', 'lpa', 'rpa'), rng.rand(3, 3))) + fiducials = dict(zip(("nasion", "lpa", "rpa"), rng.rand(3, 3))) # hsp positions are [1X, 1X, 1X] - hsp1 = make_dig_montage(**fiducials, hsp=np.full((2, 3), 11.)) - hsp2 = make_dig_montage(**fiducials, hsp=np.full((2, 3), 12.)) - hsp3 = make_dig_montage(**fiducials, hsp=np.full((2, 3), 13.)) + hsp1 = make_dig_montage(**fiducials, hsp=np.full((2, 3), 11.0)) + hsp2 = make_dig_montage(**fiducials, hsp=np.full((2, 3), 12.0)) + hsp3 = make_dig_montage(**fiducials, hsp=np.full((2, 3), 13.0)) # hpi positions are [2X, 2X, 2X] - hpi1 = make_dig_montage(**fiducials, hpi=np.full((2, 3), 21.)) - hpi2 = make_dig_montage(**fiducials, hpi=np.full((2, 3), 22.)) - hpi3 = make_dig_montage(**fiducials, hpi=np.full((2, 3), 23.)) + hpi1 = make_dig_montage(**fiducials, hpi=np.full((2, 3), 21.0)) + hpi2 = make_dig_montage(**fiducials, hpi=np.full((2, 3), 22.0)) + hpi3 = make_dig_montage(**fiducials, hpi=np.full((2, 3), 23.0)) # channels have positions at 40s, 50s, and 60s. ch_pos1 = make_dig_montage( - **fiducials, - ch_pos={'h': [41, 41, 41], 'b': [42, 42, 42], 'g': [43, 43, 43]} + **fiducials, ch_pos={"h": [41, 41, 41], "b": [42, 42, 42], "g": [43, 43, 43]} ) ch_pos2 = make_dig_montage( - **fiducials, - ch_pos={'n': [51, 51, 51], 'y': [52, 52, 52], 'p': [53, 53, 53]} + **fiducials, ch_pos={"n": [51, 51, 51], "y": [52, 52, 52], "p": [53, 53, 53]} ) ch_pos3 = make_dig_montage( - **fiducials, - ch_pos={'v': [61, 61, 61], 'a': [62, 62, 62], 'l': [63, 63, 63]} + **fiducials, ch_pos={"v": [61, 61, 61], "a": [62, 62, 62], "l": [63, 63, 63]} ) montage = ( - DigMontage() + hsp1 + hsp2 + hsp3 + hpi1 + hpi2 + hpi3 + ch_pos1 + - ch_pos2 + ch_pos3 + DigMontage() + + hsp1 + + hsp2 + + hsp3 + + hpi1 + + hpi2 + + hpi3 + + ch_pos1 + + ch_pos2 + + ch_pos3 ) assert repr(montage) == ( - '' + "" ) EXPECTED_MONTAGE = make_dig_montage( **fiducials, - hsp=np.concatenate([np.full((2, 3), 11.), np.full((2, 3), 12.), - np.full((2, 3), 13.)]), - hpi=np.concatenate([np.full((2, 3), 21.), np.full((2, 3), 22.), - np.full((2, 3), 23.)]), + hsp=np.concatenate( + [np.full((2, 3), 11.0), np.full((2, 3), 12.0), np.full((2, 3), 13.0)] + ), + hpi=np.concatenate( + [np.full((2, 3), 21.0), np.full((2, 3), 22.0), np.full((2, 3), 23.0)] + ), ch_pos={ - 'h': [41, 41, 41], 'b': [42, 42, 42], 'g': [43, 43, 43], - 'n': [51, 51, 51], 'y': [52, 52, 52], 'p': [53, 53, 53], - 'v': [61, 61, 61], 'a': [62, 62, 62], 'l': [63, 63, 63], - } + "h": [41, 41, 41], + "b": [42, 42, 42], + "g": [43, 43, 43], + "n": [51, 51, 51], + "y": [52, 52, 52], + "p": [53, 53, 53], + "v": [61, 61, 61], + "a": [62, 62, 62], + "l": [63, 63, 63], + }, ) # Do some checks to ensure they are the same DigMontage @@ -773,33 +869,33 @@ def test_combining_digmontage_objects(): def test_combining_digmontage_forbiden_behaviors(): """Test combining different DigMontage objects with repeated names.""" rng = np.random.RandomState(0) - fiducials = dict(zip(('nasion', 'lpa', 'rpa'), rng.rand(3, 3))) + fiducials = dict(zip(("nasion", "lpa", "rpa"), rng.rand(3, 3))) dig1 = make_dig_montage( **fiducials, - ch_pos=dict(zip(list('abc'), rng.rand(3, 3))), + ch_pos=dict(zip(list("abc"), rng.rand(3, 3))), ) dig2 = make_dig_montage( **fiducials, - ch_pos=dict(zip(list('bcd'), rng.rand(3, 3))), + ch_pos=dict(zip(list("bcd"), rng.rand(3, 3))), ) dig2_wrong_fid = make_dig_montage( - nasion=rng.rand(3), lpa=rng.rand(3), rpa=rng.rand(3), - ch_pos=dict(zip(list('ghi'), rng.rand(3, 3))), + nasion=rng.rand(3), + lpa=rng.rand(3), + rpa=rng.rand(3), + ch_pos=dict(zip(list("ghi"), rng.rand(3, 3))), ) dig2_wrong_coordframe = make_dig_montage( - **fiducials, - ch_pos=dict(zip(list('ghi'), rng.rand(3, 3))), - coord_frame='meg' + **fiducials, ch_pos=dict(zip(list("ghi"), rng.rand(3, 3))), coord_frame="meg" ) - EXPECTED_ERR_MSG = "Cannot.*duplicated channel.*found: \'b\', \'c\'." + EXPECTED_ERR_MSG = "Cannot.*duplicated channel.*found: 'b', 'c'." with pytest.raises(RuntimeError, match=EXPECTED_ERR_MSG): _ = dig1 + dig2 - with pytest.raises(RuntimeError, match='fiducial locations do not match'): + with pytest.raises(RuntimeError, match="fiducial locations do not match"): _ = dig1 + dig2_wrong_fid - with pytest.raises(RuntimeError, match='not in the same coordinate '): + with pytest.raises(RuntimeError, match="not in the same coordinate "): _ = dig1 + dig2_wrong_coordframe @@ -807,45 +903,57 @@ def test_set_dig_montage(): """Test setting DigMontage with toy understandable points.""" N_CHANNELS, N_HSP, N_HPI = 3, 2, 1 ch_names = list(ascii_lowercase[:N_CHANNELS]) - ch_pos = dict(zip( - ch_names, - np.arange(N_CHANNELS * 3).reshape(N_CHANNELS, 3), - )) + ch_pos = dict( + zip( + ch_names, + np.arange(N_CHANNELS * 3).reshape(N_CHANNELS, 3), + ) + ) - montage_ch_only = make_dig_montage(ch_pos=ch_pos, coord_frame='head') + montage_ch_only = make_dig_montage(ch_pos=ch_pos, coord_frame="head") assert repr(montage_ch_only) == ( - '' + "" ) - info = create_info(ch_names, sfreq=1, ch_types='eeg') + info = create_info(ch_names, sfreq=1, ch_types="eeg") info.set_montage(montage_ch_only) - assert len(info['dig']) == len(montage_ch_only.dig) + 3 # added fiducials + assert len(info["dig"]) == len(montage_ch_only.dig) + 3 # added fiducials - assert_allclose(actual=np.array([ch['loc'][:6] for ch in info['chs']]), - desired=[[0., 1., 2., 0., 0., 0.], - [3., 4., 5., 0., 0., 0.], - [6., 7., 8., 0., 0., 0.]]) + assert_allclose( + actual=np.array([ch["loc"][:6] for ch in info["chs"]]), + desired=[ + [0.0, 1.0, 2.0, 0.0, 0.0, 0.0], + [3.0, 4.0, 5.0, 0.0, 0.0, 0.0], + [6.0, 7.0, 8.0, 0.0, 0.0, 0.0], + ], + ) montage_full = make_dig_montage( ch_pos=dict(**ch_pos, EEG000=np.full(3, 42)), # 4 = 3 egg + 1 eeg_ref - nasion=[1, 1, 1], lpa=[2, 2, 2], rpa=[3, 3, 3], + nasion=[1, 1, 1], + lpa=[2, 2, 2], + rpa=[3, 3, 3], hsp=np.full((N_HSP, 3), 4), hpi=np.full((N_HPI, 3), 4), - coord_frame='head' + coord_frame="head", ) assert repr(montage_full) == ( - '' + "" ) - info = create_info(ch_names, sfreq=1, ch_types='eeg') + info = create_info(ch_names, sfreq=1, ch_types="eeg") info.set_montage(montage_full) - EXPECTED_LEN = sum({'hsp': 2, 'hpi': 1, 'fid': 3, 'eeg': 4}.values()) - assert len(info['dig']) == EXPECTED_LEN - assert_allclose(actual=np.array([ch['loc'][:6] for ch in info['chs']]), - desired=[[0., 1., 2., 42., 42., 42.], - [3., 4., 5., 42., 42., 42.], - [6., 7., 8., 42., 42., 42.]]) + EXPECTED_LEN = sum({"hsp": 2, "hpi": 1, "fid": 3, "eeg": 4}.values()) + assert len(info["dig"]) == EXPECTED_LEN + assert_allclose( + actual=np.array([ch["loc"][:6] for ch in info["chs"]]), + desired=[ + [0.0, 1.0, 2.0, 42.0, 42.0, 42.0], + [3.0, 4.0, 5.0, 42.0, 42.0, 42.0], + [6.0, 7.0, 8.0, 42.0, 42.0, 42.0], + ], + ) def test_set_dig_montage_with_nan_positions(): @@ -854,10 +962,11 @@ def test_set_dig_montage_with_nan_positions(): Test that setting a montage with some NaN positions does not produce NaN fiducials. """ + def _ensure_fid_not_nan(info, ch_pos): - montage_kwargs = dict(ch_pos=dict(), coord_frame='head') + montage_kwargs = dict(ch_pos=dict(), coord_frame="head") for ch_idx, ch in enumerate(info.ch_names): - montage_kwargs['ch_pos'][ch] = ch_pos[ch_idx] + montage_kwargs["ch_pos"][ch] = ch_pos[ch_idx] new_montage = make_dig_montage(**montage_kwargs) info = info.copy() @@ -865,7 +974,8 @@ def _ensure_fid_not_nan(info, ch_pos): recovered_montage = info.get_montage() fid_coords, coord_frame = _get_fid_coords( - recovered_montage.dig, raise_error=False) + recovered_montage.dig, raise_error=False + ) for fid_coord in fid_coords.values(): if fid_coord is not None: @@ -873,21 +983,20 @@ def _ensure_fid_not_nan(info, ch_pos): return fid_coords, coord_frame - channels = list('ABCDEF') - info = create_info(channels, 1000, ch_types='seeg') + channels = list("ABCDEF") + info = create_info(channels, 1000, ch_types="seeg") # if all positions are NaN, the fiducials should not be NaN, but None - ch_pos = [info['chs'][ch_idx]['loc'][:3] - for ch_idx in range(len(channels))] + ch_pos = [info["chs"][ch_idx]["loc"][:3] for ch_idx in range(len(channels))] fid_coords, coord_frame = _ensure_fid_not_nan(info, ch_pos) for fid_coord in fid_coords.values(): assert fid_coord is None assert coord_frame is None # if some positions are not NaN, the fiducials should be a non-NaN array - ch_pos[0] = np.array([1., 1.5, 1.]) - ch_pos[1] = np.array([2., 1.5, 1.5]) - ch_pos[2] = np.array([1.25, 1., 1.25]) + ch_pos[0] = np.array([1.0, 1.5, 1.0]) + ch_pos[1] = np.array([2.0, 1.5, 1.5]) + ch_pos[2] = np.array([1.25, 1.0, 1.25]) fid_coords, coord_frame = _ensure_fid_not_nan(info, ch_pos) for fid_coord in fid_coords.values(): assert isinstance(fid_coord, np.ndarray) @@ -908,14 +1017,14 @@ def test_fif_dig_montage(tmp_path): raw_bv_2 = raw_bv.copy() mapping = dict() for ii, ch_name in enumerate(raw_bv.ch_names): - mapping[ch_name] = 'EEG%03d' % (ii + 1,) + mapping[ch_name] = "EEG%03d" % (ii + 1,) raw_bv.rename_channels(mapping) for ii, ch_name in enumerate(raw_bv_2.ch_names): - mapping[ch_name] = 'EEG%03d' % (ii + 33,) + mapping[ch_name] = "EEG%03d" % (ii + 33,) raw_bv_2.rename_channels(mapping) raw_bv.add_channels([raw_bv_2]) - for ch in raw_bv.info['chs']: - ch['kind'] = FIFF.FIFFV_EEG_CH + for ch in raw_bv.info["chs"]: + ch["kind"] = FIFF.FIFFV_EEG_CH # Set the montage raw_bv.set_montage(dig_montage) @@ -925,33 +1034,30 @@ def test_fif_dig_montage(tmp_path): # check info[chs] matches assert_equal(len(raw_bv.ch_names), len(evoked.ch_names) - 1) - for ch_py, ch_c in zip(raw_bv.info['chs'], evoked.info['chs'][:-1]): - assert_equal(ch_py['ch_name'], - ch_c['ch_name'].replace('EEG ', 'EEG')) + for ch_py, ch_c in zip(raw_bv.info["chs"], evoked.info["chs"][:-1]): + assert_equal(ch_py["ch_name"], ch_c["ch_name"].replace("EEG ", "EEG")) # C actually says it's unknown, but it's not (?): # assert_equal(ch_py['coord_frame'], ch_c['coord_frame']) - assert_equal(ch_py['coord_frame'], FIFF.FIFFV_COORD_HEAD) - c_loc = ch_c['loc'].copy() + assert_equal(ch_py["coord_frame"], FIFF.FIFFV_COORD_HEAD) + c_loc = ch_c["loc"].copy() c_loc[c_loc == 0] = np.nan - assert_allclose(ch_py['loc'], c_loc, atol=1e-7) + assert_allclose(ch_py["loc"], c_loc, atol=1e-7) # check info[dig] assert_dig_allclose(raw_bv.info, evoked.info) # Roundtrip of non-FIF start - montage = make_dig_montage(hsp=read_polhemus_fastscan(hsp), - hpi=read_mrk(hpi)) + montage = make_dig_montage(hsp=read_polhemus_fastscan(hsp), hpi=read_mrk(hpi)) elp_points = read_polhemus_fastscan(elp) ch_pos = {"EEG%03d" % (k + 1): pos for k, pos in enumerate(elp_points[8:])} - montage += make_dig_montage(nasion=elp_points[0], - lpa=elp_points[1], - rpa=elp_points[2], - ch_pos=ch_pos) - _check_roundtrip(montage, fname_temp, 'unknown') + montage += make_dig_montage( + nasion=elp_points[0], lpa=elp_points[1], rpa=elp_points[2], ch_pos=ch_pos + ) + _check_roundtrip(montage, fname_temp, "unknown") montage = transform_to_head(montage) _check_roundtrip(montage, fname_temp) - montage.dig[0]['coord_frame'] = FIFF.FIFFV_COORD_UNKNOWN - with pytest.raises(RuntimeError, match='Only a single coordinate'): + montage.dig[0]["coord_frame"] = FIFF.FIFFV_COORD_UNKNOWN + with pytest.raises(RuntimeError, match="Only a single coordinate"): montage.save(fname_temp) @@ -963,24 +1069,25 @@ def test_egi_dig_montage(tmp_path): assert coord == FIFF.FIFFV_COORD_UNKNOWN assert_allclose( - actual=np.array([fid[key] for key in ['nasion', 'lpa', 'rpa']]), - desired=[[ 0. , 10.564, -2.051], # noqa - [-8.592, 0.498, -4.128], # noqa - [ 8.592, 0.498, -4.128]], # noqa + actual=np.array([fid[key] for key in ["nasion", "lpa", "rpa"]]), + desired=[ + [0.0, 10.564, -2.051], # noqa + [-8.592, 0.498, -4.128], # noqa + [8.592, 0.498, -4.128], + ], # noqa ) # Test accuracy and embedding within raw object - raw_egi = read_raw_egi(egi_raw_fname, channel_naming='EEG %03d') + raw_egi = read_raw_egi(egi_raw_fname, channel_naming="EEG %03d") raw_egi.set_montage(dig_montage) test_raw_egi = read_raw_fif(egi_fif_fname) assert_equal(len(raw_egi.ch_names), len(test_raw_egi.ch_names)) - for ch_raw, ch_test_raw in zip(raw_egi.info['chs'], - test_raw_egi.info['chs']): - assert_equal(ch_raw['ch_name'], ch_test_raw['ch_name']) - assert_equal(ch_raw['coord_frame'], FIFF.FIFFV_COORD_HEAD) - assert_allclose(ch_raw['loc'], ch_test_raw['loc'], atol=1e-7) + for ch_raw, ch_test_raw in zip(raw_egi.info["chs"], test_raw_egi.info["chs"]): + assert_equal(ch_raw["ch_name"], ch_test_raw["ch_name"]) + assert_equal(ch_raw["coord_frame"], FIFF.FIFFV_COORD_HEAD) + assert_allclose(ch_raw["loc"], ch_test_raw["loc"], atol=1e-7) assert_dig_allclose(raw_egi.info, test_raw_egi.info) @@ -988,14 +1095,14 @@ def test_egi_dig_montage(tmp_path): fid, coord = _get_fid_coords(dig_montage_in_head.dig) assert coord == FIFF.FIFFV_COORD_HEAD assert_allclose( - actual=np.array([fid[key] for key in ['nasion', 'lpa', 'rpa']]), - desired=[[0., 10.278, 0.], [-8.592, 0., 0.], [8.592, 0., 0.]], + actual=np.array([fid[key] for key in ["nasion", "lpa", "rpa"]]), + desired=[[0.0, 10.278, 0.0], [-8.592, 0.0, 0.0], [8.592, 0.0, 0.0]], atol=1e-4, ) # test round-trip IO - fname_temp = tmp_path / 'egi_test.fif' - _check_roundtrip(dig_montage, fname_temp, 'unknown') + fname_temp = tmp_path / "egi_test.fif" + _check_roundtrip(dig_montage, fname_temp, "unknown") _check_roundtrip(dig_montage_in_head, fname_temp) @@ -1007,44 +1114,158 @@ def _pop_montage(dig_montage, ch_name): del dig_montage.dig[dig_idx] del dig_montage.ch_names[name_idx] for k in range(dig_idx, len(dig_montage.dig)): - dig_montage.dig[k]['ident'] -= 1 + dig_montage.dig[k]["ident"] -= 1 @testing.requires_testing_data def test_read_dig_captrak(tmp_path): """Test reading a captrak montage file.""" EXPECTED_CH_NAMES_OLD = [ - 'AF3', 'AF4', 'AF7', 'AF8', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'CP1', - 'CP2', 'CP3', 'CP4', 'CP5', 'CP6', 'CPz', 'Cz', 'F1', 'F2', 'F3', 'F4', - 'F5', 'F6', 'F7', 'F8', 'FC1', 'FC2', 'FC3', 'FC4', 'FC5', 'FC6', - 'FT10', 'FT7', 'FT8', 'FT9', 'Fp1', 'Fp2', 'Fz', 'GND', 'O1', 'O2', - 'Oz', 'P1', 'P2', 'P3', 'P4', 'P5', 'P6', 'P7', 'P8', 'PO10', 'PO3', - 'PO4', 'PO7', 'PO8', 'PO9', 'POz', 'Pz', 'REF', 'T7', 'T8', 'TP10', - 'TP7', 'TP8', 'TP9' + "AF3", + "AF4", + "AF7", + "AF8", + "C1", + "C2", + "C3", + "C4", + "C5", + "C6", + "CP1", + "CP2", + "CP3", + "CP4", + "CP5", + "CP6", + "CPz", + "Cz", + "F1", + "F2", + "F3", + "F4", + "F5", + "F6", + "F7", + "F8", + "FC1", + "FC2", + "FC3", + "FC4", + "FC5", + "FC6", + "FT10", + "FT7", + "FT8", + "FT9", + "Fp1", + "Fp2", + "Fz", + "GND", + "O1", + "O2", + "Oz", + "P1", + "P2", + "P3", + "P4", + "P5", + "P6", + "P7", + "P8", + "PO10", + "PO3", + "PO4", + "PO7", + "PO8", + "PO9", + "POz", + "Pz", + "REF", + "T7", + "T8", + "TP10", + "TP7", + "TP8", + "TP9", ] EXPECTED_CH_NAMES = [ - 'T7', 'FC5', 'F7', 'C5', 'FT7', 'FT9', 'TP7', 'TP9', 'P7', 'CP5', - 'PO7', 'C3', 'CP3', 'P5', 'P3', 'PO3', 'PO9', 'O1', 'Oz', 'POz', 'O2', - 'PO4', 'P1', 'Pz', 'P2', 'CP2', 'CP1', 'CPz', 'Cz', 'C1', 'FC1', 'FC3', - 'REF', 'F3', 'F1', 'Fz', 'F5', 'AF7', 'AF3', 'Fp1', 'GND', 'F2', 'AF4', - 'Fp2', 'F4', 'F8', 'F6', 'AF8', 'FC2', 'FC6', 'FC4', 'C2', 'C4', 'P4', - 'CP4', 'PO8', 'P8', 'P6', 'CP6', 'PO10', 'TP10', 'TP8', 'FT10', 'T8', - 'C6', 'FT8' + "T7", + "FC5", + "F7", + "C5", + "FT7", + "FT9", + "TP7", + "TP9", + "P7", + "CP5", + "PO7", + "C3", + "CP3", + "P5", + "P3", + "PO3", + "PO9", + "O1", + "Oz", + "POz", + "O2", + "PO4", + "P1", + "Pz", + "P2", + "CP2", + "CP1", + "CPz", + "Cz", + "C1", + "FC1", + "FC3", + "REF", + "F3", + "F1", + "Fz", + "F5", + "AF7", + "AF3", + "Fp1", + "GND", + "F2", + "AF4", + "Fp2", + "F4", + "F8", + "F6", + "AF8", + "FC2", + "FC6", + "FC4", + "C2", + "C4", + "P4", + "CP4", + "PO8", + "P8", + "P6", + "CP6", + "PO10", + "TP10", + "TP8", + "FT10", + "T8", + "C6", + "FT8", ] assert set(EXPECTED_CH_NAMES) == set(EXPECTED_CH_NAMES_OLD) - montage = read_dig_captrak( - fname=data_path / "montage" / "captrak_coords.bvct" - ) + montage = read_dig_captrak(fname=data_path / "montage" / "captrak_coords.bvct") assert montage.ch_names == EXPECTED_CH_NAMES assert repr(montage) == ( - '' + "" ) montage = transform_to_head(montage) # transform_to_head has to be tested - _check_roundtrip(montage=montage, - fname=str(tmp_path / 'bvct_test.fif')) + _check_roundtrip(montage=montage, fname=str(tmp_path / "bvct_test.fif")) fid, _ = _get_fid_coords(montage.dig) assert_allclose( @@ -1054,64 +1275,65 @@ def test_read_dig_captrak(tmp_path): ) raw_bv = read_raw_brainvision(bv_raw_fname) - raw_bv.set_channel_types({"HEOG": 'eog', "VEOG": 'eog', "ECG": 'ecg'}) + raw_bv.set_channel_types({"HEOG": "eog", "VEOG": "eog", "ECG": "ecg"}) raw_bv.set_montage(montage) test_raw_bv = read_raw_fif(bv_fif_fname) # compare after set_montage using chs loc. - for actual, expected in zip(raw_bv.info['chs'], test_raw_bv.info['chs']): - assert_allclose(actual['loc'][:3], expected['loc'][:3]) - if actual['kind'] == FIFF.FIFFV_EEG_CH: - assert_allclose(actual['loc'][3:6], - [-0.005103, 0.05395, 0.144622], rtol=1e-04) + for actual, expected in zip(raw_bv.info["chs"], test_raw_bv.info["chs"]): + assert_allclose(actual["loc"][:3], expected["loc"][:3]) + if actual["kind"] == FIFF.FIFFV_EEG_CH: + assert_allclose( + actual["loc"][3:6], [-0.005103, 0.05395, 0.144622], rtol=1e-04 + ) # https://gist.github.com/larsoner/2264fb5895070d29a8c9aa7c0dc0e8a6 _MGH60 = ( - 'Fp1 Fpz Fp2 ' - 'AF7 AF3 AF4 AF8 ' - 'F7 F5 F3 F1 Fz F2 F4 F6 F8 ' - 'FT9 FT7 FC5 FC1 FC2 FC6 FT8 FT10 ' - 'T9 T7 C5 C3 C1 Cz C2 C4 C6 T8 T10 ' - 'TP9 TP7 CP3 CP1 CP2 CP4 TP8 TP10 ' - 'P7 P5 P3 P1 Pz P2 P4 P6 P8 ' - 'PO7 PO3 PO4 PO8 ' - 'O1 Oz O2 ' - 'Iz' + "Fp1 Fpz Fp2 " + "AF7 AF3 AF4 AF8 " + "F7 F5 F3 F1 Fz F2 F4 F6 F8 " + "FT9 FT7 FC5 FC1 FC2 FC6 FT8 FT10 " + "T9 T7 C5 C3 C1 Cz C2 C4 C6 T8 T10 " + "TP9 TP7 CP3 CP1 CP2 CP4 TP8 TP10 " + "P7 P5 P3 P1 Pz P2 P4 P6 P8 " + "PO7 PO3 PO4 PO8 " + "O1 Oz O2 " + "Iz" ).split() -@pytest.mark.parametrize('rename', ('raw', 'montage', 'custom')) +@pytest.mark.parametrize("rename", ("raw", "montage", "custom")) def test_set_montage_mgh(rename): """Test setting 'mgh60' montage to old fif.""" raw = read_raw_fif(fif_fname) eeg_picks = pick_types(raw.info, meg=False, eeg=True, exclude=()) - assert list(eeg_picks) == [ii for ii, name in enumerate(raw.ch_names) - if name.startswith('EEG')] - orig_pos = np.array([raw.info['chs'][pick]['loc'][:3] - for pick in eeg_picks]) + assert list(eeg_picks) == [ + ii for ii, name in enumerate(raw.ch_names) if name.startswith("EEG") + ] + orig_pos = np.array([raw.info["chs"][pick]["loc"][:3] for pick in eeg_picks]) atol = 1e-6 mon = None - if rename == 'raw': - raw.rename_channels(lambda x: x.replace('EEG ', 'EEG')) - raw.set_montage('mgh60') # test loading with string argument - elif rename == 'montage': - mon = make_standard_montage('mgh60') - mon.rename_channels(lambda x: x.replace('EEG', 'EEG ')) + if rename == "raw": + raw.rename_channels(lambda x: x.replace("EEG ", "EEG")) + raw.set_montage("mgh60") # test loading with string argument + elif rename == "montage": + mon = make_standard_montage("mgh60") + mon.rename_channels(lambda x: x.replace("EEG", "EEG ")) assert [raw.ch_names[pick] for pick in eeg_picks] == mon.ch_names raw.set_montage(mon) else: atol = 3e-3 # different subsets of channel locations - assert rename == 'custom' + assert rename == "custom" assert len(_MGH60) == 60 - mon = make_standard_montage('standard_1020') + mon = make_standard_montage("standard_1020") assert len(mon._get_ch_pos()) == 94 def renamer(x): try: - return 'EEG %03d' % (_MGH60.index(x) + 1,) + return "EEG %03d" % (_MGH60.index(x) + 1,) except ValueError: return x @@ -1122,47 +1344,56 @@ def renamer(x): # first two are 'Fp1' and 'Fz', take them from standard_1020.elc -- # they should not be changed on load! want_pos = [[-29.4367, 83.9171, -6.9900], [0.1123, 88.2470, -1.7130]] - got_pos = [mon.get_positions()['ch_pos'][f'EEG {x:03d}'] * 1000 - for x in range(1, 3)] + got_pos = [ + mon.get_positions()["ch_pos"][f"EEG {x:03d}"] * 1000 for x in range(1, 3) + ] assert_allclose(want_pos, got_pos) - assert mon.dig[0]['coord_frame'] == FIFF.FIFFV_COORD_MRI + assert mon.dig[0]["coord_frame"] == FIFF.FIFFV_COORD_MRI trans = compute_native_head_t(mon) - trans_2 = _get_trans('fsaverage', 'mri', 'head')[0] - assert trans['to'] == trans_2['to'] - assert trans['from'] == trans_2['from'] - assert_allclose(trans['trans'], trans_2['trans'], atol=1e-6) + trans_2 = _get_trans("fsaverage", "mri", "head")[0] + assert trans["to"] == trans_2["to"] + assert trans["from"] == trans_2["from"] + assert_allclose(trans["trans"], trans_2["trans"], atol=1e-6) - new_pos = np.array([ch['loc'][:3] for ch in raw.info['chs'] - if ch['ch_name'].startswith('EEG')]) - assert ((orig_pos != new_pos).all()) + new_pos = np.array( + [ch["loc"][:3] for ch in raw.info["chs"] if ch["ch_name"].startswith("EEG")] + ) + assert (orig_pos != new_pos).all() r0 = _fit_sphere(new_pos)[1] assert_allclose(r0, [-0.001021, 0.014554, 0.041404], atol=1e-4) # spot check: Fp1 and Fpz - assert_allclose(new_pos[:2], [[-0.030903, 0.114585, 0.027867], - [-0.001337, 0.119102, 0.03289]], atol=atol) + assert_allclose( + new_pos[:2], + [[-0.030903, 0.114585, 0.027867], [-0.001337, 0.119102, 0.03289]], + atol=atol, + ) -@pytest.mark.parametrize('fname, montage, n_eeg, n_good, bads', [ - (fif_fname, 'mgh60', 60, 59, ['EEG 053']), - pytest.param(mgh70_fname, 'mgh70', 70, 64, None, - marks=[testing._pytest_mark()]), -]) +@pytest.mark.parametrize( + "fname, montage, n_eeg, n_good, bads", + [ + (fif_fname, "mgh60", 60, 59, ["EEG 053"]), + pytest.param( + mgh70_fname, "mgh70", 70, 64, None, marks=[testing._pytest_mark()] + ), + ], +) def test_montage_positions_similar(fname, montage, n_eeg, n_good, bads): """Test that montages give spatially similar positions.""" # 1. Prepare data: load, set bads (if missing), and filter raw = read_raw_fif(fname).pick_types(eeg=True, exclude=()) if bads is not None: - assert raw.info['bads'] == [] - raw.info['bads'] = bads + assert raw.info["bads"] == [] + raw.info["bads"] = bads assert len(raw.ch_names) == n_eeg - raw.pick_types(eeg=True, exclude='bads').load_data() + raw.pick_types(eeg=True, exclude="bads").load_data() raw.apply_function(lambda x: x - x.mean()) # remove DC raw.filter(None, 40) # remove line noise assert len(raw.ch_names) == n_good - if montage == 'mgh60': + if montage == "mgh60": montage = make_standard_montage(montage) - montage.rename_channels(lambda n: f'EEG {n[-3:]}') + montage.rename_channels(lambda n: f"EEG {n[-3:]}") raw_mon = raw.copy().set_montage(montage) # 2. First test: CSDs should be similar (CSD uses 3D positions) csd = compute_current_source_density(raw).get_data() @@ -1174,8 +1405,8 @@ def test_montage_positions_similar(fname, montage, n_eeg, n_good, bads): bads = [raw.ch_names[idx] for idx in bad_picks] orig_data = raw.get_data(bad_picks) assert_allclose(orig_data, raw_mon.get_data(bad_picks)) - raw.info['bads'] = bads - raw_mon.info['bads'] = bads + raw.info["bads"] = bads + raw_mon.info["bads"] = bads raw.interpolate_bads() raw_mon.interpolate_bads() orig_data = orig_data.ravel() @@ -1185,21 +1416,22 @@ def test_montage_positions_similar(fname, montage, n_eeg, n_good, bads): assert 0.95 < corr < 0.99, corr # 4. Third test: project each to a sphere, check cosine angles are small poss = dict() - for kind, this_raw in (('orig', raw), ('mon', raw_mon)): + for kind, this_raw in (("orig", raw), ("mon", raw_mon)): pos = np.array( - list(this_raw.get_montage().get_positions()['ch_pos'].values()), - float) + list(this_raw.get_montage().get_positions()["ch_pos"].values()), float + ) pos -= np.mean(pos, axis=0) pos /= np.linalg.norm(pos, axis=1, keepdims=True) poss[kind] = pos ang = np.rad2deg( # arccos is in [0, pi] - np.arccos(np.minimum(np.sum(poss['orig'] * poss['mon'], axis=1), 1))) + np.arccos(np.minimum(np.sum(poss["orig"] * poss["mon"], axis=1), 1)) + ) assert_array_less(ang, 20) # less than 20 deg assert_array_less(0, ang) # but not equal # XXX: this does not check ch_names + it cannot work because of write_dig -def _check_roundtrip(montage, fname, coord_frame='head'): +def _check_roundtrip(montage, fname, coord_frame="head"): """Check roundtrip writing.""" montage.save(fname, overwrite=True) montage_read = read_dig_fif(fname=fname) @@ -1211,68 +1443,74 @@ def _check_roundtrip(montage, fname, coord_frame='head'): def _fake_montage(ch_names): pos = np.random.RandomState(42).randn(len(ch_names), 3) - return make_dig_montage(ch_pos=dict(zip(ch_names, pos)), - coord_frame='head') + return make_dig_montage(ch_pos=dict(zip(ch_names, pos)), coord_frame="head") cnt_ignore_warns = [ pytest.mark.filterwarnings( - 'ignore:.*Could not parse meas date from the header. Setting to None.' + "ignore:.*Could not parse meas date from the header. Setting to None." ), - pytest.mark.filterwarnings(( - 'ignore:.*Could not define the number of bytes automatically.' - ' Defaulting to 2.') + pytest.mark.filterwarnings( + ( + "ignore:.*Could not define the number of bytes automatically." + " Defaulting to 2." + ) ), ] def test_digmontage_constructor_errors(): """Test proper error messaging.""" - with pytest.raises(ValueError, match='does not match the number'): - _ = DigMontage(ch_names=['foo', 'bar'], dig=list()) + with pytest.raises(ValueError, match="does not match the number"): + _ = DigMontage(ch_names=["foo", "bar"], dig=list()) def test_transform_to_head_and_compute_dev_head_t(): """Test transform_to_head and compute_dev_head_t.""" - EXPECTED_DEV_HEAD_T = \ - [[-3.72201691e-02, -9.98212167e-01, -4.67667497e-02, -7.31583414e-04], - [8.98064989e-01, -5.39382685e-02, 4.36543170e-01, 1.60134431e-02], - [-4.38285221e-01, -2.57513699e-02, 8.98466990e-01, 6.13035748e-02], - [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 1.00000000e+00]] + EXPECTED_DEV_HEAD_T = [ + [-3.72201691e-02, -9.98212167e-01, -4.67667497e-02, -7.31583414e-04], + [8.98064989e-01, -5.39382685e-02, 4.36543170e-01, 1.60134431e-02], + [-4.38285221e-01, -2.57513699e-02, 8.98466990e-01, 6.13035748e-02], + [0.00000000e00, 0.00000000e00, 0.00000000e00, 1.00000000e00], + ] EXPECTED_FID_IN_POLHEMUS = { - 'nasion': np.array([0.001393, 0.0131613, -0.0046967]), - 'lpa': np.array([-0.0624997, -0.0737271, 0.07996]), - 'rpa': np.array([-0.0748957, 0.0873785, 0.0811943]), + "nasion": np.array([0.001393, 0.0131613, -0.0046967]), + "lpa": np.array([-0.0624997, -0.0737271, 0.07996]), + "rpa": np.array([-0.0748957, 0.0873785, 0.0811943]), } EXPECTED_FID_IN_HEAD = { - 'nasion': np.array([-8.94466792e-18, 1.10559624e-01, -3.85185989e-34]), - 'lpa': np.array([-8.10816716e-02, 6.56321671e-18, 0]), - 'rpa': np.array([8.05048781e-02, -6.47441364e-18, 0]), + "nasion": np.array([-8.94466792e-18, 1.10559624e-01, -3.85185989e-34]), + "lpa": np.array([-8.10816716e-02, 6.56321671e-18, 0]), + "rpa": np.array([8.05048781e-02, -6.47441364e-18, 0]), } hpi_dev = np.array( - [[ 2.13951493e-02, 8.47444056e-02, -5.65431188e-02], # noqa - [ 2.10299433e-02, -8.03141101e-02, -6.34420259e-02], # noqa - [ 1.05916829e-01, 8.18485672e-05, 1.19928083e-02], # noqa - [ 9.26595105e-02, 4.64804385e-02, 8.45141253e-03], # noqa - [ 9.42554419e-02, -4.35206589e-02, 8.78999363e-03]] # noqa + [ + [2.13951493e-02, 8.47444056e-02, -5.65431188e-02], # noqa + [2.10299433e-02, -8.03141101e-02, -6.34420259e-02], # noqa + [1.05916829e-01, 8.18485672e-05, 1.19928083e-02], # noqa + [9.26595105e-02, 4.64804385e-02, 8.45141253e-03], # noqa + [9.42554419e-02, -4.35206589e-02, 8.78999363e-03], + ] # noqa ) hpi_polhemus = np.array( - [[-0.0595004, -0.0704836, 0.075893 ], # noqa - [-0.0646373, 0.0838228, 0.0762123], # noqa - [-0.0135035, 0.0072522, -0.0268405], # noqa - [-0.0202967, -0.0351498, -0.0129305], # noqa - [-0.0277519, 0.0452628, -0.0222407]] # noqa + [ + [-0.0595004, -0.0704836, 0.075893], # noqa + [-0.0646373, 0.0838228, 0.0762123], # noqa + [-0.0135035, 0.0072522, -0.0268405], # noqa + [-0.0202967, -0.0351498, -0.0129305], # noqa + [-0.0277519, 0.0452628, -0.0222407], + ] # noqa ) montage_polhemus = make_dig_montage( - **EXPECTED_FID_IN_POLHEMUS, hpi=hpi_polhemus, coord_frame='unknown' + **EXPECTED_FID_IN_POLHEMUS, hpi=hpi_polhemus, coord_frame="unknown" ) - montage_meg = make_dig_montage(hpi=hpi_dev, coord_frame='meg') + montage_meg = make_dig_montage(hpi=hpi_dev, coord_frame="meg") # Test regular workflow to get dev_head_t montage = montage_polhemus + montage_meg @@ -1280,7 +1518,7 @@ def test_transform_to_head_and_compute_dev_head_t(): for kk in fids: assert_allclose(fids[kk], EXPECTED_FID_IN_POLHEMUS[kk], atol=1e-5) - with pytest.raises(ValueError, match='set to head coordinate system'): + with pytest.raises(ValueError, match="set to head coordinate system"): _ = compute_dev_head_t(montage) montage = transform_to_head(montage) @@ -1290,39 +1528,43 @@ def test_transform_to_head_and_compute_dev_head_t(): assert_allclose(fids[kk], EXPECTED_FID_IN_HEAD[kk], atol=1e-5) dev_head_t = compute_dev_head_t(montage) - assert_allclose(dev_head_t['trans'], EXPECTED_DEV_HEAD_T, atol=5e-7) + assert_allclose(dev_head_t["trans"], EXPECTED_DEV_HEAD_T, atol=5e-7) # Test errors when number of HPI points do not match - EXPECTED_ERR_MSG = 'Device-to-Head .*Got 0 .*device and 5 points in head' + EXPECTED_ERR_MSG = "Device-to-Head .*Got 0 .*device and 5 points in head" with pytest.raises(ValueError, match=EXPECTED_ERR_MSG): _ = compute_dev_head_t(transform_to_head(montage_polhemus)) - EXPECTED_ERR_MSG = 'Device-to-Head .*Got 5 .*device and 0 points in head' + EXPECTED_ERR_MSG = "Device-to-Head .*Got 5 .*device and 0 points in head" with pytest.raises(ValueError, match=EXPECTED_ERR_MSG): - _ = compute_dev_head_t(transform_to_head( - montage_meg + make_dig_montage(**EXPECTED_FID_IN_POLHEMUS) - )) + _ = compute_dev_head_t( + transform_to_head( + montage_meg + make_dig_montage(**EXPECTED_FID_IN_POLHEMUS) + ) + ) - EXPECTED_ERR_MSG = 'Device-to-Head .*Got 3 .*device and 5 points in head' + EXPECTED_ERR_MSG = "Device-to-Head .*Got 3 .*device and 5 points in head" with pytest.raises(ValueError, match=EXPECTED_ERR_MSG): - _ = compute_dev_head_t(transform_to_head( - DigMontage(dig=_format_dig_points(montage_meg.dig[:3])) + - montage_polhemus - )) + _ = compute_dev_head_t( + transform_to_head( + DigMontage(dig=_format_dig_points(montage_meg.dig[:3])) + + montage_polhemus + ) + ) def test_set_montage_with_mismatching_ch_names(): """Test setting a DigMontage with mismatching ch_names.""" raw = read_raw_fif(fif_fname) - montage = make_standard_montage('mgh60') + montage = make_standard_montage("mgh60") # 'EEG 001' and 'EEG001' won't match - missing_err = '60 channel positions not present' + missing_err = "60 channel positions not present" with pytest.raises(ValueError, match=missing_err): raw.set_montage(montage) montage.ch_names = [ # modify the names in place - name.replace('EEG', 'EEG ') for name in montage.ch_names + name.replace("EEG", "EEG ") for name in montage.ch_names ] raw.set_montage(montage) # does not raise @@ -1333,90 +1575,97 @@ def test_set_montage_with_mismatching_ch_names(): # should work raw.set_montage(montage, match_case=False) raw.rename_channels(lambda x: x.upper()) # restore - assert 'EEG 001' in raw.ch_names and 'eeg 001' not in raw.ch_names - raw.rename_channels({'EEG 002': 'eeg 001'}) - assert 'EEG 001' in raw.ch_names and 'eeg 001' in raw.ch_names - with pytest.warns(RuntimeWarning, match='changed from V to NA'): - raw.set_channel_types({'eeg 001': 'misc'}) + assert "EEG 001" in raw.ch_names and "eeg 001" not in raw.ch_names + raw.rename_channels({"EEG 002": "eeg 001"}) + assert "EEG 001" in raw.ch_names and "eeg 001" in raw.ch_names + with pytest.warns(RuntimeWarning, match="changed from V to NA"): + raw.set_channel_types({"eeg 001": "misc"}) raw.set_montage(montage) - with pytest.warns(RuntimeWarning, match='changed from NA to V'): - raw.set_channel_types({'eeg 001': 'eeg'}) - with pytest.raises(ValueError, match='1 channel position not present'): + with pytest.warns(RuntimeWarning, match="changed from NA to V"): + raw.set_channel_types({"eeg 001": "eeg"}) + with pytest.raises(ValueError, match="1 channel position not present"): raw.set_montage(montage) - with pytest.raises(ValueError, match='match_case=False as 1 channel name'): + with pytest.raises(ValueError, match="match_case=False as 1 channel name"): raw.set_montage(montage, match_case=False) - info = create_info(['EEG 001'], 1000., 'eeg') - mon = make_dig_montage({'EEG 001': np.zeros(3), 'eeg 001': np.zeros(3)}, - nasion=[0, 1., 0], rpa=[1., 0, 0], lpa=[-1., 0, 0]) + info = create_info(["EEG 001"], 1000.0, "eeg") + mon = make_dig_montage( + {"EEG 001": np.zeros(3), "eeg 001": np.zeros(3)}, + nasion=[0, 1.0, 0], + rpa=[1.0, 0, 0], + lpa=[-1.0, 0, 0], + ) info.set_montage(mon) - with pytest.raises(ValueError, match='match_case=False as 1 montage name'): + with pytest.raises(ValueError, match="match_case=False as 1 montage name"): info.set_montage(mon, match_case=False) def test_set_montage_with_sub_super_set_of_ch_names(): """Test info and montage ch_names matching criteria.""" - N_CHANNELS = len('abcdef') - montage = _make_toy_dig_montage(N_CHANNELS, coord_frame='head') + N_CHANNELS = len("abcdef") + montage = _make_toy_dig_montage(N_CHANNELS, coord_frame="head") # montage and info match - info = create_info(ch_names=list('abcdef'), sfreq=1, ch_types='eeg') + info = create_info(ch_names=list("abcdef"), sfreq=1, ch_types="eeg") info.set_montage(montage) # montage is a SUPERset of info - info = create_info(list('abc'), sfreq=1, ch_types='eeg') + info = create_info(list("abc"), sfreq=1, ch_types="eeg") info.set_montage(montage) - assert len(info['dig']) == len(list('abc')) + 3 # 3 fiducials + assert len(info["dig"]) == len(list("abc")) + 3 # 3 fiducials # montage is a SUBset of info - _MSG = 'subset of info. There are 2 .* not present in the DigMontage' - info = create_info(ch_names=list('abcdfgh'), sfreq=1, ch_types='eeg') + _MSG = "subset of info. There are 2 .* not present in the DigMontage" + info = create_info(ch_names=list("abcdfgh"), sfreq=1, ch_types="eeg") with pytest.raises(ValueError, match=_MSG) as exc: info.set_montage(montage) # plus suggestions - assert exc.match('set_channel_types') - assert exc.match('on_missing') + assert exc.match("set_channel_types") + assert exc.match("on_missing") def test_set_montage_with_known_aliases(): """Test matching unrecognized channel locations to known aliases.""" # montage and info match - mock_montage_ch_names = ['POO7', 'POO8'] + mock_montage_ch_names = ["POO7", "POO8"] n_channels = len(mock_montage_ch_names) - montage = make_dig_montage(ch_pos=dict( - zip( - mock_montage_ch_names, - np.arange(n_channels * 3).reshape(n_channels, 3), - )), - coord_frame='head') + montage = make_dig_montage( + ch_pos=dict( + zip( + mock_montage_ch_names, + np.arange(n_channels * 3).reshape(n_channels, 3), + ) + ), + coord_frame="head", + ) - mock_info_ch_names = ['Cb1', 'Cb2'] - info = create_info(ch_names=mock_info_ch_names, sfreq=1, ch_types='eeg') + mock_info_ch_names = ["Cb1", "Cb2"] + info = create_info(ch_names=mock_info_ch_names, sfreq=1, ch_types="eeg") info.set_montage(montage, match_alias=True) # work with match_case - mock_info_ch_names = ['cb1', 'cb2'] - info = create_info(ch_names=mock_info_ch_names, sfreq=1, ch_types='eeg') + mock_info_ch_names = ["cb1", "cb2"] + info = create_info(ch_names=mock_info_ch_names, sfreq=1, ch_types="eeg") info.set_montage(montage, match_case=False, match_alias=True) # should warn user T1 instead of its alias T9 - mock_info_ch_names = ['Cb1', 'T1'] - info = create_info(ch_names=mock_info_ch_names, sfreq=1, ch_types='eeg') - with pytest.raises(ValueError, match='T1'): + mock_info_ch_names = ["Cb1", "T1"] + info = create_info(ch_names=mock_info_ch_names, sfreq=1, ch_types="eeg") + with pytest.raises(ValueError, match="T1"): info.set_montage(montage, match_case=False, match_alias=True) def test_heterogeneous_ch_type(): """Test ch_names matching criteria with heterogeneous ch_type.""" - VALID_MONTAGE_NAMED_CHS = ('eeg', 'ecog', 'seeg', 'dbs') + VALID_MONTAGE_NAMED_CHS = ("eeg", "ecog", "seeg", "dbs") montage = _make_toy_dig_montage( n_channels=len(VALID_MONTAGE_NAMED_CHS), - coord_frame='head', + coord_frame="head", ) # Montage and info match - info = create_info(montage.ch_names, 1., list(VALID_MONTAGE_NAMED_CHS)) + info = create_info(montage.ch_names, 1.0, list(VALID_MONTAGE_NAMED_CHS)) RawArray(np.zeros((4, 1)), info, copy=None).set_montage(montage) @@ -1425,45 +1674,46 @@ def test_set_montage_coord_frame_in_head_vs_unknown(): N_CHANNELS, NaN = 3, np.nan raw = _make_toy_raw(N_CHANNELS) - montage_in_head = _make_toy_dig_montage(N_CHANNELS, coord_frame='head') - montage_in_unknown = _make_toy_dig_montage( - N_CHANNELS, coord_frame='unknown' - ) + montage_in_head = _make_toy_dig_montage(N_CHANNELS, coord_frame="head") + montage_in_unknown = _make_toy_dig_montage(N_CHANNELS, coord_frame="unknown") montage_in_unknown_with_fid = _make_toy_dig_montage( - N_CHANNELS, coord_frame='unknown', - nasion=[0, 1, 0], lpa=[1, 0, 0], rpa=[-1, 0, 0], + N_CHANNELS, + coord_frame="unknown", + nasion=[0, 1, 0], + lpa=[1, 0, 0], + rpa=[-1, 0, 0], ) assert_allclose( - actual=np.array([ch['loc'] for ch in raw.info['chs']]), - desired=np.full((N_CHANNELS, 12), np.nan) + actual=np.array([ch["loc"] for ch in raw.info["chs"]]), + desired=np.full((N_CHANNELS, 12), np.nan), ) raw.set_montage(montage_in_head) assert_allclose( - actual=np.array([ch['loc'] for ch in raw.info['chs']]), + actual=np.array([ch["loc"] for ch in raw.info["chs"]]), desired=[ - [0., 1., 2., 0., 0., 0., NaN, NaN, NaN, NaN, NaN, NaN], - [3., 4., 5., 0., 0., 0., NaN, NaN, NaN, NaN, NaN, NaN], - [6., 7., 8., 0., 0., 0., NaN, NaN, NaN, NaN, NaN, NaN], - ] + [0.0, 1.0, 2.0, 0.0, 0.0, 0.0, NaN, NaN, NaN, NaN, NaN, NaN], + [3.0, 4.0, 5.0, 0.0, 0.0, 0.0, NaN, NaN, NaN, NaN, NaN, NaN], + [6.0, 7.0, 8.0, 0.0, 0.0, 0.0, NaN, NaN, NaN, NaN, NaN, NaN], + ], ) - with pytest.warns(RuntimeWarning, match='assuming identity'): + with pytest.warns(RuntimeWarning, match="assuming identity"): raw.set_montage(montage_in_unknown) raw.set_montage(montage_in_unknown_with_fid) assert_allclose( - actual=np.array([ch['loc'] for ch in raw.info['chs']]), + actual=np.array([ch["loc"] for ch in raw.info["chs"]]), desired=[ - [-0., 1., -2., 0., 0., 0., NaN, NaN, NaN, NaN, NaN, NaN], - [-3., 4., -5., 0., 0., 0., NaN, NaN, NaN, NaN, NaN, NaN], - [-6., 7., -8., 0., 0., 0., NaN, NaN, NaN, NaN, NaN, NaN], - ] + [-0.0, 1.0, -2.0, 0.0, 0.0, 0.0, NaN, NaN, NaN, NaN, NaN, NaN], + [-3.0, 4.0, -5.0, 0.0, 0.0, 0.0, NaN, NaN, NaN, NaN, NaN, NaN], + [-6.0, 7.0, -8.0, 0.0, 0.0, 0.0, NaN, NaN, NaN, NaN, NaN, NaN], + ], ) # check no collateral effects from transforming montage - assert _check_get_coord_frame(montage_in_unknown_with_fid.dig) == 'unknown' + assert _check_get_coord_frame(montage_in_unknown_with_fid.dig) == "unknown" assert_array_equal( _get_dig_montage_pos(montage_in_unknown_with_fid), [[0, 1, 2], [3, 4, 5], [6, 7, 8]], @@ -1471,41 +1721,42 @@ def test_set_montage_coord_frame_in_head_vs_unknown(): @testing.requires_testing_data -@pytest.mark.parametrize('ch_type', ('eeg', 'ecog', 'seeg', 'dbs')) +@pytest.mark.parametrize("ch_type", ("eeg", "ecog", "seeg", "dbs")) def test_montage_head_frame(ch_type): """Test that head frame is set properly.""" # gh-9446 data = np.random.randn(2, 100) - info = create_info(['a', 'b'], 512, ch_type) - for ch in info['chs']: - assert ch['coord_frame'] == FIFF.FIFFV_COORD_HEAD + info = create_info(["a", "b"], 512, ch_type) + for ch in info["chs"]: + assert ch["coord_frame"] == FIFF.FIFFV_COORD_HEAD raw = RawArray(data, info) - ch_pos = dict(a=[-0.00250136, 0.04913788, 0.05047056], - b=[-0.00528394, 0.05066484, 0.05061559]) - lpa, nasion, rpa = get_mni_fiducials( - 'fsaverage', subjects_dir=subjects_dir) - lpa, nasion, rpa = lpa['r'], nasion['r'], rpa['r'] + ch_pos = dict( + a=[-0.00250136, 0.04913788, 0.05047056], b=[-0.00528394, 0.05066484, 0.05061559] + ) + lpa, nasion, rpa = get_mni_fiducials("fsaverage", subjects_dir=subjects_dir) + lpa, nasion, rpa = lpa["r"], nasion["r"], rpa["r"] montage = make_dig_montage( - ch_pos, coord_frame='mri', nasion=nasion, lpa=lpa, rpa=rpa) + ch_pos, coord_frame="mri", nasion=nasion, lpa=lpa, rpa=rpa + ) mri_head_t = compute_native_head_t(montage) raw.set_montage(montage) pos = apply_trans(mri_head_t, np.array(list(ch_pos.values()))) - for p, ch in zip(pos, raw.info['chs']): - assert ch['coord_frame'] == FIFF.FIFFV_COORD_HEAD - assert_allclose(p, ch['loc'][:3]) + for p, ch in zip(pos, raw.info["chs"]): + assert ch["coord_frame"] == FIFF.FIFFV_COORD_HEAD + assert_allclose(p, ch["loc"][:3]) # Also test that including channels in the montage that will not have their # positions set will emit a warning - with pytest.warns(RuntimeWarning, match='changed from V to NA'): - raw.set_channel_types(dict(a='misc')) - with pytest.warns(RuntimeWarning, match='Not setting .*of 1 misc channel'): + with pytest.warns(RuntimeWarning, match="changed from V to NA"): + raw.set_channel_types(dict(a="misc")) + with pytest.warns(RuntimeWarning, match="Not setting .*of 1 misc channel"): raw.set_montage(montage) # and with a bunch of bad types raw = read_raw_fif(fif_fname) ch_pos = {ch_name: np.zeros(3) for ch_name in raw.ch_names} - mon = make_dig_montage(ch_pos, coord_frame='head') - with pytest.warns(RuntimeWarning, match='316 eog/grad/mag/stim channels'): + mon = make_dig_montage(ch_pos, coord_frame="head") + with pytest.warns(RuntimeWarning, match="316 eog/grad/mag/stim channels"): raw.set_montage(mon) @@ -1514,39 +1765,44 @@ def test_set_montage_with_missing_coordinates(): N_CHANNELS, NaN = 3, np.nan raw = _make_toy_raw(N_CHANNELS) - raw.set_channel_types({ch: 'ecog' for ch in raw.ch_names}) + raw.set_channel_types({ch: "ecog" for ch in raw.ch_names}) # don't include all the channels ch_names = raw.ch_names[1:] n_channels = len(ch_names) ch_coords = np.arange(n_channels * 3).reshape(n_channels, 3) montage_in_mri = make_dig_montage( - ch_pos=dict(zip(ch_names, ch_coords,)), - coord_frame='unknown', - nasion=[0, 1, 0], lpa=[1, 0, 0], rpa=[-1, 0, 0], + ch_pos=dict( + zip( + ch_names, + ch_coords, + ) + ), + coord_frame="unknown", + nasion=[0, 1, 0], + lpa=[1, 0, 0], + rpa=[-1, 0, 0], ) - with pytest.raises(ValueError, match='DigMontage is ' - 'only a subset of info'): + with pytest.raises(ValueError, match="DigMontage is " "only a subset of info"): raw.set_montage(montage_in_mri) - with pytest.raises(ValueError, match='Invalid value'): - raw.set_montage(montage_in_mri, on_missing='foo') + with pytest.raises(ValueError, match="Invalid value"): + raw.set_montage(montage_in_mri, on_missing="foo") - with pytest.raises(TypeError, match='must be an instance'): + with pytest.raises(TypeError, match="must be an instance"): raw.set_montage(montage_in_mri, on_missing=True) - with pytest.warns(RuntimeWarning, match='DigMontage is ' - 'only a subset of info'): - raw.set_montage(montage_in_mri, on_missing='warn') + with pytest.warns(RuntimeWarning, match="DigMontage is " "only a subset of info"): + raw.set_montage(montage_in_mri, on_missing="warn") - raw.set_montage(montage_in_mri, on_missing='ignore') + raw.set_montage(montage_in_mri, on_missing="ignore") assert_allclose( - actual=np.array([ch['loc'] for ch in raw.info['chs']]), + actual=np.array([ch["loc"] for ch in raw.info["chs"]]), desired=[ [NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN], - [0., 1., -2., 0., 0., 0., NaN, NaN, NaN, NaN, NaN, NaN], - [-3., 4., -5., 0., 0., 0., NaN, NaN, NaN, NaN, NaN, NaN], - ] + [0.0, 1.0, -2.0, 0.0, 0.0, 0.0, NaN, NaN, NaN, NaN, NaN, NaN], + [-3.0, 4.0, -5.0, 0.0, 0.0, 0.0, NaN, NaN, NaN, NaN, NaN, NaN], + ], ) @@ -1559,16 +1815,16 @@ def test_get_montage(): # 1. read in testing data and assert montage roundtrip # for testing dataset: 'test_raw.fif' raw = read_raw_fif(fif_fname) - raw = raw.rename_channels(lambda name: name.replace('EEG ', 'EEG')) + raw = raw.rename_channels(lambda name: name.replace("EEG ", "EEG")) raw2 = raw.copy() # get montage and then set montage and # it should be the same montage = raw.get_montage() - raw.set_montage(montage, on_missing='raise') + raw.set_montage(montage, on_missing="raise") test_montage = raw.get_montage() - assert_object_equal(raw.info['chs'], raw2.info['chs']) + assert_object_equal(raw.info["chs"], raw2.info["chs"]) assert_dig_allclose(raw2.info, raw.info) - assert_object_equal(raw2.info['dig'], raw.info['dig']) + assert_object_equal(raw2.info["dig"], raw.info["dig"]) # the montage does not change assert_object_equal(montage.dig, test_montage.dig) @@ -1578,7 +1834,7 @@ def test_get_montage(): assert_object_equal(test2_montage.dig, test_montage.dig) # 2. now do a standard montage - montage = make_standard_montage('mgh60') + montage = make_standard_montage("mgh60") # set the montage; note renaming to make standard montage map raw.set_montage(montage) @@ -1586,20 +1842,20 @@ def test_get_montage(): # the channel locations should be the same raw2 = raw.copy() test_montage = raw.get_montage() - raw.set_montage(test_montage, on_missing='ignore') + raw.set_montage(test_montage, on_missing="ignore") # the montage should fulfill a roundtrip with make_dig_montage test2_montage = make_dig_montage(**test_montage.get_positions()) assert_object_equal(test2_montage.dig, test_montage.dig) # chs should not change - assert_object_equal(raw2.info['chs'], raw.info['chs']) + assert_object_equal(raw2.info["chs"], raw.info["chs"]) # dig order might be different after set_montage assert montage.ch_names == test_montage.ch_names # note that test_montage will have different coordinate frame # compared to standard montage assert_dig_allclose(raw2.info, raw.info) - assert_object_equal(raw2.info['dig'], raw.info['dig']) + assert_object_equal(raw2.info["dig"], raw.info["dig"]) # 3. if montage gets set to None raw.set_montage(None) @@ -1618,14 +1874,14 @@ def test_get_montage(): # of channels mapping = dict() for ii, ch_name in enumerate(raw_bv.ch_names): - mapping[ch_name] = 'EEG%03d' % (ii + 1,) + mapping[ch_name] = "EEG%03d" % (ii + 1,) raw_bv.rename_channels(mapping) for ii, ch_name in enumerate(raw_bv_2.ch_names): - mapping[ch_name] = 'EEG%03d' % (ii + 33,) + mapping[ch_name] = "EEG%03d" % (ii + 33,) raw_bv_2.rename_channels(mapping) raw_bv.add_channels([raw_bv_2]) - for ch in raw_bv.info['chs']: - ch['kind'] = FIFF.FIFFV_EEG_CH + for ch in raw_bv.info["chs"]: + ch["kind"] = FIFF.FIFFV_EEG_CH # Set the montage and roundtrip raw_bv.set_montage(dig_montage) @@ -1633,14 +1889,14 @@ def test_get_montage(): # reset the montage test_montage = raw_bv.get_montage() - raw_bv.set_montage(test_montage, on_missing='ignore') + raw_bv.set_montage(test_montage, on_missing="ignore") # dig order might be different after set_montage - assert_object_equal(raw_bv2.info['dig'], raw_bv.info['dig']) + assert_object_equal(raw_bv2.info["dig"], raw_bv.info["dig"]) assert_dig_allclose(raw_bv2.info, raw_bv.info) # if dig is not set in the info, then montage returns None with raw.info._unlock(): - raw.info['dig'] = None + raw.info["dig"] = None assert raw.get_montage() is None # the montage should fulfill a roundtrip with make_dig_montage @@ -1653,8 +1909,7 @@ def test_read_dig_hpts(): fname = io_dir / "brainvision" / "tests" / "data" / "test.hpts" montage = read_dig_hpts(fname) assert repr(montage) == ( - '' + "" ) @@ -1677,7 +1932,7 @@ def test_plot_montage(): # gh-8025 montage = read_dig_captrak(bvct_dig_montage_fname) montage.plot() - plt.close('all') + plt.close("all") f, ax = plt.subplots(1, 1) montage.plot(axes=ax) @@ -1694,12 +1949,12 @@ def test_plot_montage(): def test_montage_equality(): """Test montage equality.""" rng = np.random.RandomState(0) - fiducials = dict(zip(('nasion', 'lpa', 'rpa'), rng.rand(3, 3))) + fiducials = dict(zip(("nasion", "lpa", "rpa"), rng.rand(3, 3))) # hsp positions are [1X, 1X, 1X] - hsp1 = make_dig_montage(**fiducials, hsp=np.full((2, 3), 11.)) - hsp2 = make_dig_montage(**fiducials, hsp=np.full((2, 3), 12.)) - hsp2_identical = make_dig_montage(**fiducials, hsp=np.full((2, 3), 12.)) + hsp1 = make_dig_montage(**fiducials, hsp=np.full((2, 3), 11.0)) + hsp2 = make_dig_montage(**fiducials, hsp=np.full((2, 3), 12.0)) + hsp2_identical = make_dig_montage(**fiducials, hsp=np.full((2, 3), 12.0)) assert hsp1 != hsp2 assert hsp2 == hsp2_identical @@ -1710,45 +1965,46 @@ def test_montage_add_fiducials(): """Test montage can add estimated fiducials for rpa, lpa, nas.""" # get the fiducials from test file subjects_dir = data_path / "subjects" - subject = 'sample' + subject = "sample" fid_fname = subjects_dir / subject / "bem" / "sample-fiducials.fif" test_fids, test_coord_frame = read_fiducials(fid_fname) - test_fids = np.array([f['r'] for f in test_fids]) + test_fids = np.array([f["r"] for f in test_fids]) # create test montage and add estimated fiducials - test_ch_pos = {'A1': [0, 0, 0]} - montage = make_dig_montage(ch_pos=test_ch_pos, coord_frame='mri') + test_ch_pos = {"A1": [0, 0, 0]} + montage = make_dig_montage(ch_pos=test_ch_pos, coord_frame="mri") montage.add_estimated_fiducials(subject=subject, subjects_dir=subjects_dir) # check that adding MNI fiducials fails because we're in MRI - with pytest.raises(RuntimeError, match='Montage should be in the ' - '"mni_tal" coordinate frame'): + with pytest.raises( + RuntimeError, match="Montage should be in the " '"mni_tal" coordinate frame' + ): montage.add_mni_fiducials(subjects_dir=subjects_dir) # check that these fiducials are close to the estimated fiducials ch_pos = montage.get_positions() - fids_est = [ch_pos['lpa'], ch_pos['nasion'], ch_pos['rpa']] + fids_est = [ch_pos["lpa"], ch_pos["nasion"], ch_pos["rpa"]] - dists = np.linalg.norm(test_fids - fids_est, axis=-1) * 1000. # -> mm + dists = np.linalg.norm(test_fids - fids_est, axis=-1) * 1000.0 # -> mm assert (dists < 8).all(), dists # an error should be raised if the montage is not in `mri` coord_frame # which is the FreeSurfer RAS - montage = make_dig_montage(ch_pos=test_ch_pos, coord_frame='mni_tal') - with pytest.raises(RuntimeError, match='Montage should be in the ' - '"mri" coordinate frame'): - montage.add_estimated_fiducials(subject=subject, - subjects_dir=subjects_dir) + montage = make_dig_montage(ch_pos=test_ch_pos, coord_frame="mni_tal") + with pytest.raises( + RuntimeError, match="Montage should be in the " '"mri" coordinate frame' + ): + montage.add_estimated_fiducials(subject=subject, subjects_dir=subjects_dir) # test that adding MNI fiducials works montage.add_mni_fiducials(subjects_dir=subjects_dir) - test_fids = get_mni_fiducials('fsaverage', subjects_dir=subjects_dir) + test_fids = get_mni_fiducials("fsaverage", subjects_dir=subjects_dir) for fid, test_fid in zip(montage.dig[:3], test_fids): - assert_array_equal(fid['r'], test_fid['r']) + assert_array_equal(fid["r"], test_fid["r"]) # test remove fiducials montage.remove_fiducials() - assert all([d['kind'] != FIFF.FIFFV_POINT_CARDINAL for d in montage.dig]) + assert all([d["kind"] != FIFF.FIFFV_POINT_CARDINAL for d in montage.dig]) def test_read_dig_localite(tmp_path): @@ -1773,23 +2029,23 @@ def test_read_dig_localite(tmp_path): 17,ch14,-61.16539571,-61.86866187,26.23986153 18,ch15,-55.82855386,-34.77319103,25.8083942""" - fname = tmp_path / 'localite.csv' - with open(fname, 'w') as f: - for row in contents.split('\n'): - f.write(f'{row.lstrip()}\n') + fname = tmp_path / "localite.csv" + with open(fname, "w") as f: + for row in contents.split("\n"): + f.write(f"{row.lstrip()}\n") montage = read_dig_localite(fname, nasion="Nasion", lpa="LPA", rpa="RPA") - s = '' + s = "" assert repr(montage) == s - assert montage.ch_names == [f'ch{i:02}' for i in range(1, 16)] + assert montage.ch_names == [f"ch{i:02}" for i in range(1, 16)] def test_make_wrong_dig_montage(): """Test that a montage with non numeric is not possible.""" - make_dig_montage(ch_pos={'A1': ['0', '0', '0']}) # converted to floats + make_dig_montage(ch_pos={"A1": ["0", "0", "0"]}) # converted to floats with pytest.raises(ValueError, match="could not convert string to float"): - make_dig_montage(ch_pos={'A1': ['a', 'b', 'c']}) + make_dig_montage(ch_pos={"A1": ["a", "b", "c"]}) with pytest.raises(TypeError, match="instance of ndarray, list, or tuple"): - make_dig_montage(ch_pos={'A1': 5}) + make_dig_montage(ch_pos={"A1": 5}) @testing.requires_testing_data @@ -1805,15 +2061,14 @@ def test_fnirs_montage(): assert num_detectors == 13 # Make a change to the montage before setting - raw.info['chs'][2]['loc'][:3] = [1., 2, 3] + raw.info["chs"][2]["loc"][:3] = [1.0, 2, 3] # Set montage back to original raw.set_montage(mtg) for ch in range(len(raw.ch_names)): - assert_array_equal(info_orig['chs'][ch]['loc'], - raw.info['chs'][ch]['loc']) + assert_array_equal(info_orig["chs"][ch]["loc"], raw.info["chs"][ch]["loc"]) # Mixed channel types not supported yet - raw.set_channel_types({ch_name: 'eeg' for ch_name in raw.ch_names[-2:]}) - with pytest.raises(ValueError, match='mix of fNIRS'): + raw.set_channel_types({ch_name: "eeg" for ch_name in raw.ch_names[-2:]}) + with pytest.raises(ValueError, match="mix of fNIRS"): raw.get_montage() diff --git a/mne/channels/tests/test_standard_montage.py b/mne/channels/tests/test_standard_montage.py index 49fffaa4ab3..a9cf8f2cf0a 100644 --- a/mne/channels/tests/test_standard_montage.py +++ b/mne/channels/tests/test_standard_montage.py @@ -8,8 +8,7 @@ import numpy as np -from numpy.testing import (assert_allclose, assert_array_almost_equal, - assert_raises) +from numpy.testing import assert_allclose, assert_array_almost_equal, assert_raises from mne import create_info from mne.channels import make_standard_montage, compute_native_head_t @@ -21,7 +20,7 @@ from mne.transforms import _get_trans, _angle_between_quats, rot_to_quat -@pytest.mark.parametrize('kind', get_builtin_montages()) +@pytest.mark.parametrize("kind", get_builtin_montages()) def test_standard_montages_have_fids(kind): """Test standard montage are all in unknown coord (have fids).""" montage = make_standard_montage(kind) @@ -29,44 +28,47 @@ def test_standard_montages_have_fids(kind): for k, v in fids.items(): assert v is not None, k for d in montage.dig: - if kind.startswith(('artinis', 'standard', 'mgh')): + if kind.startswith(("artinis", "standard", "mgh")): want = FIFF.FIFFV_COORD_MRI else: want = FIFF.FIFFV_COORD_UNKNOWN - assert d['coord_frame'] == want + assert d["coord_frame"] == want def test_standard_montage_errors(): """Test error handling for wrong keys.""" _msg = "Invalid value for the 'kind' parameter..*but got.*not-here" with pytest.raises(ValueError, match=_msg): - _ = make_standard_montage('not-here') - - -@pytest.mark.parametrize('head_size', (HEAD_SIZE_DEFAULT, 0.05)) -@pytest.mark.parametrize('kind, tol', [ - ['EGI_256', 1e-5], - ['easycap-M1', 1e-8], - ['easycap-M10', 1e-8], - ['biosemi128', 1e-8], - ['biosemi16', 1e-8], - ['biosemi160', 1e-8], - ['biosemi256', 1e-8], - ['biosemi32', 1e-8], - ['biosemi64', 1e-8], - ['brainproducts-RNP-BA-128', 1e-8] -]) + _ = make_standard_montage("not-here") + + +@pytest.mark.parametrize("head_size", (HEAD_SIZE_DEFAULT, 0.05)) +@pytest.mark.parametrize( + "kind, tol", + [ + ["EGI_256", 1e-5], + ["easycap-M1", 1e-8], + ["easycap-M10", 1e-8], + ["biosemi128", 1e-8], + ["biosemi16", 1e-8], + ["biosemi160", 1e-8], + ["biosemi256", 1e-8], + ["biosemi32", 1e-8], + ["biosemi64", 1e-8], + ["brainproducts-RNP-BA-128", 1e-8], + ], +) def test_standard_montages_on_sphere(kind, tol, head_size): """Test some standard montage are on sphere.""" kwargs = dict() if head_size != HEAD_SIZE_DEFAULT: - kwargs['head_size'] = head_size + kwargs["head_size"] = head_size montage = make_standard_montage(kind, **kwargs) - eeg_loc = np.array([ch['r'] for ch in _get_dig_eeg(montage.dig)]) + eeg_loc = np.array([ch["r"] for ch in _get_dig_eeg(montage.dig)]) assert_allclose( actual=np.linalg.norm(eeg_loc, axis=1), - desired=np.full((eeg_loc.shape[0], ), head_size), + desired=np.full((eeg_loc.shape[0],), head_size), atol=tol, ) @@ -74,14 +76,14 @@ def test_standard_montages_on_sphere(kind, tol, head_size): def test_standard_superset(): """Test some properties that should hold for superset montages.""" # new montages, tweaked to end up at the same size as the others - m_1005 = make_standard_montage('standard_1005', 0.0970) - m_1020 = make_standard_montage('standard_1020', 0.0991) + m_1005 = make_standard_montage("standard_1005", 0.0970) + m_1020 = make_standard_montage("standard_1020", 0.0991) assert len(set(m_1005.ch_names) - set(m_1020.ch_names)) > 0 # XXX weird that this is not a proper superset... - assert set(m_1020.ch_names) - set(m_1005.ch_names) == {'O10', 'O9'} + assert set(m_1020.ch_names) - set(m_1005.ch_names) == {"O10", "O9"} c_1005 = m_1005._get_ch_pos() for key, value in m_1020._get_ch_pos().items(): - if key not in ('O10', 'O9'): + if key not in ("O10", "O9"): assert_allclose(c_1005[key], value, atol=1e-4, err_msg=key) @@ -93,15 +95,29 @@ def _simulate_artinis_octamon(): """ np.random.seed(42) data = np.absolute(np.random.normal(size=(16, 100))) - ch_names = ['S1_D1 760', 'S1_D1 850', 'S2_D1 760', 'S2_D1 850', - 'S3_D1 760', 'S3_D1 850', 'S4_D1 760', 'S4_D1 850', - 'S5_D2 760', 'S5_D2 850', 'S6_D2 760', 'S6_D2 850', - 'S7_D2 760', 'S7_D2 850', 'S8_D2 760', 'S8_D2 850'] - ch_types = ['fnirs_cw_amplitude' for _ in ch_names] - sfreq = 10. # Hz + ch_names = [ + "S1_D1 760", + "S1_D1 850", + "S2_D1 760", + "S2_D1 850", + "S3_D1 760", + "S3_D1 850", + "S4_D1 760", + "S4_D1 850", + "S5_D2 760", + "S5_D2 850", + "S6_D2 760", + "S6_D2 850", + "S7_D2 760", + "S7_D2 850", + "S8_D2 760", + "S8_D2 850", + ] + ch_types = ["fnirs_cw_amplitude" for _ in ch_names] + sfreq = 10.0 # Hz info = create_info(ch_names=ch_names, ch_types=ch_types, sfreq=sfreq) for i, ch_name in enumerate(ch_names): - info['chs'][i]['loc'][9] = int(ch_name.split(' ')[1]) + info["chs"][i]["loc"][9] = int(ch_name.split(" ")[1]) raw = RawArray(data, info) return raw @@ -115,47 +131,71 @@ def _simulate_artinis_brite23(): """ np.random.seed(0) data = np.random.normal(size=(46, 100)) - sd_names = ['S1_D1', 'S2_D1', 'S3_D1', 'S4_D1', 'S3_D2', 'S4_D2', 'S5_D2', - 'S4_D3', 'S5_D3', 'S6_D3', 'S5_D4', 'S6_D4', 'S7_D4', 'S6_D5', - 'S7_D5', 'S8_D5', 'S7_D6', 'S8_D6', 'S9_D6', 'S8_D7', 'S9_D7', - 'S10_D7', 'S11_D7'] + sd_names = [ + "S1_D1", + "S2_D1", + "S3_D1", + "S4_D1", + "S3_D2", + "S4_D2", + "S5_D2", + "S4_D3", + "S5_D3", + "S6_D3", + "S5_D4", + "S6_D4", + "S7_D4", + "S6_D5", + "S7_D5", + "S8_D5", + "S7_D6", + "S8_D6", + "S9_D6", + "S8_D7", + "S9_D7", + "S10_D7", + "S11_D7", + ] ch_names = [] ch_types = [] for name in sd_names: - ch_names.append(name + ' hbo') - ch_types.append('hbo') - ch_names.append(name + ' hbr') - ch_types.append('hbr') - sfreq = 10. # Hz + ch_names.append(name + " hbo") + ch_types.append("hbo") + ch_names.append(name + " hbr") + ch_types.append("hbr") + sfreq = 10.0 # Hz info = create_info(ch_names=ch_names, ch_types=ch_types, sfreq=sfreq) raw = RawArray(data, info) return raw -@pytest.mark.parametrize('kind', ('octamon', 'brite23')) +@pytest.mark.parametrize("kind", ("octamon", "brite23")) def test_set_montage_artinis_fsaverage(kind): """Test that artinis montages match fsaverage's head<->MRI transform.""" # Compare OctaMon and Brite23 to fsaverage - trans_fs, _ = _get_trans('fsaverage') - montage = make_standard_montage(f'artinis-{kind}') + trans_fs, _ = _get_trans("fsaverage") + montage = make_standard_montage(f"artinis-{kind}") trans = compute_native_head_t(montage) - assert trans['to'] == trans_fs['to'] - assert trans['from'] == trans_fs['from'] - translation = 1000 * np.linalg.norm(trans['trans'][:3, 3] - - trans_fs['trans'][:3, 3]) + assert trans["to"] == trans_fs["to"] + assert trans["from"] == trans_fs["from"] + translation = 1000 * np.linalg.norm( + trans["trans"][:3, 3] - trans_fs["trans"][:3, 3] + ) assert 0 < translation < 1 # mm rotation = np.rad2deg( - _angle_between_quats(rot_to_quat(trans['trans'][:3, :3]), - rot_to_quat(trans_fs['trans'][:3, :3]))) + _angle_between_quats( + rot_to_quat(trans["trans"][:3, :3]), rot_to_quat(trans_fs["trans"][:3, :3]) + ) + ) assert 0 < rotation < 1 # degrees def test_set_montage_artinis_basic(): """Test that OctaMon and Brite23 montages are set properly.""" # Test OctaMon montage - montage_octamon = make_standard_montage('artinis-octamon') - montage_brite23 = make_standard_montage('artinis-brite23') + montage_octamon = make_standard_montage("artinis-octamon") + montage_brite23 = make_standard_montage("artinis-brite23") raw = _simulate_artinis_octamon() raw_od = optical_density(raw) old_info = raw.info.copy() @@ -164,82 +204,106 @@ def test_set_montage_artinis_basic(): raw_od.set_montage(montage_octamon) raw_hb = beer_lambert_law(raw_od, ppf=6) # montage needed for BLL # Check that the montage was actually modified - assert_raises(AssertionError, assert_array_almost_equal, - old_info['chs'][0]['loc'][:9], - raw.info['chs'][0]['loc'][:9]) - assert_raises(AssertionError, assert_array_almost_equal, - old_info_od['chs'][0]['loc'][:9], - raw_od.info['chs'][0]['loc'][:9]) + assert_raises( + AssertionError, + assert_array_almost_equal, + old_info["chs"][0]["loc"][:9], + raw.info["chs"][0]["loc"][:9], + ) + assert_raises( + AssertionError, + assert_array_almost_equal, + old_info_od["chs"][0]["loc"][:9], + raw_od.info["chs"][0]["loc"][:9], + ) # Check a known location - assert_array_almost_equal(raw.info['chs'][0]['loc'][:3], - [0.054243, 0.081884, 0.054544]) - assert_array_almost_equal(raw.info['chs'][8]['loc'][:3], - [-0.03013, 0.105097, 0.055894]) - assert_array_almost_equal(raw.info['chs'][12]['loc'][:3], - [-0.055681, 0.086566, 0.055858]) - assert_array_almost_equal(raw_od.info['chs'][12]['loc'][:3], - [-0.055681, 0.086566, 0.055858]) - assert_array_almost_equal(raw_hb.info['chs'][12]['loc'][:3], - [-0.055681, 0.086566, 0.055858]) + assert_array_almost_equal( + raw.info["chs"][0]["loc"][:3], [0.054243, 0.081884, 0.054544] + ) + assert_array_almost_equal( + raw.info["chs"][8]["loc"][:3], [-0.03013, 0.105097, 0.055894] + ) + assert_array_almost_equal( + raw.info["chs"][12]["loc"][:3], [-0.055681, 0.086566, 0.055858] + ) + assert_array_almost_equal( + raw_od.info["chs"][12]["loc"][:3], [-0.055681, 0.086566, 0.055858] + ) + assert_array_almost_equal( + raw_hb.info["chs"][12]["loc"][:3], [-0.055681, 0.086566, 0.055858] + ) # Check that locations are identical for a pair of channels (all elements # except the 10th which is the wavelength if not hbo and hbr type) - assert_array_almost_equal(raw.info['chs'][0]['loc'][:9], - raw.info['chs'][1]['loc'][:9]) - assert_array_almost_equal(raw_od.info['chs'][0]['loc'][:9], - raw_od.info['chs'][1]['loc'][:9]) - assert_array_almost_equal(raw_hb.info['chs'][0]['loc'][:9], - raw_hb.info['chs'][1]['loc'][:9]) + assert_array_almost_equal( + raw.info["chs"][0]["loc"][:9], raw.info["chs"][1]["loc"][:9] + ) + assert_array_almost_equal( + raw_od.info["chs"][0]["loc"][:9], raw_od.info["chs"][1]["loc"][:9] + ) + assert_array_almost_equal( + raw_hb.info["chs"][0]["loc"][:9], raw_hb.info["chs"][1]["loc"][:9] + ) # Test Brite23 montage raw = _simulate_artinis_brite23() old_info = raw.info.copy() raw.set_montage(montage_brite23) # Check that the montage was actually modified - assert_raises(AssertionError, assert_array_almost_equal, - old_info['chs'][0]['loc'][:9], - raw.info['chs'][0]['loc'][:9]) + assert_raises( + AssertionError, + assert_array_almost_equal, + old_info["chs"][0]["loc"][:9], + raw.info["chs"][0]["loc"][:9], + ) # Check a known location - assert_array_almost_equal(raw.info['chs'][0]['loc'][:3], - [0.068931, 0.046201, 0.072055]) - assert_array_almost_equal(raw.info['chs'][8]['loc'][:3], - [0.055196, 0.082757, 0.052165]) - assert_array_almost_equal(raw.info['chs'][12]['loc'][:3], - [0.033592, 0.102607, 0.047423]) + assert_array_almost_equal( + raw.info["chs"][0]["loc"][:3], [0.068931, 0.046201, 0.072055] + ) + assert_array_almost_equal( + raw.info["chs"][8]["loc"][:3], [0.055196, 0.082757, 0.052165] + ) + assert_array_almost_equal( + raw.info["chs"][12]["loc"][:3], [0.033592, 0.102607, 0.047423] + ) # Check that locations are identical for a pair of channels (all elements # except the 10th which is the wavelength if not hbo and hbr type) - assert_array_almost_equal(raw.info['chs'][0]['loc'][:9], - raw.info['chs'][1]['loc'][:9]) + assert_array_almost_equal( + raw.info["chs"][0]["loc"][:9], raw.info["chs"][1]["loc"][:9] + ) # Test channel variations raw_old = _simulate_artinis_brite23() # Raw missing some channels that are in the montage: pass raw = raw_old.copy() - raw.pick(['S1_D1 hbo', 'S1_D1 hbr']) - raw.set_montage('artinis-brite23') + raw.pick(["S1_D1 hbo", "S1_D1 hbr"]) + raw.set_montage("artinis-brite23") # Unconventional channel pair: pass raw = raw_old.copy() - info_new = create_info(['S11_D1 hbo', 'S11_D1 hbr'], raw.info['sfreq'], - ['hbo', 'hbr']) + info_new = create_info( + ["S11_D1 hbo", "S11_D1 hbr"], raw.info["sfreq"], ["hbo", "hbr"] + ) new = RawArray(np.random.normal(size=(2, len(raw))), info_new) raw.add_channels([new], force_update_info=True) - raw.set_montage('artinis-brite23') + raw.set_montage("artinis-brite23") # Source not in montage: fail raw = raw_old.copy() - info_new = create_info(['S12_D7 hbo', 'S12_D7 hbr'], raw.info['sfreq'], - ['hbo', 'hbr']) + info_new = create_info( + ["S12_D7 hbo", "S12_D7 hbr"], raw.info["sfreq"], ["hbo", "hbr"] + ) new = RawArray(np.random.normal(size=(2, len(raw))), info_new) raw.add_channels([new], force_update_info=True) - with pytest.raises(ValueError, match='is not in list'): - raw.set_montage('artinis-brite23') + with pytest.raises(ValueError, match="is not in list"): + raw.set_montage("artinis-brite23") # Detector not in montage: fail raw = raw_old.copy() - info_new = create_info(['S11_D8 hbo', 'S11_D8 hbr'], raw.info['sfreq'], - ['hbo', 'hbr']) + info_new = create_info( + ["S11_D8 hbo", "S11_D8 hbr"], raw.info["sfreq"], ["hbo", "hbr"] + ) new = RawArray(np.random.normal(size=(2, len(raw))), info_new) raw.add_channels([new], force_update_info=True) - with pytest.raises(ValueError, match='is not in list'): - raw.set_montage('artinis-brite23') + with pytest.raises(ValueError, match="is not in list"): + raw.set_montage("artinis-brite23") diff --git a/mne/chpi.py b/mne/chpi.py index 9d80fa6efde..3bbddb1647b 100644 --- a/mne/chpi.py +++ b/mne/chpi.py @@ -29,24 +29,49 @@ from .io.kit.constants import KIT from .io.kit.kit import RawKIT as _RawKIT from .io.meas_info import _simplify_info, Info -from .io.pick import (pick_types, pick_channels, pick_channels_regexp, - pick_info, _picks_to_idx) +from .io.pick import ( + pick_types, + pick_channels, + pick_channels_regexp, + pick_info, + _picks_to_idx, +) from .io.proj import Projection, setup_proj from .io.constants import FIFF from .io.ctf.trans import _make_ctf_coord_trans_set -from .forward import (_magnetic_dipole_field_vec, _create_meg_coils, - _concatenate_coils) +from .forward import _magnetic_dipole_field_vec, _create_meg_coils, _concatenate_coils from .cov import make_ad_hoc_cov, compute_whitener from .dipole import _make_guesses from .fixes import jit -from .preprocessing.maxwell import (_sss_basis, _prep_mf_coils, - _regularize_out, _get_mf_picks_fix_mags) -from .transforms import (apply_trans, invert_transform, _angle_between_quats, - quat_to_rot, rot_to_quat, _fit_matched_points, - _quat_to_affine, als_ras_trans) -from .utils import (verbose, logger, use_log_level, _check_fname, warn, - _validate_type, ProgressBar, _check_option, _pl, - _on_missing, _verbose_safe_false) +from .preprocessing.maxwell import ( + _sss_basis, + _prep_mf_coils, + _regularize_out, + _get_mf_picks_fix_mags, +) +from .transforms import ( + apply_trans, + invert_transform, + _angle_between_quats, + quat_to_rot, + rot_to_quat, + _fit_matched_points, + _quat_to_affine, + als_ras_trans, +) +from .utils import ( + verbose, + logger, + use_log_level, + _check_fname, + warn, + _validate_type, + ProgressBar, + _check_option, + _pl, + _on_missing, + _verbose_safe_false, +) # Eventually we should add: # hpicons @@ -57,6 +82,7 @@ # ############################################################################ # Reading from text or FIF file + def read_head_pos(fname): """Read MaxFilter-formatted head position parameters. @@ -80,12 +106,11 @@ def read_head_pos(fname): ----- .. versionadded:: 0.12 """ - _check_fname(fname, must_exist=True, overwrite='read') + _check_fname(fname, must_exist=True, overwrite="read") data = np.loadtxt(fname, skiprows=1) # first line is header, skip it data.shape = (-1, 10) # ensure it's the right size even if empty if np.isnan(data).any(): # make sure we didn't do something dumb - raise RuntimeError('positions could not be read properly from %s' - % fname) + raise RuntimeError("positions could not be read properly from %s" % fname) return data @@ -111,14 +136,15 @@ def write_head_pos(fname, pos): _check_fname(fname, overwrite=True) pos = np.array(pos, np.float64) if pos.ndim != 2 or pos.shape[1] != 10: - raise ValueError('pos must be a 2D array of shape (N, 10)') - with open(fname, 'wb') as fid: - fid.write(' Time q1 q2 q3 q4 q5 ' - 'q6 g-value error velocity\n'.encode('ASCII')) + raise ValueError("pos must be a 2D array of shape (N, 10)") + with open(fname, "wb") as fid: + fid.write( + " Time q1 q2 q3 q4 q5 " + "q6 g-value error velocity\n".encode("ASCII") + ) for p in pos: - fmts = ['% 9.3f'] + ['% 8.5f'] * 9 - fid.write(((' ' + ' '.join(fmts) + '\n') - % tuple(p)).encode('ASCII')) + fmts = ["% 9.3f"] + ["% 8.5f"] * 9 + fid.write(((" " + " ".join(fmts) + "\n") % tuple(p)).encode("ASCII")) def head_pos_to_trans_rot_t(quats): @@ -178,15 +204,14 @@ def extract_chpi_locs_ctf(raw, verbose=None): .. versionadded:: 0.20 """ # Pick channels corresponding to the cHPI positions - hpi_picks = pick_channels_regexp(raw.info['ch_names'], 'HLC00[123][123].*') + hpi_picks = pick_channels_regexp(raw.info["ch_names"], "HLC00[123][123].*") # make sure we get 9 channels if len(hpi_picks) != 9: - raise RuntimeError('Could not find all 9 cHPI channels') + raise RuntimeError("Could not find all 9 cHPI channels") # get indices in alphabetical order - sorted_picks = np.array(sorted(hpi_picks, - key=lambda k: raw.info['ch_names'][k])) + sorted_picks = np.array(sorted(hpi_picks, key=lambda k: raw.info["ch_names"][k])) # make picks to match order of dig cardinial ident codes. # LPA (HPIC002[123]-*), NAS(HPIC001[123]-*), RPA(HPIC003[123]-*) @@ -199,7 +224,7 @@ def extract_chpi_locs_ctf(raw, verbose=None): # transforms tmp_trans = _make_ctf_coord_trans_set(None, None) - ctf_dev_dev_t = tmp_trans['t_ctf_dev_dev'] + ctf_dev_dev_t = tmp_trans["t_ctf_dev_dev"] del tmp_trans # find indices where chpi locations change @@ -216,7 +241,7 @@ def extract_chpi_locs_ctf(raw, verbose=None): @verbose -def extract_chpi_locs_kit(raw, stim_channel='MISC 064', *, verbose=None): +def extract_chpi_locs_kit(raw, stim_channel="MISC 064", *, verbose=None): """Extract cHPI locations from KIT data. Parameters @@ -235,34 +260,35 @@ def extract_chpi_locs_kit(raw, stim_channel='MISC 064', *, verbose=None): ----- .. versionadded:: 0.23 """ - _validate_type(raw, (_RawKIT,), 'raw') + _validate_type(raw, (_RawKIT,), "raw") stim_chs = [ - raw.info['ch_names'][pick] for pick in pick_types( - raw.info, stim=True, misc=True, ref_meg=False)] - _validate_type(stim_channel, str, 'stim_channel') - _check_option('stim_channel', stim_channel, stim_chs) + raw.info["ch_names"][pick] + for pick in pick_types(raw.info, stim=True, misc=True, ref_meg=False) + ] + _validate_type(stim_channel, str, "stim_channel") + _check_option("stim_channel", stim_channel, stim_chs) idx = raw.ch_names.index(stim_channel) safe_false = _verbose_safe_false() events_on = find_events( - raw, stim_channel=raw.ch_names[idx], output='onset', - verbose=safe_false)[:, 0] + raw, stim_channel=raw.ch_names[idx], output="onset", verbose=safe_false + )[:, 0] events_off = find_events( - raw, stim_channel=raw.ch_names[idx], output='offset', - verbose=safe_false)[:, 0] + raw, stim_channel=raw.ch_names[idx], output="offset", verbose=safe_false + )[:, 0] bad = False if len(events_on) == 0 or len(events_off) == 0: bad = True else: if events_on[-1] > events_off[-1]: events_on = events_on[:-1] - if events_on.size != events_off.size or not \ - (events_on < events_off).all(): + if events_on.size != events_off.size or not (events_on < events_off).all(): bad = True if bad: raise RuntimeError( - f'Could not find appropriate cHPI intervals from {stim_channel}') + f"Could not find appropriate cHPI intervals from {stim_channel}" + ) # use the midpoint for times - times = (events_on + events_off) / (2 * raw.info['sfreq']) + times = (events_on + events_off) / (2 * raw.info["sfreq"]) del events_on, events_off # XXX remove first two rows. It is unknown currently if there is a way to # determine from the con file the number of initial pulses that @@ -271,24 +297,25 @@ def extract_chpi_locs_kit(raw, stim_channel='MISC 064', *, verbose=None): # may just always be 2... times = times[2:] n_coils = 5 # KIT always has 5 (hard-coded in reader) - header = raw._raw_extras[0]['dirs'][KIT.DIR_INDEX_CHPI_DATA] - dtype = np.dtype([('good', ' 0 else None # grab codes indicating a coil is active - hpi_on = [coil['event_bits'][0] for coil in hpi_sub['hpi_coils']] + hpi_on = [coil["event_bits"][0] for coil in hpi_sub["hpi_coils"]] # not all HPI coils will actually be used - hpi_on = np.array([hpi_on[hc['number'] - 1] for hc in hpi_coils]) + hpi_on = np.array([hpi_on[hc["number"] - 1] for hc in hpi_coils]) # mask for coils that may be active hpi_mask = np.array([event_bit != 0 for event_bit in hpi_on]) hpi_on = hpi_on[hpi_mask] @@ -366,63 +404,71 @@ def get_chpi_info(info, on_missing='raise', verbose=None): @verbose def _get_hpi_initial_fit(info, adjust=False, verbose=None): """Get HPI fit locations from raw.""" - if info['hpi_results'] is None or len(info['hpi_results']) == 0: - raise RuntimeError('no initial cHPI head localization performed') - - hpi_result = info['hpi_results'][-1] - hpi_dig = sorted([d for d in info['dig'] - if d['kind'] == FIFF.FIFFV_POINT_HPI], - key=lambda x: x['ident']) # ascending (dig) order + if info["hpi_results"] is None or len(info["hpi_results"]) == 0: + raise RuntimeError("no initial cHPI head localization performed") + + hpi_result = info["hpi_results"][-1] + hpi_dig = sorted( + [d for d in info["dig"] if d["kind"] == FIFF.FIFFV_POINT_HPI], + key=lambda x: x["ident"], + ) # ascending (dig) order if len(hpi_dig) == 0: # CTF data, probably - hpi_dig = sorted(hpi_result['dig_points'], key=lambda x: x['ident']) - if all(d['coord_frame'] in (FIFF.FIFFV_COORD_DEVICE, - FIFF.FIFFV_COORD_UNKNOWN) - for d in hpi_dig): + hpi_dig = sorted(hpi_result["dig_points"], key=lambda x: x["ident"]) + if all( + d["coord_frame"] in (FIFF.FIFFV_COORD_DEVICE, FIFF.FIFFV_COORD_UNKNOWN) + for d in hpi_dig + ): for dig in hpi_dig: - dig.update(r=apply_trans(info['dev_head_t'], dig['r']), - coord_frame=FIFF.FIFFV_COORD_HEAD) + dig.update( + r=apply_trans(info["dev_head_t"], dig["r"]), + coord_frame=FIFF.FIFFV_COORD_HEAD, + ) # zero-based indexing, dig->info # CTF does not populate some entries so we use .get here - pos_order = hpi_result.get('order', np.arange(1, len(hpi_dig) + 1)) - 1 - used = hpi_result.get('used', np.arange(len(hpi_dig))) - dist_limit = hpi_result.get('dist_limit', 0.005) - good_limit = hpi_result.get('good_limit', 0.98) - goodness = hpi_result.get('goodness', np.ones(len(hpi_dig))) + pos_order = hpi_result.get("order", np.arange(1, len(hpi_dig) + 1)) - 1 + used = hpi_result.get("used", np.arange(len(hpi_dig))) + dist_limit = hpi_result.get("dist_limit", 0.005) + good_limit = hpi_result.get("good_limit", 0.98) + goodness = hpi_result.get("goodness", np.ones(len(hpi_dig))) # this shouldn't happen, eventually we could add the transforms # necessary to put it in head coords - if not all(d['coord_frame'] == FIFF.FIFFV_COORD_HEAD for d in hpi_dig): - raise RuntimeError('cHPI coordinate frame incorrect') + if not all(d["coord_frame"] == FIFF.FIFFV_COORD_HEAD for d in hpi_dig): + raise RuntimeError("cHPI coordinate frame incorrect") # Give the user some info - logger.info('HPIFIT: %s coils digitized in order %s' - % (len(pos_order), ' '.join(str(o + 1) for o in pos_order))) - logger.debug('HPIFIT: %s coils accepted: %s' - % (len(used), ' '.join(str(h) for h in used))) - hpi_rrs = np.array([d['r'] for d in hpi_dig])[pos_order] + logger.info( + "HPIFIT: %s coils digitized in order %s" + % (len(pos_order), " ".join(str(o + 1) for o in pos_order)) + ) + logger.debug( + "HPIFIT: %s coils accepted: %s" % (len(used), " ".join(str(h) for h in used)) + ) + hpi_rrs = np.array([d["r"] for d in hpi_dig])[pos_order] assert len(hpi_rrs) >= 3 # Fitting errors - hpi_rrs_fit = sorted([d for d in info['hpi_results'][-1]['dig_points']], - key=lambda x: x['ident']) - hpi_rrs_fit = np.array([d['r'] for d in hpi_rrs_fit]) + hpi_rrs_fit = sorted( + [d for d in info["hpi_results"][-1]["dig_points"]], key=lambda x: x["ident"] + ) + hpi_rrs_fit = np.array([d["r"] for d in hpi_rrs_fit]) # hpi_result['dig_points'] are in FIFFV_COORD_UNKNOWN coords, but this # is probably a misnomer because it should be FIFFV_COORD_DEVICE for this # to work - assert hpi_result['coord_trans']['to'] == FIFF.FIFFV_COORD_HEAD - hpi_rrs_fit = apply_trans(hpi_result['coord_trans']['trans'], hpi_rrs_fit) - if 'moments' in hpi_result: - logger.debug('Hpi coil moments (%d %d):' - % hpi_result['moments'].shape[::-1]) - for moment in hpi_result['moments']: + assert hpi_result["coord_trans"]["to"] == FIFF.FIFFV_COORD_HEAD + hpi_rrs_fit = apply_trans(hpi_result["coord_trans"]["trans"], hpi_rrs_fit) + if "moments" in hpi_result: + logger.debug("Hpi coil moments (%d %d):" % hpi_result["moments"].shape[::-1]) + for moment in hpi_result["moments"]: logger.debug("%g %g %g" % tuple(moment)) errors = np.linalg.norm(hpi_rrs - hpi_rrs_fit, axis=1) - logger.debug('HPIFIT errors: %s mm.' - % ', '.join('%0.1f' % (1000. * e) for e in errors)) + logger.debug( + "HPIFIT errors: %s mm." % ", ".join("%0.1f" % (1000.0 * e) for e in errors) + ) if errors.sum() < len(errors) * dist_limit: - logger.info('HPI consistency of isotrak and hpifit is OK.') + logger.info("HPI consistency of isotrak and hpifit is OK.") elif not adjust and (len(used) == len(hpi_dig)): - warn('HPI consistency of isotrak and hpifit is poor.') + warn("HPI consistency of isotrak and hpifit is poor.") else: # adjust HPI coil locations using the hpifit transformation for hi, (err, r_fit) in enumerate(zip(errors, hpi_rrs_fit)): @@ -430,24 +476,33 @@ def _get_hpi_initial_fit(info, adjust=False, verbose=None): d = 1000 * err if not adjust: if err >= dist_limit: - warn('Discrepancy of HPI coil %d isotrak and hpifit is ' - '%.1f mm!' % (hi + 1, d)) + warn( + "Discrepancy of HPI coil %d isotrak and hpifit is " + "%.1f mm!" % (hi + 1, d) + ) elif hi + 1 not in used: if goodness[hi] >= good_limit: - logger.info('Note: HPI coil %d isotrak is adjusted by ' - '%.1f mm!' % (hi + 1, d)) + logger.info( + "Note: HPI coil %d isotrak is adjusted by " + "%.1f mm!" % (hi + 1, d) + ) hpi_rrs[hi] = r_fit else: - warn('Discrepancy of HPI coil %d isotrak and hpifit of ' - '%.1f mm was not adjusted!' % (hi + 1, d)) - logger.debug('HP fitting limits: err = %.1f mm, gval = %.3f.' - % (1000 * dist_limit, good_limit)) + warn( + "Discrepancy of HPI coil %d isotrak and hpifit of " + "%.1f mm was not adjusted!" % (hi + 1, d) + ) + logger.debug( + "HP fitting limits: err = %.1f mm, gval = %.3f." + % (1000 * dist_limit, good_limit) + ) return hpi_rrs.astype(float) -def _magnetic_dipole_objective(x, B, B2, coils, whitener, too_close, - return_moment=False): +def _magnetic_dipole_objective( + x, B, B2, coils, whitener, too_close, return_moment=False +): """Project data onto right eigenvectors of whitened forward.""" fwd = _magnetic_dipole_field_vec(x[np.newaxis], coils, too_close) out, u, s, one = _magnetic_dipole_delta(fwd, whitener, B, B2) @@ -478,22 +533,27 @@ def _magnetic_dipole_delta_multi(whitened_fwd_svd, B, B2): def _fit_magnetic_dipole(B_orig, x0, too_close, whitener, coils, guesses): """Fit a single bit of data (x0 = pos).""" from scipy.optimize import fmin_cobyla + B = np.dot(whitener, B_orig) B2 = np.dot(B, B) - objective = partial(_magnetic_dipole_objective, B=B, B2=B2, - coils=coils, whitener=whitener, - too_close=too_close) + objective = partial( + _magnetic_dipole_objective, + B=B, + B2=B2, + coils=coils, + whitener=whitener, + too_close=too_close, + ) if guesses is not None: res0 = objective(x0) - res = _magnetic_dipole_delta_multi( - guesses['whitened_fwd_svd'], B, B2) - assert res.shape == (guesses['rr'].shape[0],) + res = _magnetic_dipole_delta_multi(guesses["whitened_fwd_svd"], B, B2) + assert res.shape == (guesses["rr"].shape[0],) idx = np.argmin(res) if res[idx] < res0: - x0 = guesses['rr'][idx] + x0 = guesses["rr"][idx] x = fmin_cobyla(objective, x0, (), rhobeg=1e-3, rhoend=1e-5, disp=False) gof, moment = objective(x, return_moment=True) - gof = 1. - gof / B2 + gof = 1.0 - gof / B2 return x, gof, moment @@ -515,7 +575,7 @@ def _fit_chpi_quat(coil_dev_rrs, coil_head_rrs): # XXX someday we could choose to weight these points by their goodness # of fit somehow. quat = _fit_matched_points(coil_dev_rrs, coil_head_rrs)[0] - gof = 1. - _chpi_objective(quat, coil_dev_rrs, coil_head_rrs) / denom + gof = 1.0 - _chpi_objective(quat, coil_dev_rrs, coil_head_rrs) / denom return quat, gof @@ -534,7 +594,7 @@ def _fit_coil_order_dev_head_trans(dev_pnts, head_pnts, bias=True): # equivalent g values. To avoid this, heavily penalize # large rotations. rotation = _angle_between_quats(this_quat[:3], np.zeros(3)) - check_g = g * max(1. - rotation / np.pi, 0) ** 0.25 + check_g = g * max(1.0 - rotation / np.pi, 0) ** 0.25 else: check_g = g if check_g > best_g: @@ -549,61 +609,77 @@ def _fit_coil_order_dev_head_trans(dev_pnts, head_pnts, bias=True): @verbose -def _setup_hpi_amplitude_fitting(info, t_window, remove_aliased=False, - ext_order=1, allow_empty=False, verbose=None): +def _setup_hpi_amplitude_fitting( + info, t_window, remove_aliased=False, ext_order=1, allow_empty=False, verbose=None +): """Generate HPI structure for HPI localization.""" # grab basic info. - on_missing = 'raise' if not allow_empty else 'ignore' + on_missing = "raise" if not allow_empty else "ignore" hpi_freqs, hpi_pick, hpi_ons = get_chpi_info(info, on_missing=on_missing) - _validate_type(t_window, (str, 'numeric'), 't_window') - if info['line_freq'] is not None: - line_freqs = np.arange(info['line_freq'], info['sfreq'] / 3., - info['line_freq']) + _validate_type(t_window, (str, "numeric"), "t_window") + if info["line_freq"] is not None: + line_freqs = np.arange( + info["line_freq"], info["sfreq"] / 3.0, info["line_freq"] + ) else: line_freqs = np.zeros([0]) - logger.info('Line interference frequencies: %s Hz' - % ' '.join(['%d' % lf for lf in line_freqs])) + logger.info( + "Line interference frequencies: %s Hz" + % " ".join(["%d" % lf for lf in line_freqs]) + ) # worry about resampled/filtered data. # What to do e.g. if Raw has been resampled and some of our # HPI freqs would now be aliased - highest = info.get('lowpass') - highest = info['sfreq'] / 2. if highest is None else highest + highest = info.get("lowpass") + highest = info["sfreq"] / 2.0 if highest is None else highest keepers = hpi_freqs <= highest if remove_aliased: hpi_freqs = hpi_freqs[keepers] hpi_ons = hpi_ons[keepers] elif not keepers.all(): - raise RuntimeError('Found HPI frequencies %s above the lowpass ' - '(or Nyquist) frequency %0.1f' - % (hpi_freqs[~keepers].tolist(), highest)) + raise RuntimeError( + "Found HPI frequencies %s above the lowpass " + "(or Nyquist) frequency %0.1f" % (hpi_freqs[~keepers].tolist(), highest) + ) # calculate optimal window length. if isinstance(t_window, str): - _check_option('t_window', t_window, ('auto',), extra='if a string') + _check_option("t_window", t_window, ("auto",), extra="if a string") if len(hpi_freqs): all_freqs = np.concatenate((hpi_freqs, line_freqs)) delta_freqs = np.diff(np.unique(all_freqs)) - t_window = max(5. / all_freqs.min(), 1. / delta_freqs.min()) + t_window = max(5.0 / all_freqs.min(), 1.0 / delta_freqs.min()) else: t_window = 0.2 t_window = float(t_window) if t_window <= 0: - raise ValueError('t_window (%s) must be > 0' % (t_window,)) - logger.info('Using time window: %0.1f ms' % (1000 * t_window,)) - window_nsamp = np.rint(t_window * info['sfreq']).astype(int) - model = _setup_hpi_glm(hpi_freqs, line_freqs, info['sfreq'], window_nsamp) + raise ValueError("t_window (%s) must be > 0" % (t_window,)) + logger.info("Using time window: %0.1f ms" % (1000 * t_window,)) + window_nsamp = np.rint(t_window * info["sfreq"]).astype(int) + model = _setup_hpi_glm(hpi_freqs, line_freqs, info["sfreq"], window_nsamp) inv_model = np.linalg.pinv(model) inv_model_reord = _reorder_inv_model(inv_model, len(hpi_freqs)) proj, proj_op, meg_picks = _setup_ext_proj(info, ext_order) # include mag and grad picks separately, for SNR computations - mag_picks = _picks_to_idx(info, 'mag', allow_empty=True) - grad_picks = _picks_to_idx(info, 'grad', allow_empty=True) + mag_picks = _picks_to_idx(info, "mag", allow_empty=True) + grad_picks = _picks_to_idx(info, "grad", allow_empty=True) # Set up magnetic dipole fits hpi = dict( - meg_picks=meg_picks, mag_picks=mag_picks, grad_picks=grad_picks, - hpi_pick=hpi_pick, model=model, inv_model=inv_model, t_window=t_window, - inv_model_reord=inv_model_reord, on=hpi_ons, n_window=window_nsamp, - proj=proj, proj_op=proj_op, freqs=hpi_freqs, line_freqs=line_freqs) + meg_picks=meg_picks, + mag_picks=mag_picks, + grad_picks=grad_picks, + hpi_pick=hpi_pick, + model=model, + inv_model=inv_model, + t_window=t_window, + inv_model_reord=inv_model_reord, + on=hpi_ons, + n_window=window_nsamp, + proj=proj, + proj_op=proj_op, + freqs=hpi_freqs, + line_freqs=line_freqs, + ) return hpi @@ -613,9 +689,14 @@ def _setup_hpi_glm(hpi_freqs, line_freqs, sfreq, window_nsamp): radians_per_sec = 2 * np.pi * np.arange(window_nsamp, dtype=float) / sfreq f_t = hpi_freqs[np.newaxis, :] * radians_per_sec[:, np.newaxis] l_t = line_freqs[np.newaxis, :] * radians_per_sec[:, np.newaxis] - model = [np.sin(f_t), np.cos(f_t), # hpi freqs - np.sin(l_t), np.cos(l_t), # line freqs - slope, np.ones_like(slope)] # drift, DC + model = [ + np.sin(f_t), + np.cos(f_t), # hpi freqs + np.sin(l_t), + np.cos(l_t), # line freqs + slope, + np.ones_like(slope), + ] # drift, DC return np.hstack(model) @@ -628,34 +709,40 @@ def _reorder_inv_model(inv_model, n_freqs): def _setup_ext_proj(info, ext_order): from scipy import linalg - meg_picks = pick_types(info, meg=True, eeg=False, exclude='bads') + + meg_picks = pick_types(info, meg=True, eeg=False, exclude="bads") info = pick_info(_simplify_info(info), meg_picks) # makes a copy _, _, _, _, mag_or_fine = _get_mf_picks_fix_mags( - info, int_order=0, ext_order=ext_order, ignore_ref=True, - verbose='error') - mf_coils = _prep_mf_coils(info, verbose='error') + info, int_order=0, ext_order=ext_order, ignore_ref=True, verbose="error" + ) + mf_coils = _prep_mf_coils(info, verbose="error") ext = _sss_basis( - dict(origin=(0., 0., 0.), int_order=0, ext_order=ext_order), - mf_coils).T + dict(origin=(0.0, 0.0, 0.0), int_order=0, ext_order=ext_order), mf_coils + ).T out_removes = _regularize_out(0, 1, mag_or_fine, []) ext = ext[~np.in1d(np.arange(len(ext)), out_removes)] ext = linalg.orth(ext.T).T assert ext.shape[1] == len(meg_picks) proj = Projection( - kind=FIFF.FIFFV_PROJ_ITEM_HOMOG_FIELD, desc='SSS', active=False, - data=dict(data=ext, ncol=info['nchan'], col_names=info['ch_names'], - nrow=len(ext))) + kind=FIFF.FIFFV_PROJ_ITEM_HOMOG_FIELD, + desc="SSS", + active=False, + data=dict( + data=ext, ncol=info["nchan"], col_names=info["ch_names"], nrow=len(ext) + ), + ) with info._unlock(): - info['projs'] = [proj] + info["projs"] = [proj] proj_op, _ = setup_proj( - info, add_eeg_ref=False, activate=False, verbose=_verbose_safe_false()) + info, add_eeg_ref=False, activate=False, verbose=_verbose_safe_false() + ) assert proj_op.shape == (len(meg_picks),) * 2 return proj, proj_op, meg_picks def _time_prefix(fit_time): """Format log messages.""" - return (' t=%0.3f:' % fit_time).ljust(17) + return (" t=%0.3f:" % fit_time).ljust(17) def _fit_chpi_amplitudes(raw, time_sl, hpi, snr=False): @@ -672,32 +759,43 @@ def _fit_chpi_amplitudes(raw, time_sl, hpi, snr=False): # No need to detrend the data because our model has a DC term with use_log_level(False): # loads good channels - this_data = raw[hpi['meg_picks'], time_sl][0] + this_data = raw[hpi["meg_picks"], time_sl][0] # which HPI coils to use - if hpi['hpi_pick'] is not None: + if hpi["hpi_pick"] is not None: with use_log_level(False): # loads hpi_stim channel - chpi_data = raw[hpi['hpi_pick'], time_sl][0] + chpi_data = raw[hpi["hpi_pick"], time_sl][0] - ons = (np.round(chpi_data).astype(np.int64) & - hpi['on'][:, np.newaxis]).astype(bool) + ons = (np.round(chpi_data).astype(np.int64) & hpi["on"][:, np.newaxis]).astype( + bool + ) n_on = ons.all(axis=-1).sum(axis=0) if not (n_on >= 3).all(): return None if snr: return _fast_fit_snr( - this_data, len(hpi['freqs']), hpi['model'], hpi['inv_model'], - hpi['mag_picks'], hpi['grad_picks']) - return _fast_fit(this_data, hpi['proj_op'], len(hpi['freqs']), - hpi['model'], hpi['inv_model_reord']) + this_data, + len(hpi["freqs"]), + hpi["model"], + hpi["inv_model"], + hpi["mag_picks"], + hpi["grad_picks"], + ) + return _fast_fit( + this_data, + hpi["proj_op"], + len(hpi["freqs"]), + hpi["model"], + hpi["inv_model_reord"], + ) @jit() def _fast_fit(this_data, proj, n_freqs, model, inv_model_reord): # first or last window if this_data.shape[1] != model.shape[0]: - model = model[:this_data.shape[1]] + model = model[: this_data.shape[1]] inv_model_reord = _reorder_inv_model(np.linalg.pinv(model), n_freqs) proj_data = proj @ this_data X = inv_model_reord @ proj_data.T @@ -705,7 +803,7 @@ def _fast_fit(this_data, proj, n_freqs, model, inv_model_reord): sin_fit = np.zeros((n_freqs, X.shape[1])) for fi in range(n_freqs): # use SVD across all sensors to estimate the sinusoid phase - u, s, vt = np.linalg.svd(X[2 * fi:2 * fi + 2], full_matrices=False) + u, s, vt = np.linalg.svd(X[2 * fi : 2 * fi + 2], full_matrices=False) # the first component holds the predominant phase direction # (so ignore the second, effectively doing s[1] = 0): sin_fit[fi] = vt[0] * s[0] @@ -716,11 +814,11 @@ def _fast_fit(this_data, proj, n_freqs, model, inv_model_reord): def _fast_fit_snr(this_data, n_freqs, model, inv_model, mag_picks, grad_picks): # first or last window if this_data.shape[1] != model.shape[0]: - model = model[:this_data.shape[1]] + model = model[: this_data.shape[1]] inv_model = np.linalg.pinv(model) coefs = np.ascontiguousarray(inv_model) @ np.ascontiguousarray(this_data.T) # average sin & cos terms (special property of sinusoids: power=A²/2) - hpi_power = (coefs[:n_freqs] ** 2 + coefs[n_freqs:(2 * n_freqs)] ** 2) / 2 + hpi_power = (coefs[:n_freqs] ** 2 + coefs[n_freqs : (2 * n_freqs)] ** 2) / 2 resid = this_data - np.ascontiguousarray((model @ coefs).T) # can't use np.var(..., axis=1) with Numba, so do it manually: resid_mean = np.atleast_2d(resid.sum(axis=1) / resid.shape[1]).T @@ -741,59 +839,70 @@ def _fast_fit_snr(this_data, n_freqs, model, inv_model, mag_picks, grad_picks): def _check_chpi_param(chpi_, name): - if name == 'chpi_locs': + if name == "chpi_locs": want_ndims = dict(times=1, rrs=3, moments=3, gofs=2) extra_keys = list() else: - assert name == 'chpi_amplitudes' + assert name == "chpi_amplitudes" want_ndims = dict(times=1, slopes=3) - extra_keys = ['proj'] + extra_keys = ["proj"] _validate_type(chpi_, dict, name) want_keys = list(want_ndims.keys()) + extra_keys if set(want_keys).symmetric_difference(chpi_): - raise ValueError('%s must be a dict with entries %s, got %s' - % (name, want_keys, sorted(chpi_.keys()))) + raise ValueError( + "%s must be a dict with entries %s, got %s" + % (name, want_keys, sorted(chpi_.keys())) + ) n_times = None for key, want_ndim in want_ndims.items(): - key_str = '%s[%s]' % (name, key) + key_str = "%s[%s]" % (name, key) val = chpi_[key] _validate_type(val, np.ndarray, key_str) shape = val.shape if val.ndim != want_ndim: - raise ValueError('%s must have ndim=%d, got %d' - % (key_str, want_ndim, val.ndim)) - if n_times is None and key != 'proj': + raise ValueError( + "%s must have ndim=%d, got %d" % (key_str, want_ndim, val.ndim) + ) + if n_times is None and key != "proj": n_times = shape[0] - if n_times != shape[0] and key != 'proj': - raise ValueError('%s have inconsistent number of time ' - 'points in %s' % (name, want_keys)) - if name == 'chpi_locs': - n_coils = chpi_['rrs'].shape[1] - for key in ('gofs', 'moments'): + if n_times != shape[0] and key != "proj": + raise ValueError( + "%s have inconsistent number of time " + "points in %s" % (name, want_keys) + ) + if name == "chpi_locs": + n_coils = chpi_["rrs"].shape[1] + for key in ("gofs", "moments"): val = chpi_[key] if val.shape[1] != n_coils: - raise ValueError('chpi_locs["rrs"] had values for %d coils but' - ' chpi_locs["%s"] had values for %d coils' - % (n_coils, key, val.shape[1])) - for key in ('rrs', 'moments'): + raise ValueError( + 'chpi_locs["rrs"] had values for %d coils but' + ' chpi_locs["%s"] had values for %d coils' + % (n_coils, key, val.shape[1]) + ) + for key in ("rrs", "moments"): val = chpi_[key] if val.shape[2] != 3: - raise ValueError('chpi_locs["%s"].shape[2] must be 3, got ' - 'shape %s' % (key, shape)) + raise ValueError( + 'chpi_locs["%s"].shape[2] must be 3, got ' "shape %s" % (key, shape) + ) else: - assert name == 'chpi_amplitudes' - slopes, proj = chpi_['slopes'], chpi_['proj'] + assert name == "chpi_amplitudes" + slopes, proj = chpi_["slopes"], chpi_["proj"] _validate_type(proj, Projection, 'chpi_amplitudes["proj"]') - n_ch = len(proj['data']['col_names']) + n_ch = len(proj["data"]["col_names"]) if slopes.shape[0] != n_times or slopes.shape[2] != n_ch: - raise ValueError('slopes must have shape[0]==%d and shape[2]==%d,' - ' got shape %s' % (n_times, n_ch, slopes.shape)) + raise ValueError( + "slopes must have shape[0]==%d and shape[2]==%d," + " got shape %s" % (n_times, n_ch, slopes.shape) + ) @verbose -def compute_head_pos(info, chpi_locs, dist_limit=0.005, gof_limit=0.98, - adjust_dig=False, verbose=None): +def compute_head_pos( + info, chpi_locs, dist_limit=0.005, gof_limit=0.98, adjust_dig=False, verbose=None +): """Compute time-varying head positions. Parameters @@ -825,29 +934,30 @@ def compute_head_pos(info, chpi_locs, dist_limit=0.005, gof_limit=0.98, ----- .. versionadded:: 0.20 """ - _check_chpi_param(chpi_locs, 'chpi_locs') - _validate_type(info, Info, 'info') - hpi_dig_head_rrs = _get_hpi_initial_fit(info, adjust=adjust_dig, - verbose='error') + _check_chpi_param(chpi_locs, "chpi_locs") + _validate_type(info, Info, "info") + hpi_dig_head_rrs = _get_hpi_initial_fit(info, adjust=adjust_dig, verbose="error") n_coils = len(hpi_dig_head_rrs) - coil_dev_rrs = apply_trans(invert_transform(info['dev_head_t']), - hpi_dig_head_rrs) - dev_head_t = info['dev_head_t']['trans'] + coil_dev_rrs = apply_trans(invert_transform(info["dev_head_t"]), hpi_dig_head_rrs) + dev_head_t = info["dev_head_t"]["trans"] pos_0 = dev_head_t[:3, 3] - last = dict(quat_fit_time=-0.1, coil_dev_rrs=coil_dev_rrs, - quat=np.concatenate([rot_to_quat(dev_head_t[:3, :3]), - dev_head_t[:3, 3]])) + last = dict( + quat_fit_time=-0.1, + coil_dev_rrs=coil_dev_rrs, + quat=np.concatenate([rot_to_quat(dev_head_t[:3, :3]), dev_head_t[:3, 3]]), + ) del coil_dev_rrs quats = [] for fit_time, this_coil_dev_rrs, g_coils in zip( - *(chpi_locs[key] for key in ('times', 'rrs', 'gofs'))): + *(chpi_locs[key] for key in ("times", "rrs", "gofs")) + ): use_idx = np.where(g_coils >= gof_limit)[0] # # 1. Check number of good ones # if len(use_idx) < 3: - gofs = ', '.join(f"{g:0.2f}" for g in g_coils) + gofs = ", ".join(f"{g:0.2f}" for g in g_coils) warn( f"{_time_prefix(fit_time)}{len(use_idx)}/{n_coils} " "good HPI fits, cannot determine the transformation " @@ -861,7 +971,8 @@ def compute_head_pos(info, chpi_locs, dist_limit=0.005, gof_limit=0.98, # positions) iteratively using different sets of coils. # this_quat, g, use_idx = _fit_chpi_quat_subset( - this_coil_dev_rrs, hpi_dig_head_rrs, use_idx) + this_coil_dev_rrs, hpi_dig_head_rrs, use_idx + ) # # 3. Stop if < 3 good @@ -873,64 +984,87 @@ def compute_head_pos(info, chpi_locs, dist_limit=0.005, gof_limit=0.98, errs = np.linalg.norm(hpi_dig_head_rrs - est_coil_head_rrs, axis=1) n_good = ((g_coils >= gof_limit) & (errs < dist_limit)).sum() if n_good < 3: - warn(_time_prefix(fit_time) + '%s/%s good HPI fits, cannot ' - 'determine the transformation (%s mm/GOF)!' - % (n_good, n_coils, - ', '.join(f'{1000 * e:0.1f}::{g:0.2f}' - for e, g in zip(errs, g_coils)))) + warn( + _time_prefix(fit_time) + "%s/%s good HPI fits, cannot " + "determine the transformation (%s mm/GOF)!" + % ( + n_good, + n_coils, + ", ".join( + f"{1000 * e:0.1f}::{g:0.2f}" for e, g in zip(errs, g_coils) + ), + ) + ) continue # velocities, in device coords, of HPI coils - dt = fit_time - last['quat_fit_time'] - vs = tuple(1000. * np.linalg.norm(last['coil_dev_rrs'] - - this_coil_dev_rrs, axis=1) / dt) - logger.info(_time_prefix(fit_time) + - ('%s/%s good HPI fits, movements [mm/s] = ' + - ' / '.join(['% 8.1f'] * n_coils)) - % ((n_good, n_coils) + vs)) + dt = fit_time - last["quat_fit_time"] + vs = tuple( + 1000.0 + * np.linalg.norm(last["coil_dev_rrs"] - this_coil_dev_rrs, axis=1) + / dt + ) + logger.info( + _time_prefix(fit_time) + + ( + "%s/%s good HPI fits, movements [mm/s] = " + + " / ".join(["% 8.1f"] * n_coils) + ) + % ((n_good, n_coils) + vs) + ) # Log results # MaxFilter averages over a 200 ms window for display, but we don't for ii in range(n_coils): if ii in use_idx: - start, end = ' ', '/' + start, end = " ", "/" else: - start, end = '(', ')' - log_str = (' ' + start + - '{0:6.1f} {1:6.1f} {2:6.1f} / ' + - '{3:6.1f} {4:6.1f} {5:6.1f} / ' + - 'g = {6:0.3f} err = {7:4.1f} ' + - end) - vals = np.concatenate((1000 * hpi_dig_head_rrs[ii], - 1000 * est_coil_head_rrs[ii], - [g_coils[ii], 1000 * errs[ii]])) + start, end = "(", ")" + log_str = ( + " " + + start + + "{0:6.1f} {1:6.1f} {2:6.1f} / " + + "{3:6.1f} {4:6.1f} {5:6.1f} / " + + "g = {6:0.3f} err = {7:4.1f} " + + end + ) + vals = np.concatenate( + ( + 1000 * hpi_dig_head_rrs[ii], + 1000 * est_coil_head_rrs[ii], + [g_coils[ii], 1000 * errs[ii]], + ) + ) if len(use_idx) >= 3: if ii <= 2: - log_str += '{8:6.3f} {9:6.3f} {10:6.3f}' - vals = np.concatenate( - (vals, this_dev_head_t[ii, :3])) + log_str += "{8:6.3f} {9:6.3f} {10:6.3f}" + vals = np.concatenate((vals, this_dev_head_t[ii, :3])) elif ii == 3: - log_str += '{8:6.1f} {9:6.1f} {10:6.1f}' - vals = np.concatenate( - (vals, this_dev_head_t[:3, 3] * 1000.)) + log_str += "{8:6.1f} {9:6.1f} {10:6.1f}" + vals = np.concatenate((vals, this_dev_head_t[:3, 3] * 1000.0)) logger.debug(log_str.format(*vals)) # resulting errors in head coil positions - d = np.linalg.norm(last['quat'][3:] - this_quat[3:]) # m - r = _angle_between_quats(last['quat'][:3], this_quat[:3]) / dt + d = np.linalg.norm(last["quat"][3:] - this_quat[3:]) # m + r = _angle_between_quats(last["quat"][:3], this_quat[:3]) / dt v = d / dt # m/s d = 100 * np.linalg.norm(this_quat[3:] - pos_0) # dis from 1st - logger.debug(' #t = %0.3f, #e = %0.2f cm, #g = %0.3f, ' - '#v = %0.2f cm/s, #r = %0.2f rad/s, #d = %0.2f cm' - % (fit_time, 100 * errs.mean(), g, 100 * v, r, d)) - logger.debug(' #t = %0.3f, #q = %s ' - % (fit_time, ' '.join(map('{:8.5f}'.format, this_quat)))) - - quats.append(np.concatenate(([fit_time], this_quat, [g], - [errs[use_idx].mean()], [v]))) - last['quat_fit_time'] = fit_time - last['quat'] = this_quat - last['coil_dev_rrs'] = this_coil_dev_rrs + logger.debug( + " #t = %0.3f, #e = %0.2f cm, #g = %0.3f, " + "#v = %0.2f cm/s, #r = %0.2f rad/s, #d = %0.2f cm" + % (fit_time, 100 * errs.mean(), g, 100 * v, r, d) + ) + logger.debug( + " #t = %0.3f, #q = %s " + % (fit_time, " ".join(map("{:8.5f}".format, this_quat))) + ) + + quats.append( + np.concatenate(([fit_time], this_quat, [g], [errs[use_idx].mean()], [v])) + ) + last["quat_fit_time"] = fit_time + last["quat"] = this_quat + last["coil_dev_rrs"] = this_coil_dev_rrs quats = np.array(quats, np.float64) quats = np.zeros((0, 10)) if quats.size == 0 else quats return quats @@ -941,9 +1075,10 @@ def _fit_chpi_quat_subset(coil_dev_rrs, coil_head_rrs, use_idx): out_idx = use_idx.copy() if len(use_idx) > 3: # try dropping one (recursively) for di in range(len(use_idx)): - this_use_idx = list(use_idx[:di]) + list(use_idx[di + 1:]) + this_use_idx = list(use_idx[:di]) + list(use_idx[di + 1 :]) this_quat, this_g, this_use_idx = _fit_chpi_quat_subset( - coil_dev_rrs, coil_head_rrs, this_use_idx) + coil_dev_rrs, coil_head_rrs, this_use_idx + ) if this_g > g: quat, g, out_idx = this_quat, this_g, this_use_idx return quat, g, np.array(out_idx, int) @@ -956,8 +1091,9 @@ def _unit_quat_constraint(x): @verbose -def compute_chpi_snr(raw, t_step_min=0.01, t_window='auto', ext_order=1, - tmin=0, tmax=None, verbose=None): +def compute_chpi_snr( + raw, t_step_min=0.01, t_window="auto", ext_order=1, tmin=0, tmax=None, verbose=None +): """Compute time-varying estimates of cHPI SNR. Parameters @@ -988,13 +1124,15 @@ def compute_chpi_snr(raw, t_step_min=0.01, t_window='auto', ext_order=1, ----- .. versionadded:: 0.24 """ - return _compute_chpi_amp_or_snr(raw, t_step_min, t_window, ext_order, - tmin, tmax, verbose, snr=True) + return _compute_chpi_amp_or_snr( + raw, t_step_min, t_window, ext_order, tmin, tmax, verbose, snr=True + ) @verbose -def compute_chpi_amplitudes(raw, t_step_min=0.01, t_window='auto', - ext_order=1, tmin=0, tmax=None, verbose=None): +def compute_chpi_amplitudes( + raw, t_step_min=0.01, t_window="auto", ext_order=1, tmin=0, tmax=None, verbose=None +): """Compute time-varying cHPI amplitudes. Parameters @@ -1040,13 +1178,21 @@ def compute_chpi_amplitudes(raw, t_step_min=0.01, t_window='auto', .. versionadded:: 0.20 """ - return _compute_chpi_amp_or_snr(raw, t_step_min, t_window, ext_order, - tmin, tmax, verbose) + return _compute_chpi_amp_or_snr( + raw, t_step_min, t_window, ext_order, tmin, tmax, verbose + ) -def _compute_chpi_amp_or_snr(raw, t_step_min=0.01, t_window='auto', - ext_order=1, tmin=0, tmax=None, verbose=None, - snr=False): +def _compute_chpi_amp_or_snr( + raw, + t_step_min=0.01, + t_window="auto", + ext_order=1, + tmin=0, + tmax=None, + verbose=None, + snr=False, +): """Compute cHPI amplitude or SNR. See compute_chpi_amplitudes for parameter descriptions. One additional @@ -1055,42 +1201,44 @@ def _compute_chpi_amp_or_snr(raw, t_step_min=0.01, t_window='auto', """ hpi = _setup_hpi_amplitude_fitting(raw.info, t_window, ext_order=ext_order) tmin, tmax = raw._tmin_tmax_to_start_stop(tmin, tmax) - tmin = tmin / raw.info['sfreq'] - tmax = tmax / raw.info['sfreq'] - need_win = hpi['t_window'] / 2. - fit_idxs = raw.time_as_index(np.arange( - tmin + need_win, tmax, t_step_min), use_rounding=True) - logger.info('Fitting %d HPI coil locations at up to %s time points ' - '(%0.1f s duration)' - % (len(hpi['freqs']), len(fit_idxs), tmax - tmin)) + tmin = tmin / raw.info["sfreq"] + tmax = tmax / raw.info["sfreq"] + need_win = hpi["t_window"] / 2.0 + fit_idxs = raw.time_as_index( + np.arange(tmin + need_win, tmax, t_step_min), use_rounding=True + ) + logger.info( + "Fitting %d HPI coil locations at up to %s time points " + "(%0.1f s duration)" % (len(hpi["freqs"]), len(fit_idxs), tmax - tmin) + ) del tmin, tmax sin_fits = dict() - sin_fits['proj'] = hpi['proj'] - sin_fits['times'] = np.round(fit_idxs + raw.first_samp - - hpi['n_window'] / 2.) / raw.info['sfreq'] - n_times = len(sin_fits['times']) - n_freqs = len(hpi['freqs']) - n_chans = len(sin_fits['proj']['data']['col_names']) + sin_fits["proj"] = hpi["proj"] + sin_fits["times"] = ( + np.round(fit_idxs + raw.first_samp - hpi["n_window"] / 2.0) / raw.info["sfreq"] + ) + n_times = len(sin_fits["times"]) + n_freqs = len(hpi["freqs"]) + n_chans = len(sin_fits["proj"]["data"]["col_names"]) if snr: - del sin_fits['proj'] - sin_fits['freqs'] = hpi['freqs'] + del sin_fits["proj"] + sin_fits["freqs"] = hpi["freqs"] ch_types = raw.get_channel_types() - grad_offset = 3 if 'mag' in ch_types else 0 - for ch_type in ('mag', 'grad'): + grad_offset = 3 if "mag" in ch_types else 0 + for ch_type in ("mag", "grad"): if ch_type in ch_types: - for key in ('snr', 'power', 'resid'): - cols = 1 if key == 'resid' else n_freqs - sin_fits[f'{ch_type}_{key}'] = np.empty((n_times, cols)) + for key in ("snr", "power", "resid"): + cols = 1 if key == "resid" else n_freqs + sin_fits[f"{ch_type}_{key}"] = np.empty((n_times, cols)) else: - sin_fits['slopes'] = np.empty((n_times, n_freqs, n_chans)) + sin_fits["slopes"] = np.empty((n_times, n_freqs, n_chans)) message = f"cHPI {'SNRs' if snr else 'amplitudes'}" for mi, midpt in enumerate(ProgressBar(fit_idxs, mesg=message)): # # 0. determine samples to fit. # - time_sl = midpt - hpi['n_window'] // 2 - time_sl = slice(max(time_sl, 0), - min(time_sl + hpi['n_window'], len(raw.times))) + time_sl = midpt - hpi["n_window"] // 2 + time_sl = slice(max(time_sl, 0), min(time_sl + hpi["n_window"], len(raw.times))) # # 1. Fit amplitudes for each channel from each of the N sinusoids @@ -1103,22 +1251,28 @@ def _compute_chpi_amp_or_snr(raw, t_step_min=0.01, t_window='auto', # is returned as a (tiled) vector (again, because Numba) so that's # why below we take amps_or_snrs[0, 2] instead of [:, 2] ch_types = raw.get_channel_types() - if 'mag' in ch_types: - sin_fits['mag_snr'][mi] = amps_or_snrs[:, 0] # SNR - sin_fits['mag_power'][mi] = amps_or_snrs[:, 1] # mean power - sin_fits['mag_resid'][mi] = amps_or_snrs[0, 2] # mean resid - if 'grad' in ch_types: - sin_fits['grad_snr'][mi] = amps_or_snrs[:, grad_offset] - sin_fits['grad_power'][mi] = amps_or_snrs[:, grad_offset + 1] - sin_fits['grad_resid'][mi] = amps_or_snrs[0, grad_offset + 2] + if "mag" in ch_types: + sin_fits["mag_snr"][mi] = amps_or_snrs[:, 0] # SNR + sin_fits["mag_power"][mi] = amps_or_snrs[:, 1] # mean power + sin_fits["mag_resid"][mi] = amps_or_snrs[0, 2] # mean resid + if "grad" in ch_types: + sin_fits["grad_snr"][mi] = amps_or_snrs[:, grad_offset] + sin_fits["grad_power"][mi] = amps_or_snrs[:, grad_offset + 1] + sin_fits["grad_resid"][mi] = amps_or_snrs[0, grad_offset + 2] else: - sin_fits['slopes'][mi] = amps_or_snrs + sin_fits["slopes"][mi] = amps_or_snrs return sin_fits @verbose -def compute_chpi_locs(info, chpi_amplitudes, t_step_max=1., too_close='raise', - adjust_dig=False, verbose=None): +def compute_chpi_locs( + info, + chpi_amplitudes, + t_step_max=1.0, + too_close="raise", + adjust_dig=False, + verbose=None, +): """Compute locations of each cHPI coils over time. Parameters @@ -1163,19 +1317,18 @@ def compute_chpi_locs(info, chpi_amplitudes, t_step_max=1., too_close='raise', .. versionadded:: 0.20 """ # Set up magnetic dipole fits - _check_option('too_close', too_close, ['raise', 'warning', 'info']) - _check_chpi_param(chpi_amplitudes, 'chpi_amplitudes') - _validate_type(info, Info, 'info') + _check_option("too_close", too_close, ["raise", "warning", "info"]) + _check_chpi_param(chpi_amplitudes, "chpi_amplitudes") + _validate_type(info, Info, "info") sin_fits = chpi_amplitudes # use the old name below del chpi_amplitudes - proj = sin_fits['proj'] - meg_picks = pick_channels( - info['ch_names'], proj['data']['col_names'], ordered=True) + proj = sin_fits["proj"] + meg_picks = pick_channels(info["ch_names"], proj["data"]["col_names"], ordered=True) info = pick_info(info, meg_picks) # makes a copy with info._unlock(): - info['projs'] = [proj] + info["projs"] = [proj] del meg_picks, proj - meg_coils = _concatenate_coils(_create_meg_coils(info['chs'], 'accurate')) + meg_coils = _concatenate_coils(_create_meg_coils(info["chs"], "accurate")) # Set up external model for interference suppression safe_false = _verbose_safe_false() @@ -1184,10 +1337,13 @@ def compute_chpi_locs(info, chpi_amplitudes, t_step_max=1., too_close='raise', # Make some location guesses (1 cm grid) R = np.linalg.norm(meg_coils[0], axis=1).min() - guesses = _make_guesses(dict(R=R, r0=np.zeros(3)), 0.01, 0., 0.005, - verbose=safe_false)[0]['rr'] - logger.info('Computing %d HPI location guesses (1 cm grid in a %0.1f cm ' - 'sphere)' % (len(guesses), R * 100)) + guesses = _make_guesses( + dict(R=R, r0=np.zeros(3)), 0.01, 0.0, 0.005, verbose=safe_false + )[0]["rr"] + logger.info( + "Computing %d HPI location guesses (1 cm grid in a %0.1f cm " + "sphere)" % (len(guesses), R * 100) + ) fwd = _magnetic_dipole_field_vec(guesses, meg_coils, too_close) fwd = np.dot(fwd, whitener.T) fwd.shape = (guesses.shape[0], 3, -1) @@ -1195,51 +1351,58 @@ def compute_chpi_locs(info, chpi_amplitudes, t_step_max=1., too_close='raise', guesses = dict(rr=guesses, whitened_fwd_svd=fwd) del fwd, R - iter_ = list(zip(sin_fits['times'], sin_fits['slopes'])) + iter_ = list(zip(sin_fits["times"], sin_fits["slopes"])) chpi_locs = dict(times=[], rrs=[], gofs=[], moments=[]) # setup last iteration structure hpi_dig_dev_rrs = apply_trans( - invert_transform(info['dev_head_t'])['trans'], - _get_hpi_initial_fit(info, adjust=adjust_dig)) - last = dict(sin_fit=None, coil_fit_time=sin_fits['times'][0] - 1, - coil_dev_rrs=hpi_dig_dev_rrs) + invert_transform(info["dev_head_t"])["trans"], + _get_hpi_initial_fit(info, adjust=adjust_dig), + ) + last = dict( + sin_fit=None, + coil_fit_time=sin_fits["times"][0] - 1, + coil_dev_rrs=hpi_dig_dev_rrs, + ) n_hpi = len(hpi_dig_dev_rrs) del hpi_dig_dev_rrs - for fit_time, sin_fit in ProgressBar(iter_, mesg='cHPI locations '): + for fit_time, sin_fit in ProgressBar(iter_, mesg="cHPI locations "): # skip this window if bad if not np.isfinite(sin_fit).all(): continue # check if data has sufficiently changed - if last['sin_fit'] is not None: # first iteration + if last["sin_fit"] is not None: # first iteration corrs = np.array( - [np.corrcoef(s, lst)[0, 1] - for s, lst in zip(sin_fit, last['sin_fit'])]) + [np.corrcoef(s, lst)[0, 1] for s, lst in zip(sin_fit, last["sin_fit"])] + ) corrs *= corrs # check to see if we need to continue - if fit_time - last['coil_fit_time'] <= t_step_max - 1e-7 and \ - (corrs > 0.98).sum() >= 3: + if ( + fit_time - last["coil_fit_time"] <= t_step_max - 1e-7 + and (corrs > 0.98).sum() >= 3 + ): # don't need to refit data continue # update 'last' sin_fit *before* inplace sign mult - last['sin_fit'] = sin_fit.copy() + last["sin_fit"] = sin_fit.copy() # # 2. Fit magnetic dipole for each coil to obtain coil positions # in device coordinates # - coil_fits = [_fit_magnetic_dipole(f, x0, too_close, whitener, - meg_coils, guesses) - for f, x0 in zip(sin_fit, last['coil_dev_rrs'])] + coil_fits = [ + _fit_magnetic_dipole(f, x0, too_close, whitener, meg_coils, guesses) + for f, x0 in zip(sin_fit, last["coil_dev_rrs"]) + ] rrs, gofs, moments = zip(*coil_fits) - chpi_locs['times'].append(fit_time) - chpi_locs['rrs'].append(rrs) - chpi_locs['gofs'].append(gofs) - chpi_locs['moments'].append(moments) - last['coil_fit_time'] = fit_time - last['coil_dev_rrs'] = rrs - n_times = len(chpi_locs['times']) + chpi_locs["times"].append(fit_time) + chpi_locs["rrs"].append(rrs) + chpi_locs["gofs"].append(gofs) + chpi_locs["moments"].append(moments) + last["coil_fit_time"] = fit_time + last["coil_dev_rrs"] = rrs + n_times = len(chpi_locs["times"]) shapes = dict( times=(n_times,), rrs=(n_times, n_hpi, 3), @@ -1254,17 +1417,32 @@ def compute_chpi_locs(info, chpi_amplitudes, t_step_max=1., too_close='raise', def _chpi_locs_to_times_dig(chpi_locs): """Reformat chpi_locs as list of dig (dict).""" dig = list() - for rrs, gofs in zip(*(chpi_locs[key] for key in ('rrs', 'gofs'))): - dig.append([{'r': rr, 'ident': idx, 'gof': gof, - 'kind': FIFF.FIFFV_POINT_HPI, - 'coord_frame': FIFF.FIFFV_COORD_DEVICE} - for idx, (rr, gof) in enumerate(zip(rrs, gofs), 1)]) - return chpi_locs['times'], dig + for rrs, gofs in zip(*(chpi_locs[key] for key in ("rrs", "gofs"))): + dig.append( + [ + { + "r": rr, + "ident": idx, + "gof": gof, + "kind": FIFF.FIFFV_POINT_HPI, + "coord_frame": FIFF.FIFFV_COORD_DEVICE, + } + for idx, (rr, gof) in enumerate(zip(rrs, gofs), 1) + ] + ) + return chpi_locs["times"], dig @verbose -def filter_chpi(raw, include_line=True, t_step=0.01, t_window='auto', - ext_order=1, allow_line_only=False, verbose=None): +def filter_chpi( + raw, + include_line=True, + t_step=0.01, + t_window="auto", + ext_order=1, + allow_line_only=False, + verbose=None, +): """Remove cHPI and line noise from data. .. note:: This function will only work properly if cHPI was on @@ -1301,73 +1479,80 @@ def filter_chpi(raw, include_line=True, t_step=0.01, t_window='auto', .. versionadded:: 0.12 """ - _validate_type(raw, BaseRaw, 'raw') + _validate_type(raw, BaseRaw, "raw") if not raw.preload: - raise RuntimeError('raw data must be preloaded') + raise RuntimeError("raw data must be preloaded") t_step = float(t_step) if t_step <= 0: - raise ValueError('t_step (%s) must be > 0' % (t_step,)) - n_step = int(np.ceil(t_step * raw.info['sfreq'])) - if include_line and raw.info['line_freq'] is None: - raise RuntimeError('include_line=True but raw.info["line_freq"] is ' - 'None, consider setting it to the line frequency') + raise ValueError("t_step (%s) must be > 0" % (t_step,)) + n_step = int(np.ceil(t_step * raw.info["sfreq"])) + if include_line and raw.info["line_freq"] is None: + raise RuntimeError( + 'include_line=True but raw.info["line_freq"] is ' + "None, consider setting it to the line frequency" + ) hpi = _setup_hpi_amplitude_fitting( - raw.info, t_window, remove_aliased=True, ext_order=ext_order, - allow_empty=allow_line_only, verbose=_verbose_safe_false()) + raw.info, + t_window, + remove_aliased=True, + ext_order=ext_order, + allow_empty=allow_line_only, + verbose=_verbose_safe_false(), + ) - fit_idxs = np.arange(0, len(raw.times) + hpi['n_window'] // 2, n_step) - n_freqs = len(hpi['freqs']) + fit_idxs = np.arange(0, len(raw.times) + hpi["n_window"] // 2, n_step) + n_freqs = len(hpi["freqs"]) n_remove = 2 * n_freqs meg_picks = pick_types(raw.info, meg=True, exclude=()) # filter all chs n_times = len(raw.times) - msg = 'Removing %s cHPI' % n_freqs + msg = "Removing %s cHPI" % n_freqs if include_line: - n_remove += 2 * len(hpi['line_freqs']) - msg += ' and %s line harmonic' % len(hpi['line_freqs']) - msg += ' frequencies from %s MEG channels' % len(meg_picks) + n_remove += 2 * len(hpi["line_freqs"]) + msg += " and %s line harmonic" % len(hpi["line_freqs"]) + msg += " frequencies from %s MEG channels" % len(meg_picks) - recon = np.dot(hpi['model'][:, :n_remove], hpi['inv_model'][:n_remove]).T + recon = np.dot(hpi["model"][:, :n_remove], hpi["inv_model"][:n_remove]).T logger.info(msg) chunks = list() # the chunks to subtract last_endpt = 0 - pb = ProgressBar(fit_idxs, mesg='Filtering') + pb = ProgressBar(fit_idxs, mesg="Filtering") for ii, midpt in enumerate(pb): - left_edge = midpt - hpi['n_window'] // 2 - time_sl = slice(max(left_edge, 0), - min(left_edge + hpi['n_window'], len(raw.times))) + left_edge = midpt - hpi["n_window"] // 2 + time_sl = slice( + max(left_edge, 0), min(left_edge + hpi["n_window"], len(raw.times)) + ) this_len = time_sl.stop - time_sl.start - if this_len == hpi['n_window']: + if this_len == hpi["n_window"]: this_recon = recon else: # first or last window - model = hpi['model'][:this_len] + model = hpi["model"][:this_len] inv_model = np.linalg.pinv(model) this_recon = np.dot(model[:, :n_remove], inv_model[:n_remove]).T this_data = raw._data[meg_picks, time_sl] subt_pt = min(midpt + n_step, n_times) if last_endpt != subt_pt: - fit_left_edge = left_edge - time_sl.start + hpi['n_window'] // 2 - fit_sl = slice(fit_left_edge, - fit_left_edge + (subt_pt - last_endpt)) + fit_left_edge = left_edge - time_sl.start + hpi["n_window"] // 2 + fit_sl = slice(fit_left_edge, fit_left_edge + (subt_pt - last_endpt)) chunks.append((subt_pt, np.dot(this_data, this_recon[:, fit_sl]))) last_endpt = subt_pt # Consume (trailing) chunks that are now safe to remove because # our windows will no longer touch them if ii < len(fit_idxs) - 1: - next_left_edge = fit_idxs[ii + 1] - hpi['n_window'] // 2 + next_left_edge = fit_idxs[ii + 1] - hpi["n_window"] // 2 else: next_left_edge = np.inf while len(chunks) > 0 and chunks[0][0] <= next_left_edge: right_edge, chunk = chunks.pop(0) - raw._data[meg_picks, - right_edge - chunk.shape[1]:right_edge] -= chunk + raw._data[meg_picks, right_edge - chunk.shape[1] : right_edge] -= chunk return raw def _compute_good_distances(hpi_coil_dists, new_pos, dist_limit=0.005): """Compute good coils based on distances.""" from scipy.spatial.distance import cdist + these_dists = cdist(new_pos, new_pos) these_dists = np.abs(hpi_coil_dists - these_dists) # there is probably a better algorithm for finding the bad ones... @@ -1375,7 +1560,7 @@ def _compute_good_distances(hpi_coil_dists, new_pos, dist_limit=0.005): use_mask = np.ones(len(hpi_coil_dists), bool) while not good: d = these_dists[use_mask][:, use_mask] - d_bad = (d > dist_limit) + d_bad = d > dist_limit good = not d_bad.any() if not good: if use_mask.sum() == 2: @@ -1389,7 +1574,7 @@ def _compute_good_distances(hpi_coil_dists, new_pos, dist_limit=0.005): @verbose -def get_active_chpi(raw, *, on_missing='raise', verbose=None): +def get_active_chpi(raw, *, on_missing="raise", verbose=None): """Determine how many HPI coils were active for a time point. Parameters @@ -1412,10 +1597,14 @@ def get_active_chpi(raw, *, on_missing='raise', verbose=None): system, _ = _get_meg_system(raw.info) # check whether we have a neuromag system - if system not in ['122m', '306m']: - raise NotImplementedError(('Identifying active HPI channels' - ' is not implemented for other systems' - ' than neuromag.')) + if system not in ["122m", "306m"]: + raise NotImplementedError( + ( + "Identifying active HPI channels" + " is not implemented for other systems" + " than neuromag." + ) + ) # extract hpi info chpi_info = get_chpi_info(raw.info, on_missing=on_missing) if len(chpi_info[2]) == 0: diff --git a/mne/commands/mne_anonymize.py b/mne/commands/mne_anonymize.py index 7c858319265..d4b54000b78 100644 --- a/mne/commands/mne_anonymize.py +++ b/mne/commands/mne_anonymize.py @@ -19,7 +19,7 @@ import mne import os.path as op -ANONYMIZE_FILE_PREFIX = 'anon' +ANONYMIZE_FILE_PREFIX = "anon" def mne_anonymize(fif_fname, out_fname, keep_his, daysback, overwrite): @@ -49,8 +49,7 @@ def mne_anonymize(fif_fname, out_fname, keep_his, daysback, overwrite): dir_name = op.split(fif_fname)[0] if out_fname is None: fif_bname = op.basename(fif_fname) - out_fname = op.join(dir_name, - "{}-{}".format(ANONYMIZE_FILE_PREFIX, fif_bname)) + out_fname = op.join(dir_name, "{}-{}".format(ANONYMIZE_FILE_PREFIX, fif_bname)) elif not op.isabs(out_fname): out_fname = op.join(dir_name, out_fname) @@ -63,20 +62,48 @@ def run(): parser = get_optparser(__file__) - parser.add_option("-f", "--file", type="string", dest="file", - help="Name of file to modify.", metavar="FILE", - default=None) - parser.add_option("-o", "--output", type="string", dest="output", - help="Name of anonymized output file." - "`anon-` prefix is added to FILE if not given", - metavar="OUTFILE", default=None) - parser.add_option("--keep_his", dest="keep_his", action="store_true", - help="Keep the HIS tag (not advised)", default=False) - parser.add_option("-d", "--daysback", type="int", dest="daysback", - help="Move dates in file backwards by this many days.", - metavar="N_DAYS", default=None) - parser.add_option("--overwrite", dest="overwrite", action="store_true", - help="Overwrite input file.", default=False) + parser.add_option( + "-f", + "--file", + type="string", + dest="file", + help="Name of file to modify.", + metavar="FILE", + default=None, + ) + parser.add_option( + "-o", + "--output", + type="string", + dest="output", + help="Name of anonymized output file." + "`anon-` prefix is added to FILE if not given", + metavar="OUTFILE", + default=None, + ) + parser.add_option( + "--keep_his", + dest="keep_his", + action="store_true", + help="Keep the HIS tag (not advised)", + default=False, + ) + parser.add_option( + "-d", + "--daysback", + type="int", + dest="daysback", + help="Move dates in file backwards by this many days.", + metavar="N_DAYS", + default=None, + ) + parser.add_option( + "--overwrite", + dest="overwrite", + action="store_true", + help="Overwrite input file.", + default=False, + ) options, args = parser.parse_args() if options.file is None: @@ -88,12 +115,12 @@ def run(): keep_his = options.keep_his daysback = options.daysback overwrite = options.overwrite - if not fname.endswith('.fif'): - raise ValueError('%s does not seem to be a .fif file.' % fname) + if not fname.endswith(".fif"): + raise ValueError("%s does not seem to be a .fif file." % fname) mne_anonymize(fname, out_fname, keep_his, daysback, overwrite) -is_main = (__name__ == '__main__') +is_main = __name__ == "__main__" if is_main: run() diff --git a/mne/commands/mne_browse_raw.py b/mne/commands/mne_browse_raw.py index 95b4381cc7b..9c338518e85 100644 --- a/mne/commands/mne_browse_raw.py +++ b/mne/commands/mne_browse_raw.py @@ -24,57 +24,114 @@ def run(): from mne.commands.utils import get_optparser, _add_verbose_flag from mne.viz import _RAW_CLIP_DEF - parser = get_optparser(__file__, usage='usage: %prog raw [options]') - - parser.add_option("--raw", dest="raw_in", - help="Input raw FIF file (can also be specified " - "directly as an argument without the --raw prefix)", - metavar="FILE") - parser.add_option("--proj", dest="proj_in", - help="Projector file", metavar="FILE", - default='') - parser.add_option("--projoff", dest="proj_off", - help="Disable all projectors", - default=False, action="store_true") - parser.add_option("--eve", dest="eve_in", - help="Events file", metavar="FILE", - default='') - parser.add_option("-d", "--duration", dest="duration", type="float", - help="Time window for plotting (s)", - default=10.0) - parser.add_option("-t", "--start", dest="start", type="float", - help="Initial start time for plotting", - default=0.0) - parser.add_option("-n", "--n_channels", dest="n_channels", type="int", - help="Number of channels to plot at a time", - default=20) - parser.add_option("-o", "--order", dest="group_by", - help="Order to use for grouping during plotting " - "('type' or 'original')", default='type') - parser.add_option("-p", "--preload", dest="preload", - help="Preload raw data (for faster navigaton)", - default=False, action="store_true") - parser.add_option("-s", "--show_options", dest="show_options", - help="Show projection options dialog", - default=False) - parser.add_option("--allowmaxshield", dest="maxshield", - help="Allow loading MaxShield processed data", - action="store_true") - parser.add_option("--highpass", dest="highpass", type="float", - help="Display high-pass filter corner frequency", - default=-1) - parser.add_option("--lowpass", dest="lowpass", type="float", - help="Display low-pass filter corner frequency", - default=-1) - parser.add_option("--filtorder", dest="filtorder", type="int", - help="Display filtering IIR order (or 0 to use FIR)", - default=4) - parser.add_option("--clipping", dest="clipping", - help="Enable trace clipping mode, either 'clamp' or " - "'transparent'", default=_RAW_CLIP_DEF) - parser.add_option("--filterchpi", dest="filterchpi", - help="Enable filtering cHPI signals.", default=None, - action="store_true") + parser = get_optparser(__file__, usage="usage: %prog raw [options]") + + parser.add_option( + "--raw", + dest="raw_in", + help="Input raw FIF file (can also be specified " + "directly as an argument without the --raw prefix)", + metavar="FILE", + ) + parser.add_option( + "--proj", dest="proj_in", help="Projector file", metavar="FILE", default="" + ) + parser.add_option( + "--projoff", + dest="proj_off", + help="Disable all projectors", + default=False, + action="store_true", + ) + parser.add_option( + "--eve", dest="eve_in", help="Events file", metavar="FILE", default="" + ) + parser.add_option( + "-d", + "--duration", + dest="duration", + type="float", + help="Time window for plotting (s)", + default=10.0, + ) + parser.add_option( + "-t", + "--start", + dest="start", + type="float", + help="Initial start time for plotting", + default=0.0, + ) + parser.add_option( + "-n", + "--n_channels", + dest="n_channels", + type="int", + help="Number of channels to plot at a time", + default=20, + ) + parser.add_option( + "-o", + "--order", + dest="group_by", + help="Order to use for grouping during plotting " "('type' or 'original')", + default="type", + ) + parser.add_option( + "-p", + "--preload", + dest="preload", + help="Preload raw data (for faster navigaton)", + default=False, + action="store_true", + ) + parser.add_option( + "-s", + "--show_options", + dest="show_options", + help="Show projection options dialog", + default=False, + ) + parser.add_option( + "--allowmaxshield", + dest="maxshield", + help="Allow loading MaxShield processed data", + action="store_true", + ) + parser.add_option( + "--highpass", + dest="highpass", + type="float", + help="Display high-pass filter corner frequency", + default=-1, + ) + parser.add_option( + "--lowpass", + dest="lowpass", + type="float", + help="Display low-pass filter corner frequency", + default=-1, + ) + parser.add_option( + "--filtorder", + dest="filtorder", + type="int", + help="Display filtering IIR order (or 0 to use FIR)", + default=4, + ) + parser.add_option( + "--clipping", + dest="clipping", + help="Enable trace clipping mode, either 'clamp' or " "'transparent'", + default=_RAW_CLIP_DEF, + ) + parser.add_option( + "--filterchpi", + dest="filterchpi", + help="Enable filtering cHPI signals.", + default=None, + action="store_true", + ) _add_verbose_flag(parser) options, args = parser.parse_args() @@ -97,7 +154,7 @@ def run(): filtorder = options.filtorder clipping = options.clipping if isinstance(clipping, str): - if clipping.lower() == 'none': + if clipping.lower() == "none": clipping = None else: try: @@ -113,11 +170,11 @@ def run(): kwargs = dict(preload=preload) if maxshield: - kwargs.update(allow_maxshield='yes') + kwargs.update(allow_maxshield="yes") raw = mne.io.read_raw(raw_in, **kwargs) if len(proj_in) > 0: projs = mne.read_proj(proj_in) - raw.info['projs'] = projs + raw.info["projs"] = projs if len(eve_in) > 0: events = mne.read_events(eve_in) else: @@ -125,17 +182,27 @@ def run(): if filterchpi: if not preload: - raise RuntimeError( - 'Raw data must be preloaded for chpi, use --preload') + raise RuntimeError("Raw data must be preloaded for chpi, use --preload") raw = mne.chpi.filter_chpi(raw) highpass = None if highpass < 0 or filtorder < 0 else highpass lowpass = None if lowpass < 0 or filtorder < 0 else lowpass - raw.plot(duration=duration, start=start, n_channels=n_channels, - group_by=group_by, show_options=show_options, events=events, - highpass=highpass, lowpass=lowpass, filtorder=filtorder, - clipping=clipping, proj=not proj_off, verbose=verbose, - show=True, block=True) + raw.plot( + duration=duration, + start=start, + n_channels=n_channels, + group_by=group_by, + show_options=show_options, + events=events, + highpass=highpass, + lowpass=lowpass, + filtorder=filtorder, + clipping=clipping, + proj=not proj_off, + verbose=verbose, + show=True, + block=True, + ) mne.utils.run_command_if_main() diff --git a/mne/commands/mne_bti2fiff.py b/mne/commands/mne_bti2fiff.py index db3c37fcd8c..88510626822 100644 --- a/mne/commands/mne_bti2fiff.py +++ b/mne/commands/mne_bti2fiff.py @@ -41,29 +41,57 @@ def run(): parser = get_optparser(__file__) - parser.add_option('-p', '--pdf', dest='pdf_fname', - help='Input data file name', metavar='FILE') - parser.add_option('-c', '--config', dest='config_fname', - help='Input config file name', metavar='FILE', - default='config') - parser.add_option('--head_shape', dest='head_shape_fname', - help='Headshape file name', metavar='FILE', - default='hs_file') - parser.add_option('-o', '--out_fname', dest='out_fname', - help='Name of the resulting fiff file', - default='as_data_fname') - parser.add_option('-r', '--rotation_x', dest='rotation_x', type='float', - help='Compensatory rotation about Neuromag x axis, deg', - default=2.0) - parser.add_option('-T', '--translation', dest='translation', type='str', - help='Default translation, meter', - default=(0.00, 0.02, 0.11)) - parser.add_option('--ecg_ch', dest='ecg_ch', type='str', - help='4D ECG channel name', - default='E31') - parser.add_option('--eog_ch', dest='eog_ch', type='str', - help='4D EOG channel names', - default='E63,E64') + parser.add_option( + "-p", "--pdf", dest="pdf_fname", help="Input data file name", metavar="FILE" + ) + parser.add_option( + "-c", + "--config", + dest="config_fname", + help="Input config file name", + metavar="FILE", + default="config", + ) + parser.add_option( + "--head_shape", + dest="head_shape_fname", + help="Headshape file name", + metavar="FILE", + default="hs_file", + ) + parser.add_option( + "-o", + "--out_fname", + dest="out_fname", + help="Name of the resulting fiff file", + default="as_data_fname", + ) + parser.add_option( + "-r", + "--rotation_x", + dest="rotation_x", + type="float", + help="Compensatory rotation about Neuromag x axis, deg", + default=2.0, + ) + parser.add_option( + "-T", + "--translation", + dest="translation", + type="str", + help="Default translation, meter", + default=(0.00, 0.02, 0.11), + ) + parser.add_option( + "--ecg_ch", dest="ecg_ch", type="str", help="4D ECG channel name", default="E31" + ) + parser.add_option( + "--eog_ch", + dest="eog_ch", + type="str", + help="4D EOG channel names", + default="E63,E64", + ) options, args = parser.parse_args() @@ -78,15 +106,20 @@ def run(): rotation_x = options.rotation_x translation = options.translation ecg_ch = options.ecg_ch - eog_ch = options.ecg_ch.split(',') - - if out_fname == 'as_data_fname': - out_fname = pdf_fname + '_raw.fif' - - raw = read_raw_bti(pdf_fname=pdf_fname, config_fname=config_fname, - head_shape_fname=head_shape_fname, - rotation_x=rotation_x, translation=translation, - ecg_ch=ecg_ch, eog_ch=eog_ch) + eog_ch = options.ecg_ch.split(",") + + if out_fname == "as_data_fname": + out_fname = pdf_fname + "_raw.fif" + + raw = read_raw_bti( + pdf_fname=pdf_fname, + config_fname=config_fname, + head_shape_fname=head_shape_fname, + rotation_x=rotation_x, + translation=translation, + ecg_ch=ecg_ch, + eog_ch=eog_ch, + ) raw.save(out_fname) raw.close() diff --git a/mne/commands/mne_clean_eog_ecg.py b/mne/commands/mne_clean_eog_ecg.py index f722a9fea52..b1ffaa74edd 100644 --- a/mne/commands/mne_clean_eog_ecg.py +++ b/mne/commands/mne_clean_eog_ecg.py @@ -18,10 +18,18 @@ import mne -def clean_ecg_eog(in_fif_fname, out_fif_fname=None, eog=True, ecg=True, - ecg_proj_fname=None, eog_proj_fname=None, - ecg_event_fname=None, eog_event_fname=None, in_path='.', - quiet=False): +def clean_ecg_eog( + in_fif_fname, + out_fif_fname=None, + eog=True, + ecg=True, + ecg_proj_fname=None, + eog_proj_fname=None, + ecg_event_fname=None, + eog_event_fname=None, + in_path=".", + quiet=False, +): """Clean ECG from raw fif file. Parameters @@ -45,65 +53,124 @@ def clean_ecg_eog(in_fif_fname, out_fif_fname=None, eog=True, ecg=True, # Reading fif File raw_in = mne.io.read_raw_fif(in_fif_fname) - if in_fif_fname.endswith('_raw.fif') or in_fif_fname.endswith('-raw.fif'): + if in_fif_fname.endswith("_raw.fif") or in_fif_fname.endswith("-raw.fif"): prefix = in_fif_fname[:-8] else: prefix = in_fif_fname[:-4] if out_fif_fname is None: - out_fif_fname = prefix + '_clean_ecg_eog_raw.fif' + out_fif_fname = prefix + "_clean_ecg_eog_raw.fif" if ecg_proj_fname is None: - ecg_proj_fname = prefix + '_ecg-proj.fif' + ecg_proj_fname = prefix + "_ecg-proj.fif" if eog_proj_fname is None: - eog_proj_fname = prefix + '_eog-proj.fif' + eog_proj_fname = prefix + "_eog-proj.fif" if ecg_event_fname is None: - ecg_event_fname = prefix + '_ecg-eve.fif' + ecg_event_fname = prefix + "_ecg-eve.fif" if eog_event_fname is None: - eog_event_fname = prefix + '_eog-eve.fif' + eog_event_fname = prefix + "_eog-eve.fif" - print('Implementing ECG and EOG artifact rejection on data') + print("Implementing ECG and EOG artifact rejection on data") kwargs = dict() if quiet else dict(stdout=None, stderr=None) if ecg: ecg_events, _, _ = mne.preprocessing.find_ecg_events( - raw_in, reject_by_annotation=True) + raw_in, reject_by_annotation=True + ) print("Writing ECG events in %s" % ecg_event_fname) mne.write_events(ecg_event_fname, ecg_events) - print('Computing ECG projector') - command = ('mne_process_raw', '--cd', in_path, '--raw', in_fif_fname, - '--events', ecg_event_fname, '--makeproj', - '--projtmin', '-0.08', '--projtmax', '0.08', - '--saveprojtag', '_ecg-proj', '--projnmag', '2', - '--projngrad', '1', '--projevent', '999', '--highpass', '5', - '--lowpass', '35', '--projmagrej', '4000', - '--projgradrej', '3000') + print("Computing ECG projector") + command = ( + "mne_process_raw", + "--cd", + in_path, + "--raw", + in_fif_fname, + "--events", + ecg_event_fname, + "--makeproj", + "--projtmin", + "-0.08", + "--projtmax", + "0.08", + "--saveprojtag", + "_ecg-proj", + "--projnmag", + "2", + "--projngrad", + "1", + "--projevent", + "999", + "--highpass", + "5", + "--lowpass", + "35", + "--projmagrej", + "4000", + "--projgradrej", + "3000", + ) mne.utils.run_subprocess(command, **kwargs) if eog: eog_events = mne.preprocessing.find_eog_events(raw_in) print("Writing EOG events in %s" % eog_event_fname) mne.write_events(eog_event_fname, eog_events) - print('Computing EOG projector') - command = ('mne_process_raw', '--cd', in_path, '--raw', in_fif_fname, - '--events', eog_event_fname, '--makeproj', - '--projtmin', '-0.15', '--projtmax', '0.15', - '--saveprojtag', '_eog-proj', '--projnmag', '2', - '--projngrad', '2', '--projevent', '998', '--lowpass', '35', - '--projmagrej', '4000', '--projgradrej', '3000') + print("Computing EOG projector") + command = ( + "mne_process_raw", + "--cd", + in_path, + "--raw", + in_fif_fname, + "--events", + eog_event_fname, + "--makeproj", + "--projtmin", + "-0.15", + "--projtmax", + "0.15", + "--saveprojtag", + "_eog-proj", + "--projnmag", + "2", + "--projngrad", + "2", + "--projevent", + "998", + "--lowpass", + "35", + "--projmagrej", + "4000", + "--projgradrej", + "3000", + ) mne.utils.run_subprocess(command, **kwargs) if out_fif_fname is not None: # Applying the ECG EOG projector - print('Applying ECG EOG projector') - command = ('mne_process_raw', '--cd', in_path, '--raw', in_fif_fname, - '--proj', in_fif_fname, '--projoff', '--save', - out_fif_fname, '--filteroff', - '--proj', ecg_proj_fname, '--proj', eog_proj_fname) + print("Applying ECG EOG projector") + command = ( + "mne_process_raw", + "--cd", + in_path, + "--raw", + in_fif_fname, + "--proj", + in_fif_fname, + "--projoff", + "--save", + out_fif_fname, + "--filteroff", + "--proj", + ecg_proj_fname, + "--proj", + eog_proj_fname, + ) mne.utils.run_subprocess(command, **kwargs) - print('Done removing artifacts.') + print("Done removing artifacts.") print("Cleaned raw data saved in: %s" % out_fif_fname) - print('IMPORTANT : Please eye-ball the data !!') + print("IMPORTANT : Please eye-ball the data !!") else: - print('Projection not applied to raw data.') + print("Projection not applied to raw data.") def run(): @@ -112,17 +179,41 @@ def run(): parser = get_optparser(__file__) - parser.add_option("-i", "--in", dest="raw_in", - help="Input raw FIF file", metavar="FILE") - parser.add_option("-o", "--out", dest="raw_out", - help="Output raw FIF file", metavar="FILE", - default=None) - parser.add_option("-e", "--no-eog", dest="eog", action="store_false", - help="Remove EOG", default=True) - parser.add_option("-c", "--no-ecg", dest="ecg", action="store_false", - help="Remove ECG", default=True) - parser.add_option("-q", "--quiet", dest="quiet", action="store_true", - help="Suppress mne_process_raw output", default=False) + parser.add_option( + "-i", "--in", dest="raw_in", help="Input raw FIF file", metavar="FILE" + ) + parser.add_option( + "-o", + "--out", + dest="raw_out", + help="Output raw FIF file", + metavar="FILE", + default=None, + ) + parser.add_option( + "-e", + "--no-eog", + dest="eog", + action="store_false", + help="Remove EOG", + default=True, + ) + parser.add_option( + "-c", + "--no-ecg", + dest="ecg", + action="store_false", + help="Remove ECG", + default=True, + ) + parser.add_option( + "-q", + "--quiet", + dest="quiet", + action="store_true", + help="Suppress mne_process_raw output", + default=False, + ) options, args = parser.parse_args() diff --git a/mne/commands/mne_compare_fiff.py b/mne/commands/mne_compare_fiff.py index b616a3e4072..fe05d636592 100644 --- a/mne/commands/mne_compare_fiff.py +++ b/mne/commands/mne_compare_fiff.py @@ -18,7 +18,8 @@ def run(): """Run command.""" parser = mne.commands.utils.get_optparser( - __file__, usage='mne compare_fiff ') + __file__, usage="mne compare_fiff " + ) options, args = parser.parse_args() if len(args) != 2: parser.print_help() diff --git a/mne/commands/mne_compute_proj_ecg.py b/mne/commands/mne_compute_proj_ecg.py index c42798be3be..bb366f9d3e2 100644 --- a/mne/commands/mne_compute_proj_ecg.py +++ b/mne/commands/mne_compute_proj_ecg.py @@ -24,97 +24,191 @@ def run(): parser = get_optparser(__file__) - parser.add_option("-i", "--in", dest="raw_in", - help="Input raw FIF file", metavar="FILE") - parser.add_option("--tmin", dest="tmin", type="float", - help="Time before event in seconds", - default=-0.2) - parser.add_option("--tmax", dest="tmax", type="float", - help="Time after event in seconds", - default=0.4) - parser.add_option("-g", "--n-grad", dest="n_grad", type="int", - help="Number of SSP vectors for gradiometers", - default=2) - parser.add_option("-m", "--n-mag", dest="n_mag", type="int", - help="Number of SSP vectors for magnetometers", - default=2) - parser.add_option("-e", "--n-eeg", dest="n_eeg", type="int", - help="Number of SSP vectors for EEG", - default=2) - parser.add_option("--l-freq", dest="l_freq", type="float", - help="Filter low cut-off frequency in Hz", - default=1) - parser.add_option("--h-freq", dest="h_freq", type="float", - help="Filter high cut-off frequency in Hz", - default=100) - parser.add_option("--ecg-l-freq", dest="ecg_l_freq", type="float", - help="Filter low cut-off frequency in Hz used " - "for ECG event detection", - default=5) - parser.add_option("--ecg-h-freq", dest="ecg_h_freq", type="float", - help="Filter high cut-off frequency in Hz used " - "for ECG event detection", - default=35) - parser.add_option("-p", "--preload", dest="preload", - help="Temporary file used during computation " - "(to save memory)", - default=True) - parser.add_option("-a", "--average", dest="average", action="store_true", - help="Compute SSP after averaging", - default=False) - parser.add_option("--proj", dest="proj", - help="Use SSP projections from a fif file.", - default=None) - parser.add_option("--filtersize", dest="filter_length", type="int", - help="Number of taps to use for filtering", - default=2048) - parser.add_option("-j", "--n-jobs", dest="n_jobs", type="int", - help="Number of jobs to run in parallel", - default=1) - parser.add_option("-c", "--channel", dest="ch_name", - help="Channel to use for ECG detection " - "(Required if no ECG found)", - default=None) - parser.add_option("--rej-grad", dest="rej_grad", type="float", - help="Gradiometers rejection parameter " - "in fT/cm (peak to peak amplitude)", - default=2000) - parser.add_option("--rej-mag", dest="rej_mag", type="float", - help="Magnetometers rejection parameter " - "in fT (peak to peak amplitude)", - default=3000) - parser.add_option("--rej-eeg", dest="rej_eeg", type="float", - help="EEG rejection parameter in µV " - "(peak to peak amplitude)", - default=50) - parser.add_option("--rej-eog", dest="rej_eog", type="float", - help="EOG rejection parameter in µV " - "(peak to peak amplitude)", - default=250) - parser.add_option("--avg-ref", dest="avg_ref", action="store_true", - help="Add EEG average reference proj", - default=False) - parser.add_option("--no-proj", dest="no_proj", action="store_true", - help="Exclude the SSP projectors currently " - "in the fiff file", - default=False) - parser.add_option("--bad", dest="bad_fname", - help="Text file containing bad channels list " - "(one per line)", - default=None) - parser.add_option("--event-id", dest="event_id", type="int", - help="ID to use for events", - default=999) - parser.add_option("--event-raw", dest="raw_event_fname", - help="raw file to use for event detection", - default=None) - parser.add_option("--tstart", dest="tstart", type="float", - help="Start artifact detection after tstart seconds", - default=0.) - parser.add_option("--qrsthr", dest="qrs_threshold", type="string", - help="QRS detection threshold. Between 0 and 1. Can " - "also be 'auto' for automatic selection", - default='auto') + parser.add_option( + "-i", "--in", dest="raw_in", help="Input raw FIF file", metavar="FILE" + ) + parser.add_option( + "--tmin", + dest="tmin", + type="float", + help="Time before event in seconds", + default=-0.2, + ) + parser.add_option( + "--tmax", + dest="tmax", + type="float", + help="Time after event in seconds", + default=0.4, + ) + parser.add_option( + "-g", + "--n-grad", + dest="n_grad", + type="int", + help="Number of SSP vectors for gradiometers", + default=2, + ) + parser.add_option( + "-m", + "--n-mag", + dest="n_mag", + type="int", + help="Number of SSP vectors for magnetometers", + default=2, + ) + parser.add_option( + "-e", + "--n-eeg", + dest="n_eeg", + type="int", + help="Number of SSP vectors for EEG", + default=2, + ) + parser.add_option( + "--l-freq", + dest="l_freq", + type="float", + help="Filter low cut-off frequency in Hz", + default=1, + ) + parser.add_option( + "--h-freq", + dest="h_freq", + type="float", + help="Filter high cut-off frequency in Hz", + default=100, + ) + parser.add_option( + "--ecg-l-freq", + dest="ecg_l_freq", + type="float", + help="Filter low cut-off frequency in Hz used " "for ECG event detection", + default=5, + ) + parser.add_option( + "--ecg-h-freq", + dest="ecg_h_freq", + type="float", + help="Filter high cut-off frequency in Hz used " "for ECG event detection", + default=35, + ) + parser.add_option( + "-p", + "--preload", + dest="preload", + help="Temporary file used during computation " "(to save memory)", + default=True, + ) + parser.add_option( + "-a", + "--average", + dest="average", + action="store_true", + help="Compute SSP after averaging", + default=False, + ) + parser.add_option( + "--proj", dest="proj", help="Use SSP projections from a fif file.", default=None + ) + parser.add_option( + "--filtersize", + dest="filter_length", + type="int", + help="Number of taps to use for filtering", + default=2048, + ) + parser.add_option( + "-j", + "--n-jobs", + dest="n_jobs", + type="int", + help="Number of jobs to run in parallel", + default=1, + ) + parser.add_option( + "-c", + "--channel", + dest="ch_name", + help="Channel to use for ECG detection " "(Required if no ECG found)", + default=None, + ) + parser.add_option( + "--rej-grad", + dest="rej_grad", + type="float", + help="Gradiometers rejection parameter " "in fT/cm (peak to peak amplitude)", + default=2000, + ) + parser.add_option( + "--rej-mag", + dest="rej_mag", + type="float", + help="Magnetometers rejection parameter " "in fT (peak to peak amplitude)", + default=3000, + ) + parser.add_option( + "--rej-eeg", + dest="rej_eeg", + type="float", + help="EEG rejection parameter in µV " "(peak to peak amplitude)", + default=50, + ) + parser.add_option( + "--rej-eog", + dest="rej_eog", + type="float", + help="EOG rejection parameter in µV " "(peak to peak amplitude)", + default=250, + ) + parser.add_option( + "--avg-ref", + dest="avg_ref", + action="store_true", + help="Add EEG average reference proj", + default=False, + ) + parser.add_option( + "--no-proj", + dest="no_proj", + action="store_true", + help="Exclude the SSP projectors currently " "in the fiff file", + default=False, + ) + parser.add_option( + "--bad", + dest="bad_fname", + help="Text file containing bad channels list " "(one per line)", + default=None, + ) + parser.add_option( + "--event-id", + dest="event_id", + type="int", + help="ID to use for events", + default=999, + ) + parser.add_option( + "--event-raw", + dest="raw_event_fname", + help="raw file to use for event detection", + default=None, + ) + parser.add_option( + "--tstart", + dest="tstart", + type="float", + help="Start artifact detection after tstart seconds", + default=0.0, + ) + parser.add_option( + "--qrsthr", + dest="qrs_threshold", + type="string", + help="QRS detection threshold. Between 0 and 1. Can " + "also be 'auto' for automatic selection", + default="auto", + ) options, args = parser.parse_args() @@ -138,10 +232,12 @@ def run(): filter_length = options.filter_length n_jobs = options.n_jobs ch_name = options.ch_name - reject = dict(grad=1e-13 * float(options.rej_grad), - mag=1e-15 * float(options.rej_mag), - eeg=1e-6 * float(options.rej_eeg), - eog=1e-6 * float(options.rej_eog)) + reject = dict( + grad=1e-13 * float(options.rej_grad), + mag=1e-15 * float(options.rej_mag), + eeg=1e-6 * float(options.rej_eeg), + eog=1e-6 * float(options.rej_eog), + ) avg_ref = options.avg_ref no_proj = options.no_proj bad_fname = options.bad_fname @@ -150,30 +246,30 @@ def run(): raw_event_fname = options.raw_event_fname tstart = options.tstart qrs_threshold = options.qrs_threshold - if qrs_threshold != 'auto': + if qrs_threshold != "auto": try: qrs_threshold = float(qrs_threshold) except ValueError: raise ValueError('qrsthr must be "auto" or a float') if bad_fname is not None: - with open(bad_fname, 'r') as fid: + with open(bad_fname, "r") as fid: bads = [w.rstrip() for w in fid.readlines()] - print('Bad channels read : %s' % bads) + print("Bad channels read : %s" % bads) else: bads = [] - if raw_in.endswith('_raw.fif') or raw_in.endswith('-raw.fif'): + if raw_in.endswith("_raw.fif") or raw_in.endswith("-raw.fif"): prefix = raw_in[:-8] else: prefix = raw_in[:-4] - ecg_event_fname = prefix + '_ecg-eve.fif' + ecg_event_fname = prefix + "_ecg-eve.fif" if average: - ecg_proj_fname = prefix + '_ecg_avg-proj.fif' + ecg_proj_fname = prefix + "_ecg_avg-proj.fif" else: - ecg_proj_fname = prefix + '_ecg-proj.fif' + ecg_proj_fname = prefix + "_ecg-proj.fif" raw = mne.io.read_raw_fif(raw_in, preload=preload) @@ -184,10 +280,31 @@ def run(): flat = None projs, events = mne.preprocessing.compute_proj_ecg( - raw, raw_event, tmin, tmax, n_grad, n_mag, n_eeg, l_freq, h_freq, - average, filter_length, n_jobs, ch_name, reject, flat, bads, avg_ref, - no_proj, event_id, ecg_l_freq, ecg_h_freq, tstart, qrs_threshold, - copy=False) + raw, + raw_event, + tmin, + tmax, + n_grad, + n_mag, + n_eeg, + l_freq, + h_freq, + average, + filter_length, + n_jobs, + ch_name, + reject, + flat, + bads, + avg_ref, + no_proj, + event_id, + ecg_l_freq, + ecg_h_freq, + tstart, + qrs_threshold, + copy=False, + ) raw.close() @@ -195,7 +312,7 @@ def run(): raw_event.close() if proj_fname is not None: - print('Including SSP projections from : %s' % proj_fname) + print("Including SSP projections from : %s" % proj_fname) # append the ecg projs, so they are last in the list projs = mne.read_proj(proj_fname) + projs diff --git a/mne/commands/mne_compute_proj_eog.py b/mne/commands/mne_compute_proj_eog.py index 3494ffa47af..42c93513122 100644 --- a/mne/commands/mne_compute_proj_eog.py +++ b/mne/commands/mne_compute_proj_eog.py @@ -34,77 +34,184 @@ def run(): parser = get_optparser(__file__) - parser.add_option("-i", "--in", dest="raw_in", - help="Input raw FIF file", metavar="FILE") - parser.add_option("--tmin", dest="tmin", type="float", - help="Time before event in seconds", default=-0.2) - parser.add_option("--tmax", dest="tmax", type="float", - help="Time after event in seconds", default=0.2) - parser.add_option("-g", "--n-grad", dest="n_grad", type="int", - help="Number of SSP vectors for gradiometers", - default=2) - parser.add_option("-m", "--n-mag", dest="n_mag", type="int", - help="Number of SSP vectors for magnetometers", - default=2) - parser.add_option("-e", "--n-eeg", dest="n_eeg", type="int", - help="Number of SSP vectors for EEG", default=2) - parser.add_option("--l-freq", dest="l_freq", type="float", - help="Filter low cut-off frequency in Hz", - default=1) - parser.add_option("--h-freq", dest="h_freq", type="float", - help="Filter high cut-off frequency in Hz", - default=35) - parser.add_option("--eog-l-freq", dest="eog_l_freq", type="float", - help="Filter low cut-off frequency in Hz used for " - "EOG event detection", default=1) - parser.add_option("--eog-h-freq", dest="eog_h_freq", type="float", - help="Filter high cut-off frequency in Hz used for " - "EOG event detection", default=10) - parser.add_option("-p", "--preload", dest="preload", - help="Temporary file used during computation (to " - "save memory)", default=True) - parser.add_option("-a", "--average", dest="average", action="store_true", - help="Compute SSP after averaging", - default=False) - parser.add_option("--proj", dest="proj", - help="Use SSP projections from a fif file.", - default=None) - parser.add_option("--filtersize", dest="filter_length", type="int", - help="Number of taps to use for filtering", - default=2048) - parser.add_option("-j", "--n-jobs", dest="n_jobs", type="int", - help="Number of jobs to run in parallel", default=1) - parser.add_option("--rej-grad", dest="rej_grad", type="float", - help="Gradiometers rejection parameter in fT/cm (peak " - "to peak amplitude)", default=2000) - parser.add_option("--rej-mag", dest="rej_mag", type="float", - help="Magnetometers rejection parameter in fT (peak to " - "peak amplitude)", default=3000) - parser.add_option("--rej-eeg", dest="rej_eeg", type="float", - help="EEG rejection parameter in µV (peak to peak " - "amplitude)", default=50) - parser.add_option("--rej-eog", dest="rej_eog", type="float", - help="EOG rejection parameter in µV (peak to peak " - "amplitude)", default=1e9) - parser.add_option("--avg-ref", dest="avg_ref", action="store_true", - help="Add EEG average reference proj", - default=False) - parser.add_option("--no-proj", dest="no_proj", action="store_true", - help="Exclude the SSP projectors currently in the " - "fiff file", default=False) - parser.add_option("--bad", dest="bad_fname", - help="Text file containing bad channels list " - "(one per line)", default=None) - parser.add_option("--event-id", dest="event_id", type="int", - help="ID to use for events", default=998) - parser.add_option("--event-raw", dest="raw_event_fname", - help="raw file to use for event detection", default=None) - parser.add_option("--tstart", dest="tstart", type="float", - help="Start artifact detection after tstart seconds", - default=0.) - parser.add_option("-c", "--channel", dest="ch_name", type="string", - help="Custom EOG channel(s), comma separated", - default=None) + parser.add_option( + "-i", "--in", dest="raw_in", help="Input raw FIF file", metavar="FILE" + ) + parser.add_option( + "--tmin", + dest="tmin", + type="float", + help="Time before event in seconds", + default=-0.2, + ) + parser.add_option( + "--tmax", + dest="tmax", + type="float", + help="Time after event in seconds", + default=0.2, + ) + parser.add_option( + "-g", + "--n-grad", + dest="n_grad", + type="int", + help="Number of SSP vectors for gradiometers", + default=2, + ) + parser.add_option( + "-m", + "--n-mag", + dest="n_mag", + type="int", + help="Number of SSP vectors for magnetometers", + default=2, + ) + parser.add_option( + "-e", + "--n-eeg", + dest="n_eeg", + type="int", + help="Number of SSP vectors for EEG", + default=2, + ) + parser.add_option( + "--l-freq", + dest="l_freq", + type="float", + help="Filter low cut-off frequency in Hz", + default=1, + ) + parser.add_option( + "--h-freq", + dest="h_freq", + type="float", + help="Filter high cut-off frequency in Hz", + default=35, + ) + parser.add_option( + "--eog-l-freq", + dest="eog_l_freq", + type="float", + help="Filter low cut-off frequency in Hz used for " "EOG event detection", + default=1, + ) + parser.add_option( + "--eog-h-freq", + dest="eog_h_freq", + type="float", + help="Filter high cut-off frequency in Hz used for " "EOG event detection", + default=10, + ) + parser.add_option( + "-p", + "--preload", + dest="preload", + help="Temporary file used during computation (to " "save memory)", + default=True, + ) + parser.add_option( + "-a", + "--average", + dest="average", + action="store_true", + help="Compute SSP after averaging", + default=False, + ) + parser.add_option( + "--proj", dest="proj", help="Use SSP projections from a fif file.", default=None + ) + parser.add_option( + "--filtersize", + dest="filter_length", + type="int", + help="Number of taps to use for filtering", + default=2048, + ) + parser.add_option( + "-j", + "--n-jobs", + dest="n_jobs", + type="int", + help="Number of jobs to run in parallel", + default=1, + ) + parser.add_option( + "--rej-grad", + dest="rej_grad", + type="float", + help="Gradiometers rejection parameter in fT/cm (peak " "to peak amplitude)", + default=2000, + ) + parser.add_option( + "--rej-mag", + dest="rej_mag", + type="float", + help="Magnetometers rejection parameter in fT (peak to " "peak amplitude)", + default=3000, + ) + parser.add_option( + "--rej-eeg", + dest="rej_eeg", + type="float", + help="EEG rejection parameter in µV (peak to peak " "amplitude)", + default=50, + ) + parser.add_option( + "--rej-eog", + dest="rej_eog", + type="float", + help="EOG rejection parameter in µV (peak to peak " "amplitude)", + default=1e9, + ) + parser.add_option( + "--avg-ref", + dest="avg_ref", + action="store_true", + help="Add EEG average reference proj", + default=False, + ) + parser.add_option( + "--no-proj", + dest="no_proj", + action="store_true", + help="Exclude the SSP projectors currently in the " "fiff file", + default=False, + ) + parser.add_option( + "--bad", + dest="bad_fname", + help="Text file containing bad channels list " "(one per line)", + default=None, + ) + parser.add_option( + "--event-id", + dest="event_id", + type="int", + help="ID to use for events", + default=998, + ) + parser.add_option( + "--event-raw", + dest="raw_event_fname", + help="raw file to use for event detection", + default=None, + ) + parser.add_option( + "--tstart", + dest="tstart", + type="float", + help="Start artifact detection after tstart seconds", + default=0.0, + ) + parser.add_option( + "-c", + "--channel", + dest="ch_name", + type="string", + help="Custom EOG channel(s), comma separated", + default=None, + ) options, args = parser.parse_args() @@ -127,10 +234,12 @@ def run(): preload = options.preload filter_length = options.filter_length n_jobs = options.n_jobs - reject = dict(grad=1e-13 * float(options.rej_grad), - mag=1e-15 * float(options.rej_mag), - eeg=1e-6 * float(options.rej_eeg), - eog=1e-6 * float(options.rej_eog)) + reject = dict( + grad=1e-13 * float(options.rej_grad), + mag=1e-15 * float(options.rej_mag), + eeg=1e-6 * float(options.rej_eeg), + eog=1e-6 * float(options.rej_eog), + ) avg_ref = options.avg_ref no_proj = options.no_proj bad_fname = options.bad_fname @@ -141,23 +250,23 @@ def run(): ch_name = options.ch_name if bad_fname is not None: - with open(bad_fname, 'r') as fid: + with open(bad_fname, "r") as fid: bads = [w.rstrip() for w in fid.readlines()] - print('Bad channels read : %s' % bads) + print("Bad channels read : %s" % bads) else: bads = [] - if raw_in.endswith('_raw.fif') or raw_in.endswith('-raw.fif'): + if raw_in.endswith("_raw.fif") or raw_in.endswith("-raw.fif"): prefix = raw_in[:-8] else: prefix = raw_in[:-4] - eog_event_fname = prefix + '_eog-eve.fif' + eog_event_fname = prefix + "_eog-eve.fif" if average: - eog_proj_fname = prefix + '_eog_avg-proj.fif' + eog_proj_fname = prefix + "_eog_avg-proj.fif" else: - eog_proj_fname = prefix + '_eog-proj.fif' + eog_proj_fname = prefix + "_eog-proj.fif" raw = mne.io.read_raw_fif(raw_in, preload=preload) @@ -168,13 +277,30 @@ def run(): flat = None projs, events = mne.preprocessing.compute_proj_eog( - raw=raw, raw_event=raw_event, tmin=tmin, tmax=tmax, n_grad=n_grad, - n_mag=n_mag, n_eeg=n_eeg, l_freq=l_freq, h_freq=h_freq, - average=average, filter_length=filter_length, - n_jobs=n_jobs, reject=reject, flat=flat, bads=bads, - avg_ref=avg_ref, no_proj=no_proj, event_id=event_id, - eog_l_freq=eog_l_freq, eog_h_freq=eog_h_freq, - tstart=tstart, ch_name=ch_name, copy=False) + raw=raw, + raw_event=raw_event, + tmin=tmin, + tmax=tmax, + n_grad=n_grad, + n_mag=n_mag, + n_eeg=n_eeg, + l_freq=l_freq, + h_freq=h_freq, + average=average, + filter_length=filter_length, + n_jobs=n_jobs, + reject=reject, + flat=flat, + bads=bads, + avg_ref=avg_ref, + no_proj=no_proj, + event_id=event_id, + eog_l_freq=eog_l_freq, + eog_h_freq=eog_h_freq, + tstart=tstart, + ch_name=ch_name, + copy=False, + ) raw.close() @@ -182,7 +308,7 @@ def run(): raw_event.close() if proj_fname is not None: - print('Including SSP projections from : %s' % proj_fname) + print("Including SSP projections from : %s" % proj_fname) # append the eog projs, so they are last in the list projs = mne.read_proj(proj_fname) + projs @@ -196,6 +322,6 @@ def run(): mne.write_events(eog_event_fname, events) -is_main = (__name__ == '__main__') +is_main = __name__ == "__main__" if is_main: run() diff --git a/mne/commands/mne_coreg.py b/mne/commands/mne_coreg.py index 0e25c1f44de..dad18d278aa 100644 --- a/mne/commands/mne_coreg.py +++ b/mne/commands/mne_coreg.py @@ -22,51 +22,98 @@ def run(): parser = get_optparser(__file__) - parser.add_option("-d", "--subjects-dir", dest="subjects_dir", - default=None, help="Subjects directory") - parser.add_option("-s", "--subject", dest="subject", default=None, - help="Subject name") - parser.add_option("-f", "--fiff", dest="inst", default=None, - help="FIFF file with digitizer data for coregistration") - parser.add_option("-t", "--tabbed", dest="tabbed", action="store_true", - default=False, help="Option for small screens: Combine " - "the data source panel and the coregistration panel " - "into a single panel with tabs.") - parser.add_option("--no-guess-mri", dest="guess_mri_subject", - action='store_false', default=None, - help="Prevent the GUI from automatically guessing and " - "changing the MRI subject when a new head shape source " - "file is selected.") - parser.add_option("--head-opacity", type=float, default=None, - dest="head_opacity", - help="The opacity of the head surface, in the range " - "[0, 1].") - parser.add_option("--high-res-head", - action='store_true', default=False, dest="high_res_head", - help="Use a high-resolution head surface.") - parser.add_option("--low-res-head", - action='store_true', default=False, dest="low_res_head", - help="Use a low-resolution head surface.") - parser.add_option('--trans', dest='trans', default=None, - help='Head<->MRI transform FIF file ("-trans.fif")') - parser.add_option('--interaction', - type=str, default=None, dest='interaction', - help='Interaction style to use, can be "trackball" or ' - '"terrain".') - parser.add_option('--scale', - type=float, default=None, dest='scale', - help='Scale factor for the scene.') - parser.add_option('--simple-rendering', action='store_false', - dest='advanced_rendering', - help='Use simplified OpenGL rendering') + parser.add_option( + "-d", + "--subjects-dir", + dest="subjects_dir", + default=None, + help="Subjects directory", + ) + parser.add_option( + "-s", "--subject", dest="subject", default=None, help="Subject name" + ) + parser.add_option( + "-f", + "--fiff", + dest="inst", + default=None, + help="FIFF file with digitizer data for coregistration", + ) + parser.add_option( + "-t", + "--tabbed", + dest="tabbed", + action="store_true", + default=False, + help="Option for small screens: Combine " + "the data source panel and the coregistration panel " + "into a single panel with tabs.", + ) + parser.add_option( + "--no-guess-mri", + dest="guess_mri_subject", + action="store_false", + default=None, + help="Prevent the GUI from automatically guessing and " + "changing the MRI subject when a new head shape source " + "file is selected.", + ) + parser.add_option( + "--head-opacity", + type=float, + default=None, + dest="head_opacity", + help="The opacity of the head surface, in the range " "[0, 1].", + ) + parser.add_option( + "--high-res-head", + action="store_true", + default=False, + dest="high_res_head", + help="Use a high-resolution head surface.", + ) + parser.add_option( + "--low-res-head", + action="store_true", + default=False, + dest="low_res_head", + help="Use a low-resolution head surface.", + ) + parser.add_option( + "--trans", + dest="trans", + default=None, + help='Head<->MRI transform FIF file ("-trans.fif")', + ) + parser.add_option( + "--interaction", + type=str, + default=None, + dest="interaction", + help='Interaction style to use, can be "trackball" or ' '"terrain".', + ) + parser.add_option( + "--scale", + type=float, + default=None, + dest="scale", + help="Scale factor for the scene.", + ) + parser.add_option( + "--simple-rendering", + action="store_false", + dest="advanced_rendering", + help="Use simplified OpenGL rendering", + ) _add_verbose_flag(parser) options, args = parser.parse_args() if options.low_res_head: if options.high_res_head: - raise ValueError("Can't specify --high-res-head and " - "--low-res-head at the same time.") + raise ValueError( + "Can't specify --high-res-head and " "--low-res-head at the same time." + ) head_high_res = False elif options.high_res_head: head_high_res = True @@ -81,18 +128,25 @@ def run(): if trans is not None: trans = op.expanduser(trans) import faulthandler + faulthandler.enable() mne.gui.coregistration( - options.tabbed, inst=options.inst, subject=options.subject, + options.tabbed, + inst=options.inst, + subject=options.subject, subjects_dir=subjects_dir, guess_mri_subject=options.guess_mri_subject, - head_opacity=options.head_opacity, head_high_res=head_high_res, - trans=trans, scrollable=True, + head_opacity=options.head_opacity, + head_high_res=head_high_res, + trans=trans, + scrollable=True, interaction=options.interaction, scale=options.scale, advanced_rendering=options.advanced_rendering, - show=True, block=True, - verbose=options.verbose) + show=True, + block=True, + verbose=options.verbose, + ) mne.utils.run_command_if_main() diff --git a/mne/commands/mne_flash_bem.py b/mne/commands/mne_flash_bem.py index 3556b58a78d..8ffaf57b816 100644 --- a/mne/commands/mne_flash_bem.py +++ b/mne/commands/mne_flash_bem.py @@ -41,7 +41,7 @@ def _vararg_callback(option, opt_str, value, parser): break value.append(arg) - del parser.rargs[:len(value)] + del parser.rargs[: len(value)] setattr(parser.values, option.dest, value) @@ -51,45 +51,103 @@ def run(): parser = get_optparser(__file__) - parser.add_option("-s", "--subject", dest="subject", - help="Subject name", default=None) - parser.add_option("-d", "--subjects-dir", dest="subjects_dir", - help="Subjects directory", default=None) - parser.add_option("-3", "--flash30", "--noflash30", dest="flash30", - action="callback", callback=_vararg_callback, - help=("The 30-degree flip angle data. If no argument do " - "not use flash30. If arguments are given, them as " - "file names.")) - parser.add_option("-5", "--flash5", dest="flash5", - action="callback", callback=_vararg_callback, - help=("Path to the multiecho flash 5 images. " - "Can be one file or one per echo."),) - parser.add_option("-r", "--registered", dest="registered", - action="store_true", default=False, - help=("Set if the Flash MRI images have already " - "been registered with the T1.mgz file.")) - parser.add_option("-n", "--noconvert", dest="noconvert", - action="store_true", default=False, - help=("[DEPRECATED] Assume that the Flash MRI images " - "have already been converted to mgz files")) - parser.add_option("-u", "--unwarp", dest="unwarp", - action="store_true", default=False, - help=("Run grad_unwarp with -unwarp " - "option on each of the converted data sets")) - parser.add_option("-o", "--overwrite", dest="overwrite", - action="store_true", default=False, - help="Write over existing .surf files in bem folder") - parser.add_option("-v", "--view", dest="show", action="store_true", - help="Show BEM model in 3D for visual inspection", - default=False) - parser.add_option("--copy", dest="copy", - help="Use copies instead of symlinks for surfaces", - action="store_true") - parser.add_option("-p", "--flash-path", dest="flash_path", - default=None, - help="[DEPRECATED] The directory containing flash5.mgz " - "files (defaults to " - "$SUBJECTS_DIR/$SUBJECT/mri/flash/parameter_maps") + parser.add_option( + "-s", "--subject", dest="subject", help="Subject name", default=None + ) + parser.add_option( + "-d", + "--subjects-dir", + dest="subjects_dir", + help="Subjects directory", + default=None, + ) + parser.add_option( + "-3", + "--flash30", + "--noflash30", + dest="flash30", + action="callback", + callback=_vararg_callback, + help=( + "The 30-degree flip angle data. If no argument do " + "not use flash30. If arguments are given, them as " + "file names." + ), + ) + parser.add_option( + "-5", + "--flash5", + dest="flash5", + action="callback", + callback=_vararg_callback, + help=( + "Path to the multiecho flash 5 images. " "Can be one file or one per echo." + ), + ) + parser.add_option( + "-r", + "--registered", + dest="registered", + action="store_true", + default=False, + help=( + "Set if the Flash MRI images have already " + "been registered with the T1.mgz file." + ), + ) + parser.add_option( + "-n", + "--noconvert", + dest="noconvert", + action="store_true", + default=False, + help=( + "[DEPRECATED] Assume that the Flash MRI images " + "have already been converted to mgz files" + ), + ) + parser.add_option( + "-u", + "--unwarp", + dest="unwarp", + action="store_true", + default=False, + help=( + "Run grad_unwarp with -unwarp " + "option on each of the converted data sets" + ), + ) + parser.add_option( + "-o", + "--overwrite", + dest="overwrite", + action="store_true", + default=False, + help="Write over existing .surf files in bem folder", + ) + parser.add_option( + "-v", + "--view", + dest="show", + action="store_true", + help="Show BEM model in 3D for visual inspection", + default=False, + ) + parser.add_option( + "--copy", + dest="copy", + help="Use copies instead of symlinks for surfaces", + action="store_true", + ) + parser.add_option( + "-p", + "--flash-path", + dest="flash_path", + default=None, + help="[DEPRECATED] The directory containing flash5.mgz " + "files (defaults to " + "$SUBJECTS_DIR/$SUBJECT/mri/flash/parameter_maps", + ) options, _ = parser.parse_args() @@ -111,15 +169,26 @@ def run(): if options.subject is None: parser.print_help() - raise RuntimeError('The subject argument must be set') + raise RuntimeError("The subject argument must be set") flash5_img = convert_flash_mris( - subject=subject, subjects_dir=subjects_dir, flash5=flash5, - flash30=flash30, unwarp=unwarp, verbose=True + subject=subject, + subjects_dir=subjects_dir, + flash5=flash5, + flash30=flash30, + unwarp=unwarp, + verbose=True, + ) + make_flash_bem( + subject=subject, + subjects_dir=subjects_dir, + overwrite=overwrite, + show=show, + copy=copy, + register=register, + flash5_img=flash5_img, + verbose=True, ) - make_flash_bem(subject=subject, subjects_dir=subjects_dir, - overwrite=overwrite, show=show, copy=copy, - register=register, flash5_img=flash5_img, verbose=True) mne.utils.run_command_if_main() diff --git a/mne/commands/mne_freeview_bem_surfaces.py b/mne/commands/mne_freeview_bem_surfaces.py index f5a65d9fb79..646049b6616 100644 --- a/mne/commands/mne_freeview_bem_surfaces.py +++ b/mne/commands/mne_freeview_bem_surfaces.py @@ -38,39 +38,40 @@ def freeview_bem_surfaces(subject, subjects_dir, method): subject_dir = op.join(subjects_dir, subject) if not op.isdir(subject_dir): - raise ValueError("Wrong path: '{}'. Check subjects-dir or" - "subject argument.".format(subject_dir)) + raise ValueError( + "Wrong path: '{}'. Check subjects-dir or" + "subject argument.".format(subject_dir) + ) env = os.environ.copy() - env['SUBJECT'] = subject - env['SUBJECTS_DIR'] = subjects_dir + env["SUBJECT"] = subject + env["SUBJECTS_DIR"] = subjects_dir - if 'FREESURFER_HOME' not in env: - raise RuntimeError('The FreeSurfer environment needs to be set up.') + if "FREESURFER_HOME" not in env: + raise RuntimeError("The FreeSurfer environment needs to be set up.") - mri_dir = op.join(subject_dir, 'mri') - bem_dir = op.join(subject_dir, 'bem') - mri = op.join(mri_dir, 'T1.mgz') + mri_dir = op.join(subject_dir, "mri") + bem_dir = op.join(subject_dir, "bem") + mri = op.join(mri_dir, "T1.mgz") - if method == 'watershed': - bem_dir = op.join(bem_dir, 'watershed') - outer_skin = op.join(bem_dir, '%s_outer_skin_surface' % subject) - outer_skull = op.join(bem_dir, '%s_outer_skull_surface' % subject) - inner_skull = op.join(bem_dir, '%s_inner_skull_surface' % subject) + if method == "watershed": + bem_dir = op.join(bem_dir, "watershed") + outer_skin = op.join(bem_dir, "%s_outer_skin_surface" % subject) + outer_skull = op.join(bem_dir, "%s_outer_skull_surface" % subject) + inner_skull = op.join(bem_dir, "%s_inner_skull_surface" % subject) else: - if method == 'flash': - bem_dir = op.join(bem_dir, 'flash') - outer_skin = op.join(bem_dir, 'outer_skin.surf') - outer_skull = op.join(bem_dir, 'outer_skull.surf') - inner_skull = op.join(bem_dir, 'inner_skull.surf') + if method == "flash": + bem_dir = op.join(bem_dir, "flash") + outer_skin = op.join(bem_dir, "outer_skin.surf") + outer_skull = op.join(bem_dir, "outer_skull.surf") + inner_skull = op.join(bem_dir, "inner_skull.surf") # put together the command - cmd = ['freeview'] + cmd = ["freeview"] cmd += ["--volume", mri] cmd += ["--surface", "%s:color=red:edgecolor=red" % inner_skull] cmd += ["--surface", "%s:color=yellow:edgecolor=yellow" % outer_skull] - cmd += ["--surface", - "%s:color=255,170,127:edgecolor=255,170,127" % outer_skin] + cmd += ["--surface", "%s:color=255,170,127:edgecolor=255,170,127" % outer_skin] run_subprocess(cmd, env=env, stdout=sys.stdout) print("[done]") @@ -82,18 +83,27 @@ def run(): parser = get_optparser(__file__) - subject = os.environ.get('SUBJECT') + subject = os.environ.get("SUBJECT") subjects_dir = get_subjects_dir() if subjects_dir is not None: subjects_dir = str(subjects_dir) - parser.add_option("-s", "--subject", dest="subject", - help="Subject name", default=subject) - parser.add_option("-d", "--subjects-dir", dest="subjects_dir", - help="Subjects directory", default=subjects_dir) - parser.add_option("-m", "--method", dest="method", - help=("Method used to generate the BEM model. " - "Can be flash or watershed.")) + parser.add_option( + "-s", "--subject", dest="subject", help="Subject name", default=subject + ) + parser.add_option( + "-d", + "--subjects-dir", + dest="subjects_dir", + help="Subjects directory", + default=subjects_dir, + ) + parser.add_option( + "-m", + "--method", + dest="method", + help=("Method used to generate the BEM model. " "Can be flash or watershed."), + ) options, args = parser.parse_args() diff --git a/mne/commands/mne_kit2fiff.py b/mne/commands/mne_kit2fiff.py index 1317a154c8c..0c6b4545203 100644 --- a/mne/commands/mne_kit2fiff.py +++ b/mne/commands/mne_kit2fiff.py @@ -29,33 +29,50 @@ def run(): parser = get_optparser(__file__) - parser.add_option('--input', dest='input_fname', - help='Input data file name', metavar='filename') - parser.add_option('--mrk', dest='mrk_fname', - help='MEG Marker file name', metavar='filename') - parser.add_option('--elp', dest='elp_fname', - help='Headshape points file name', metavar='filename') - parser.add_option('--hsp', dest='hsp_fname', - help='Headshape file name', metavar='filename') - parser.add_option('--stim', dest='stim', - help='Colon Separated Stimulus Trigger Channels', - metavar='chs') - parser.add_option('--slope', dest='slope', help='Slope direction', - metavar='slope') - parser.add_option('--stimthresh', dest='stimthresh', default=1, - help='Threshold value for trigger channels', - metavar='value') - parser.add_option('--output', dest='out_fname', - help='Name of the resulting fiff file', - metavar='filename') - parser.add_option('--debug', dest='debug', action='store_true', - default=False, - help='Set logging level for terminal output to debug') + parser.add_option( + "--input", dest="input_fname", help="Input data file name", metavar="filename" + ) + parser.add_option( + "--mrk", dest="mrk_fname", help="MEG Marker file name", metavar="filename" + ) + parser.add_option( + "--elp", dest="elp_fname", help="Headshape points file name", metavar="filename" + ) + parser.add_option( + "--hsp", dest="hsp_fname", help="Headshape file name", metavar="filename" + ) + parser.add_option( + "--stim", + dest="stim", + help="Colon Separated Stimulus Trigger Channels", + metavar="chs", + ) + parser.add_option("--slope", dest="slope", help="Slope direction", metavar="slope") + parser.add_option( + "--stimthresh", + dest="stimthresh", + default=1, + help="Threshold value for trigger channels", + metavar="value", + ) + parser.add_option( + "--output", + dest="out_fname", + help="Name of the resulting fiff file", + metavar="filename", + ) + parser.add_option( + "--debug", + dest="debug", + action="store_true", + default=False, + help="Set logging level for terminal output to debug", + ) options, args = parser.parse_args() if options.debug: - mne.set_log_level('debug') + mne.set_log_level("debug") input_fname = options.input_fname if input_fname is None: @@ -63,8 +80,8 @@ def run(): from mne_kit_gui import kit2fiff # noqa except ImportError: raise ImportError( - 'The mne-kit-gui package is required, install it using ' - 'conda or pip') from None + "The mne-kit-gui package is required, install it using " "conda or pip" + ) from None kit2fiff() sys.exit(0) @@ -77,11 +94,17 @@ def run(): out_fname = options.out_fname if isinstance(stim, str): - stim = map(int, stim.split(':')) - - raw = read_raw_kit(input_fname=input_fname, mrk=mrk_fname, elp=elp_fname, - hsp=hsp_fname, stim=stim, slope=slope, - stimthresh=stimthresh) + stim = map(int, stim.split(":")) + + raw = read_raw_kit( + input_fname=input_fname, + mrk=mrk_fname, + elp=elp_fname, + hsp=hsp_fname, + stim=stim, + slope=slope, + stimthresh=stimthresh, + ) raw.save(out_fname) raw.close() diff --git a/mne/commands/mne_make_scalp_surfaces.py b/mne/commands/mne_make_scalp_surfaces.py index 9da7941384c..c5bf03e06a0 100644 --- a/mne/commands/mne_make_scalp_surfaces.py +++ b/mne/commands/mne_make_scalp_surfaces.py @@ -27,29 +27,60 @@ def run(): from mne.commands.utils import get_optparser, _add_verbose_flag parser = get_optparser(__file__) - subjects_dir = mne.get_config('SUBJECTS_DIR') + subjects_dir = mne.get_config("SUBJECTS_DIR") - parser.add_option('-o', '--overwrite', dest='overwrite', - action='store_true', - help='Overwrite previously computed surface') - parser.add_option('-s', '--subject', dest='subject', - help='The name of the subject', type='str') - parser.add_option('-m', '--mri', dest='mri', type='str', default='T1.mgz', - help='The MRI file to process using mkheadsurf.') - parser.add_option('-f', '--force', dest='force', action='store_true', - help='Force creation of the surface even if it has ' - 'some topological defects.') - parser.add_option('-t', '--threshold', dest='threshold', type='int', - default=20, help='Threshold value to use with the MRI.') - parser.add_option("-d", "--subjects-dir", dest="subjects_dir", - help="Subjects directory", default=subjects_dir) - parser.add_option("-n", "--no-decimate", dest="no_decimate", - help="Disable medium and sparse decimations " - "(dense only)", action='store_true') + parser.add_option( + "-o", + "--overwrite", + dest="overwrite", + action="store_true", + help="Overwrite previously computed surface", + ) + parser.add_option( + "-s", "--subject", dest="subject", help="The name of the subject", type="str" + ) + parser.add_option( + "-m", + "--mri", + dest="mri", + type="str", + default="T1.mgz", + help="The MRI file to process using mkheadsurf.", + ) + parser.add_option( + "-f", + "--force", + dest="force", + action="store_true", + help="Force creation of the surface even if it has " + "some topological defects.", + ) + parser.add_option( + "-t", + "--threshold", + dest="threshold", + type="int", + default=20, + help="Threshold value to use with the MRI.", + ) + parser.add_option( + "-d", + "--subjects-dir", + dest="subjects_dir", + help="Subjects directory", + default=subjects_dir, + ) + parser.add_option( + "-n", + "--no-decimate", + dest="no_decimate", + help="Disable medium and sparse decimations " "(dense only)", + action="store_true", + ) _add_verbose_flag(parser) options, args = parser.parse_args() - subject = vars(options).get('subject', os.getenv('SUBJECT')) + subject = vars(options).get("subject", os.getenv("SUBJECT")) subjects_dir = options.subjects_dir if subject is None or subjects_dir is None: parser.print_help() @@ -62,7 +93,8 @@ def run(): no_decimate=options.no_decimate, threshold=options.threshold, mri=options.mri, - verbose=options.verbose) + verbose=options.verbose, + ) mne.utils.run_command_if_main() diff --git a/mne/commands/mne_maxfilter.py b/mne/commands/mne_maxfilter.py index 4825b4d5553..182a2c6254b 100644 --- a/mne/commands/mne_maxfilter.py +++ b/mne/commands/mne_maxfilter.py @@ -25,71 +25,157 @@ def run(): parser = get_optparser(__file__) - parser.add_option("-i", "--in", dest="in_fname", - help="Input raw FIF file", metavar="FILE") - parser.add_option("-o", dest="out_fname", - help="Output FIF file (if not set, suffix '_sss' will " - "be used)", metavar="FILE", default=None) - parser.add_option("--origin", dest="origin", - help="Head origin in mm, or a filename to read the " - "origin from. If not set it will be estimated from " - "headshape points", default=None) - parser.add_option("--origin-out", dest="origin_out", - help="Filename to use for computed origin", default=None) - parser.add_option("--frame", dest="frame", type="string", - help="Coordinate frame for head center ('device' or " - "'head')", default="device") - parser.add_option("--bad", dest="bad", type="string", - help="List of static bad channels", - default=None) - parser.add_option("--autobad", dest="autobad", type="string", - help="Set automated bad channel detection ('on', 'off', " - "'n')", default="off") - parser.add_option("--skip", dest="skip", - help="Skips raw data sequences, time intervals pairs in " - "s, e.g.: 0 30 120 150", default=None) - parser.add_option("--force", dest="force", action="store_true", - help="Ignore program warnings", - default=False) - parser.add_option("--st", dest="st", action="store_true", - help="Apply the time-domain MaxST extension", - default=False) - parser.add_option("--buflen", dest="st_buflen", type="float", - help="MaxSt buffer length in s", - default=16.0) - parser.add_option("--corr", dest="st_corr", type="float", - help="MaxSt subspace correlation", - default=0.96) - parser.add_option("--trans", dest="mv_trans", - help="Transforms the data into the coil definitions of " - "in_fname, or into the default frame", default=None) - parser.add_option("--movecomp", dest="mv_comp", action="store_true", - help="Estimates and compensates head movements in " - "continuous raw data", default=False) - parser.add_option("--headpos", dest="mv_headpos", action="store_true", - help="Estimates and stores head position parameters, " - "but does not compensate movements", default=False) - parser.add_option("--hp", dest="mv_hp", type="string", - help="Stores head position data in an ascii file", - default=None) - parser.add_option("--hpistep", dest="mv_hpistep", type="float", - help="Sets head position update interval in ms", - default=None) - parser.add_option("--hpisubt", dest="mv_hpisubt", type="string", - help="Subtracts hpi signals: sine amplitudes, amp + " - "baseline, or switch off", default=None) - parser.add_option("--nohpicons", dest="mv_hpicons", action="store_false", - help="Do not check initial consistency isotrak vs " - "hpifit", default=True) - parser.add_option("--linefreq", dest="linefreq", type="float", - help="Sets the basic line interference frequency (50 or " - "60 Hz)", default=None) - parser.add_option("--nooverwrite", dest="overwrite", action="store_false", - help="Do not overwrite output file if it already exists", - default=True) - parser.add_option("--args", dest="mx_args", type="string", - help="Additional command line arguments to pass to " - "MaxFilter", default="") + parser.add_option( + "-i", "--in", dest="in_fname", help="Input raw FIF file", metavar="FILE" + ) + parser.add_option( + "-o", + dest="out_fname", + help="Output FIF file (if not set, suffix '_sss' will " "be used)", + metavar="FILE", + default=None, + ) + parser.add_option( + "--origin", + dest="origin", + help="Head origin in mm, or a filename to read the " + "origin from. If not set it will be estimated from " + "headshape points", + default=None, + ) + parser.add_option( + "--origin-out", + dest="origin_out", + help="Filename to use for computed origin", + default=None, + ) + parser.add_option( + "--frame", + dest="frame", + type="string", + help="Coordinate frame for head center ('device' or " "'head')", + default="device", + ) + parser.add_option( + "--bad", + dest="bad", + type="string", + help="List of static bad channels", + default=None, + ) + parser.add_option( + "--autobad", + dest="autobad", + type="string", + help="Set automated bad channel detection ('on', 'off', " "'n')", + default="off", + ) + parser.add_option( + "--skip", + dest="skip", + help="Skips raw data sequences, time intervals pairs in " + "s, e.g.: 0 30 120 150", + default=None, + ) + parser.add_option( + "--force", + dest="force", + action="store_true", + help="Ignore program warnings", + default=False, + ) + parser.add_option( + "--st", + dest="st", + action="store_true", + help="Apply the time-domain MaxST extension", + default=False, + ) + parser.add_option( + "--buflen", + dest="st_buflen", + type="float", + help="MaxSt buffer length in s", + default=16.0, + ) + parser.add_option( + "--corr", + dest="st_corr", + type="float", + help="MaxSt subspace correlation", + default=0.96, + ) + parser.add_option( + "--trans", + dest="mv_trans", + help="Transforms the data into the coil definitions of " + "in_fname, or into the default frame", + default=None, + ) + parser.add_option( + "--movecomp", + dest="mv_comp", + action="store_true", + help="Estimates and compensates head movements in " "continuous raw data", + default=False, + ) + parser.add_option( + "--headpos", + dest="mv_headpos", + action="store_true", + help="Estimates and stores head position parameters, " + "but does not compensate movements", + default=False, + ) + parser.add_option( + "--hp", + dest="mv_hp", + type="string", + help="Stores head position data in an ascii file", + default=None, + ) + parser.add_option( + "--hpistep", + dest="mv_hpistep", + type="float", + help="Sets head position update interval in ms", + default=None, + ) + parser.add_option( + "--hpisubt", + dest="mv_hpisubt", + type="string", + help="Subtracts hpi signals: sine amplitudes, amp + " "baseline, or switch off", + default=None, + ) + parser.add_option( + "--nohpicons", + dest="mv_hpicons", + action="store_false", + help="Do not check initial consistency isotrak vs " "hpifit", + default=True, + ) + parser.add_option( + "--linefreq", + dest="linefreq", + type="float", + help="Sets the basic line interference frequency (50 or " "60 Hz)", + default=None, + ) + parser.add_option( + "--nooverwrite", + dest="overwrite", + action="store_false", + help="Do not overwrite output file if it already exists", + default=True, + ) + parser.add_option( + "--args", + dest="mx_args", + type="string", + help="Additional command line arguments to pass to " "MaxFilter", + default="", + ) options, args = parser.parse_args() @@ -121,30 +207,48 @@ def run(): overwrite = options.overwrite mx_args = options.mx_args - if in_fname.endswith('_raw.fif') or in_fname.endswith('-raw.fif'): + if in_fname.endswith("_raw.fif") or in_fname.endswith("-raw.fif"): prefix = in_fname[:-8] else: prefix = in_fname[:-4] if out_fname is None: if st: - out_fname = prefix + '_tsss.fif' + out_fname = prefix + "_tsss.fif" else: - out_fname = prefix + '_sss.fif' + out_fname = prefix + "_sss.fif" if origin is not None and os.path.exists(origin): - with open(origin, 'r') as fid: + with open(origin, "r") as fid: origin = fid.readlines()[0].strip() origin = mne.preprocessing.apply_maxfilter( - in_fname, out_fname, origin, frame, - bad, autobad, skip, force, st, st_buflen, st_corr, mv_trans, - mv_comp, mv_headpos, mv_hp, mv_hpistep, mv_hpisubt, mv_hpicons, - linefreq, mx_args, overwrite) + in_fname, + out_fname, + origin, + frame, + bad, + autobad, + skip, + force, + st, + st_buflen, + st_corr, + mv_trans, + mv_comp, + mv_headpos, + mv_hp, + mv_hpistep, + mv_hpisubt, + mv_hpicons, + linefreq, + mx_args, + overwrite, + ) if origin_out is not None: - with open(origin_out, 'w') as fid: - fid.write(origin + '\n') + with open(origin_out, "w") as fid: + fid.write(origin + "\n") mne.utils.run_command_if_main() diff --git a/mne/commands/mne_prepare_bem_model.py b/mne/commands/mne_prepare_bem_model.py index da308bb737e..ae43ae9533a 100644 --- a/mne/commands/mne_prepare_bem_model.py +++ b/mne/commands/mne_prepare_bem_model.py @@ -20,18 +20,25 @@ def run(): parser = get_optparser(__file__) - parser.add_option('--bem', dest='bem_fname', - help='The name of the file containing the ' - 'triangulations of the BEM surfaces and the ' - 'conductivities of the compartments. The standard ' - 'ending for this file is -bem.fif.', - metavar="FILE") - parser.add_option('--sol', dest='bem_sol_fname', - help='The name of the resulting file containing BEM ' - 'solution (geometry matrix). It uses the linear ' - 'collocation approach. The file should end with ' - '-bem-sof.fif.', - metavar='FILE', default=None) + parser.add_option( + "--bem", + dest="bem_fname", + help="The name of the file containing the " + "triangulations of the BEM surfaces and the " + "conductivities of the compartments. The standard " + "ending for this file is -bem.fif.", + metavar="FILE", + ) + parser.add_option( + "--sol", + dest="bem_sol_fname", + help="The name of the resulting file containing BEM " + "solution (geometry matrix). It uses the linear " + "collocation approach. The file should end with " + "-bem-sof.fif.", + metavar="FILE", + default=None, + ) _add_verbose_flag(parser) options, args = parser.parse_args() @@ -45,10 +52,9 @@ def run(): if bem_sol_fname is None: base, _ = os.path.splitext(bem_fname) - bem_sol_fname = base + '-sol.fif' + bem_sol_fname = base + "-sol.fif" - bem_model = mne.read_bem_surfaces(bem_fname, patch_stats=False, - verbose=verbose) + bem_model = mne.read_bem_surfaces(bem_fname, patch_stats=False, verbose=verbose) bem_solution = mne.make_bem_solution(bem_model, verbose=verbose) mne.write_bem_solution(bem_sol_fname, bem_solution) diff --git a/mne/commands/mne_report.py b/mne/commands/mne_report.py index 2d96570f26f..79818d52bab 100644 --- a/mne/commands/mne_report.py +++ b/mne/commands/mne_report.py @@ -78,7 +78,7 @@ @verbose def log_elapsed(t, verbose=None): """Log elapsed time.""" - logger.info('Report complete in %s seconds' % round(t, 1)) + logger.info("Report complete in %s seconds" % round(t, 1)) def run(): @@ -87,36 +87,72 @@ def run(): parser = get_optparser(__file__) - parser.add_option("-p", "--path", dest="path", - help="Path to folder who MNE-Report must be created") - parser.add_option("-i", "--info", dest="info_fname", - help="File from which info dictionary is to be read", - metavar="FILE") - parser.add_option("-c", "--cov", dest="cov_fname", - help="File from which noise covariance is to be read", - metavar="FILE") - parser.add_option("--bmin", dest="bmin", - help="Time at which baseline correction starts for " - "evokeds", default=None) - parser.add_option("--bmax", dest="bmax", - help="Time at which baseline correction stops for " - "evokeds", default=None) - parser.add_option("-d", "--subjects-dir", dest="subjects_dir", - help="The subjects directory") - parser.add_option("-s", "--subject", dest="subject", - help="The subject name") - parser.add_option("--no-browser", dest="no_browser", action='store_false', - help="Do not open MNE-Report in browser") - parser.add_option("--overwrite", dest="overwrite", action='store_false', - help="Overwrite html report if it already exists") - parser.add_option("-j", "--jobs", dest="n_jobs", help="Number of jobs to" - " run in parallel") - parser.add_option("-m", "--mri-decim", type="int", dest="mri_decim", - default=2, help="Integer factor used to decimate " - "BEM plots") - parser.add_option("--image-format", type="str", dest="image_format", - default='png', help="Image format to use " - "(can be 'png' or 'svg')") + parser.add_option( + "-p", + "--path", + dest="path", + help="Path to folder who MNE-Report must be created", + ) + parser.add_option( + "-i", + "--info", + dest="info_fname", + help="File from which info dictionary is to be read", + metavar="FILE", + ) + parser.add_option( + "-c", + "--cov", + dest="cov_fname", + help="File from which noise covariance is to be read", + metavar="FILE", + ) + parser.add_option( + "--bmin", + dest="bmin", + help="Time at which baseline correction starts for " "evokeds", + default=None, + ) + parser.add_option( + "--bmax", + dest="bmax", + help="Time at which baseline correction stops for " "evokeds", + default=None, + ) + parser.add_option( + "-d", "--subjects-dir", dest="subjects_dir", help="The subjects directory" + ) + parser.add_option("-s", "--subject", dest="subject", help="The subject name") + parser.add_option( + "--no-browser", + dest="no_browser", + action="store_false", + help="Do not open MNE-Report in browser", + ) + parser.add_option( + "--overwrite", + dest="overwrite", + action="store_false", + help="Overwrite html report if it already exists", + ) + parser.add_option( + "-j", "--jobs", dest="n_jobs", help="Number of jobs to" " run in parallel" + ) + parser.add_option( + "-m", + "--mri-decim", + type="int", + dest="mri_decim", + default=2, + help="Integer factor used to decimate " "BEM plots", + ) + parser.add_option( + "--image-format", + type="str", + dest="image_format", + default="png", + help="Image format to use " "(can be 'png' or 'svg')", + ) _add_verbose_flag(parser) options, args = parser.parse_args() @@ -144,12 +180,16 @@ def run(): baseline = (bmin, bmax) t0 = time.time() - report = Report(info_fname, subjects_dir=subjects_dir, - subject=subject, baseline=baseline, - cov_fname=cov_fname, verbose=verbose, - image_format=image_format) - report.parse_folder(path, verbose=verbose, n_jobs=n_jobs, - mri_decim=mri_decim) + report = Report( + info_fname, + subjects_dir=subjects_dir, + subject=subject, + baseline=baseline, + cov_fname=cov_fname, + verbose=verbose, + image_format=image_format, + ) + report.parse_folder(path, verbose=verbose, n_jobs=n_jobs, mri_decim=mri_decim) log_elapsed(time.time() - t0, verbose=verbose) report.save(open_browser=open_browser, overwrite=overwrite) diff --git a/mne/commands/mne_setup_forward_model.py b/mne/commands/mne_setup_forward_model.py index 239decefbfe..df7fc5fff4b 100644 --- a/mne/commands/mne_setup_forward_model.py +++ b/mne/commands/mne_setup_forward_model.py @@ -21,51 +21,66 @@ def run(): parser = get_optparser(__file__) - parser.add_option("-s", "--subject", - dest="subject", - help="Subject name (required)", - default=None) - parser.add_option("--model", - dest="model", - help="Output file name. Use a name /-bem.fif", - default=None, - type='string') - parser.add_option('--ico', - dest='ico', - help='The surface ico downsampling to use, e.g. ' - ' 5=20484, 4=5120, 3=1280. If None, no subsampling' - ' is applied.', - default=None, - type='int') - parser.add_option('--brainc', - dest='brainc', - help='Defines the brain compartment conductivity. ' - 'The default value is 0.3 S/m.', - default=0.3, - type='float') - parser.add_option('--skullc', - dest='skullc', - help='Defines the skull compartment conductivity. ' - 'The default value is 0.006 S/m.', - default=None, - type='float') - parser.add_option('--scalpc', - dest='scalpc', - help='Defines the scalp compartment conductivity. ' - 'The default value is 0.3 S/m.', - default=None, - type='float') - parser.add_option('--homog', - dest='homog', - help='Use a single compartment model (brain only) ' - 'instead a three layer one (scalp, skull, and ' - ' brain). If this flag is specified, the options ' - '--skullc and --scalpc are irrelevant.', - default=None, action="store_true") - parser.add_option('-d', '--subjects-dir', - dest='subjects_dir', - help='Subjects directory', - default=None) + parser.add_option( + "-s", "--subject", dest="subject", help="Subject name (required)", default=None + ) + parser.add_option( + "--model", + dest="model", + help="Output file name. Use a name /-bem.fif", + default=None, + type="string", + ) + parser.add_option( + "--ico", + dest="ico", + help="The surface ico downsampling to use, e.g. " + " 5=20484, 4=5120, 3=1280. If None, no subsampling" + " is applied.", + default=None, + type="int", + ) + parser.add_option( + "--brainc", + dest="brainc", + help="Defines the brain compartment conductivity. " + "The default value is 0.3 S/m.", + default=0.3, + type="float", + ) + parser.add_option( + "--skullc", + dest="skullc", + help="Defines the skull compartment conductivity. " + "The default value is 0.006 S/m.", + default=None, + type="float", + ) + parser.add_option( + "--scalpc", + dest="scalpc", + help="Defines the scalp compartment conductivity. " + "The default value is 0.3 S/m.", + default=None, + type="float", + ) + parser.add_option( + "--homog", + dest="homog", + help="Use a single compartment model (brain only) " + "instead a three layer one (scalp, skull, and " + " brain). If this flag is specified, the options " + "--skullc and --scalpc are irrelevant.", + default=None, + action="store_true", + ) + parser.add_option( + "-d", + "--subjects-dir", + dest="subjects_dir", + help="Subjects directory", + default=None, + ) _add_verbose_flag(parser) options, args = parser.parse_args() @@ -85,11 +100,15 @@ def run(): # Parse conductivity option if homog is True: if skullc is not None: - warn('Trying to set the skull conductivity for a single layer ' - 'model. To use a 3 layer model, do not set the --homog flag.') + warn( + "Trying to set the skull conductivity for a single layer " + "model. To use a 3 layer model, do not set the --homog flag." + ) if scalpc is not None: - warn('Trying to set the scalp conductivity for a single layer ' - 'model. To use a 3 layer model, do not set the --homog flag.') + warn( + "Trying to set the scalp conductivity for a single layer " + "model. To use a 3 layer model, do not set the --homog flag." + ) # Single layer conductivity = [brainc] else: @@ -99,17 +118,19 @@ def run(): scalpc = 0.3 conductivity = [brainc, skullc, scalpc] # Create source space - bem_model = mne.make_bem_model(subject, - ico=ico, - conductivity=conductivity, - subjects_dir=subjects_dir, - verbose=verbose) + bem_model = mne.make_bem_model( + subject, + ico=ico, + conductivity=conductivity, + subjects_dir=subjects_dir, + verbose=verbose, + ) # Generate filename if fname is None: - n_faces = list(str(len(surface['tris'])) for surface in bem_model) - fname = subject + '-' + '-'.join(n_faces) + '-bem.fif' + n_faces = list(str(len(surface["tris"])) for surface in bem_model) + fname = subject + "-" + "-".join(n_faces) + "-bem.fif" else: - if not (fname.endswith('-bem.fif') or fname.endswith('_bem.fif')): + if not (fname.endswith("-bem.fif") or fname.endswith("_bem.fif")): fname = fname + "-bem.fif" # Save to subject's directory subjects_dir = get_subjects_dir(subjects_dir, raise_error=True) diff --git a/mne/commands/mne_setup_source_space.py b/mne/commands/mne_setup_source_space.py index e8b14b78db3..49bf0b9ed06 100644 --- a/mne/commands/mne_setup_source_space.py +++ b/mne/commands/mne_setup_source_space.py @@ -22,64 +22,89 @@ def run(): """Run command.""" from mne.commands.utils import get_optparser, _add_verbose_flag + parser = get_optparser(__file__) - parser.add_option('-s', '--subject', - dest='subject', - help='Subject name (required)', - default=None) - parser.add_option('--src', dest='fname', - help='Output file name. Use a name /-src.fif', - metavar='FILE', default=None) - parser.add_option('--morph', - dest='subject_to', - help='morph the source space to this subject', - default=None) - parser.add_option('--surf', - dest='surface', - help='The surface to use. (default to white)', - default='white', - type='string') - parser.add_option('--spacing', - dest='spacing', - help='Specifies the approximate grid spacing of the ' - 'source space in mm. (default to 7mm)', - default=None, - type='int') - parser.add_option('--ico', - dest='ico', - help='use the recursively subdivided icosahedron ' - 'to create the source space.', - default=None, - type='int') - parser.add_option('--oct', - dest='oct', - help='use the recursively subdivided octahedron ' - 'to create the source space.', - default=None, - type='int') - parser.add_option('-d', '--subjects-dir', - dest='subjects_dir', - help='Subjects directory', - default=None) - parser.add_option('-n', '--n-jobs', - dest='n_jobs', - help='The number of jobs to run in parallel ' - '(default 1). Requires the joblib package. ' - 'Will use at most 2 jobs' - ' (one for each hemisphere).', - default=1, - type='int') - parser.add_option('--add-dist', - dest='add_dist', - help='Add distances. Can be "True", "False", or "patch" ' - 'to only compute cortical patch statistics (like the ' - '--cps option in MNE-C; requires SciPy >= 1.3)', - default='True') - parser.add_option('-o', '--overwrite', - dest='overwrite', - help='to write over existing files', - default=None, action="store_true") + parser.add_option( + "-s", "--subject", dest="subject", help="Subject name (required)", default=None + ) + parser.add_option( + "--src", + dest="fname", + help="Output file name. Use a name /-src.fif", + metavar="FILE", + default=None, + ) + parser.add_option( + "--morph", + dest="subject_to", + help="morph the source space to this subject", + default=None, + ) + parser.add_option( + "--surf", + dest="surface", + help="The surface to use. (default to white)", + default="white", + type="string", + ) + parser.add_option( + "--spacing", + dest="spacing", + help="Specifies the approximate grid spacing of the " + "source space in mm. (default to 7mm)", + default=None, + type="int", + ) + parser.add_option( + "--ico", + dest="ico", + help="use the recursively subdivided icosahedron " + "to create the source space.", + default=None, + type="int", + ) + parser.add_option( + "--oct", + dest="oct", + help="use the recursively subdivided octahedron " "to create the source space.", + default=None, + type="int", + ) + parser.add_option( + "-d", + "--subjects-dir", + dest="subjects_dir", + help="Subjects directory", + default=None, + ) + parser.add_option( + "-n", + "--n-jobs", + dest="n_jobs", + help="The number of jobs to run in parallel " + "(default 1). Requires the joblib package. " + "Will use at most 2 jobs" + " (one for each hemisphere).", + default=1, + type="int", + ) + parser.add_option( + "--add-dist", + dest="add_dist", + help='Add distances. Can be "True", "False", or "patch" ' + "to only compute cortical patch statistics (like the " + "--cps option in MNE-C; requires SciPy >= 1.3)", + default="True", + ) + parser.add_option( + "-o", + "--overwrite", + dest="overwrite", + help="to write over existing files", + default=None, + action="store_true", + ) _add_verbose_flag(parser) options, args = parser.parse_args() @@ -98,8 +123,8 @@ def run(): surface = options.surface n_jobs = options.n_jobs add_dist = options.add_dist - _check_option('add_dist', add_dist, ('True', 'False', 'patch')) - add_dist = {'True': True, 'False': False, 'patch': 'patch'}[add_dist] + _check_option("add_dist", add_dist, ("True", "False", "patch")) + add_dist = {"True": True, "False": False, "patch": "patch"}[add_dist] verbose = True if options.verbose is not None else False overwrite = True if options.overwrite is not None else False @@ -107,10 +132,10 @@ def run(): spacing_options = [ico, oct, spacing] n_options = len([x for x in spacing_options if x is not None]) if n_options > 1: - raise ValueError('Only one spacing option can be set at the same time') + raise ValueError("Only one spacing option can be set at the same time") elif n_options == 0: # Default to oct6 - use_spacing = 'oct6' + use_spacing = "oct6" elif n_options == 1: if ico is not None: use_spacing = "ico" + str(ico) @@ -121,23 +146,31 @@ def run(): # Generate filename if fname is None: if subject_to is None: - fname = subject + '-' + str(use_spacing) + '-src.fif' + fname = subject + "-" + str(use_spacing) + "-src.fif" else: - fname = (subject_to + '-' + subject + '-' + - str(use_spacing) + '-src.fif') + fname = subject_to + "-" + subject + "-" + str(use_spacing) + "-src.fif" else: - if not (fname.endswith('_src.fif') or fname.endswith('-src.fif')): + if not (fname.endswith("_src.fif") or fname.endswith("-src.fif")): fname = fname + "-src.fif" # Create source space - src = mne.setup_source_space(subject=subject, spacing=use_spacing, - surface=surface, subjects_dir=subjects_dir, - n_jobs=n_jobs, add_dist=add_dist, - verbose=verbose) + src = mne.setup_source_space( + subject=subject, + spacing=use_spacing, + surface=surface, + subjects_dir=subjects_dir, + n_jobs=n_jobs, + add_dist=add_dist, + verbose=verbose, + ) # Morph source space if --morph is set if subject_to is not None: - src = mne.morph_source_spaces(src, subject_to=subject_to, - subjects_dir=subjects_dir, - surf=surface, verbose=verbose) + src = mne.morph_source_spaces( + src, + subject_to=subject_to, + subjects_dir=subjects_dir, + surf=surface, + verbose=verbose, + ) # Save source space to file src.save(fname=fname, overwrite=overwrite) diff --git a/mne/commands/mne_show_fiff.py b/mne/commands/mne_show_fiff.py index be31cde2ad8..ed6fccdf89e 100644 --- a/mne/commands/mne_show_fiff.py +++ b/mne/commands/mne_show_fiff.py @@ -24,10 +24,14 @@ def run(): """Run command.""" - parser = mne.commands.utils.get_optparser( - __file__, usage='mne show_fiff ') - parser.add_option("-t", "--tag", dest="tag", - help="provide information about this tag", metavar="TAG") + parser = mne.commands.utils.get_optparser(__file__, usage="mne show_fiff ") + parser.add_option( + "-t", + "--tag", + dest="tag", + help="provide information about this tag", + metavar="TAG", + ) options, args = parser.parse_args() if len(args) != 1: parser.print_help() diff --git a/mne/commands/mne_show_info.py b/mne/commands/mne_show_info.py index 44e1fa79141..dc39491fb6c 100644 --- a/mne/commands/mne_show_info.py +++ b/mne/commands/mne_show_info.py @@ -17,8 +17,7 @@ def run(): """Run command.""" - parser = mne.commands.utils.get_optparser( - __file__, usage='mne show_info ') + parser = mne.commands.utils.get_optparser(__file__, usage="mne show_info ") options, args = parser.parse_args() if len(args) != 1: parser.print_help() @@ -26,8 +25,8 @@ def run(): fname = args[0] - if not fname.endswith('.fif'): - raise ValueError('%s does not seem to be a .fif file.' % fname) + if not fname.endswith(".fif"): + raise ValueError("%s does not seem to be a .fif file." % fname) info = mne.io.read_info(fname) print("File : %s" % fname) diff --git a/mne/commands/mne_surf2bem.py b/mne/commands/mne_surf2bem.py index 4cb5ade9662..93a154b2477 100644 --- a/mne/commands/mne_surf2bem.py +++ b/mne/commands/mne_surf2bem.py @@ -25,12 +25,19 @@ def run(): parser = get_optparser(__file__) - parser.add_option("-s", "--surf", dest="surf", - help="Surface in Freesurfer format", metavar="FILE") - parser.add_option("-f", "--fif", dest="fif", - help="FIF file produced", metavar="FILE") - parser.add_option("-i", "--id", dest="id", default=4, - help=("Surface Id (e.g. 4 for head surface)")) + parser.add_option( + "-s", "--surf", dest="surf", help="Surface in Freesurfer format", metavar="FILE" + ) + parser.add_option( + "-f", "--fif", dest="fif", help="FIF file produced", metavar="FILE" + ) + parser.add_option( + "-i", + "--id", + dest="id", + default=4, + help=("Surface Id (e.g. 4 for head surface)"), + ) options, args = parser.parse_args() @@ -39,8 +46,7 @@ def run(): sys.exit(1) print("Converting %s to BEM FIF file." % options.surf) - surf = mne.bem._surfaces_to_bem([options.surf], [int(options.id)], - sigmas=[1]) + surf = mne.bem._surfaces_to_bem([options.surf], [int(options.id)], sigmas=[1]) mne.write_bem_surfaces(options.fif, surf) diff --git a/mne/commands/mne_sys_info.py b/mne/commands/mne_sys_info.py index a09994de8f9..075ff446681 100644 --- a/mne/commands/mne_sys_info.py +++ b/mne/commands/mne_sys_info.py @@ -17,17 +17,31 @@ def run(): """Run command.""" - parser = mne.commands.utils.get_optparser(__file__, usage='mne sys_info') - parser.add_option('-p', '--show-paths', dest='show_paths', - help='Show module paths', action='store_true') - parser.add_option('-d', '--developer', dest='developer', - help='Show additional developer module information', - action='store_true') - parser.add_option('-a', '--ascii', dest='unicode', - help='Use ASCII instead of unicode symbols', - action='store_false', default=True) + parser = mne.commands.utils.get_optparser(__file__, usage="mne sys_info") + parser.add_option( + "-p", + "--show-paths", + dest="show_paths", + help="Show module paths", + action="store_true", + ) + parser.add_option( + "-d", + "--developer", + dest="developer", + help="Show additional developer module information", + action="store_true", + ) + parser.add_option( + "-a", + "--ascii", + dest="unicode", + help="Use ASCII instead of unicode symbols", + action="store_false", + default=True, + ) options, args = parser.parse_args() - dependencies = 'developer' if options.developer else 'user' + dependencies = "developer" if options.developer else "user" if len(args) != 0: parser.print_help() sys.exit(1) @@ -35,7 +49,7 @@ def run(): mne.sys_info( show_paths=options.show_paths, dependencies=dependencies, - unicode=options.unicode + unicode=options.unicode, ) diff --git a/mne/commands/mne_watershed_bem.py b/mne/commands/mne_watershed_bem.py index b69a2801fd6..c182c7a0ded 100644 --- a/mne/commands/mne_watershed_bem.py +++ b/mne/commands/mne_watershed_bem.py @@ -23,35 +23,73 @@ def run(): parser = get_optparser(__file__) - parser.add_option("-s", "--subject", dest="subject", - help="Subject name (required)", default=None) - parser.add_option("-d", "--subjects-dir", dest="subjects_dir", - help="Subjects directory", default=None) - parser.add_option("-o", "--overwrite", dest="overwrite", - help="Write over existing files", action="store_true") - parser.add_option("-v", "--volume", dest="volume", - help="Defaults to T1", default='T1') - parser.add_option("-a", "--atlas", dest="atlas", - help="Specify the --atlas option for mri_watershed", - default=False, action="store_true") - parser.add_option("-g", "--gcaatlas", dest="gcaatlas", - help="Specify the --brain_atlas option for " - "mri_watershed", default=False, action="store_true") - parser.add_option("-p", "--preflood", dest="preflood", - help="Change the preflood height", default=None) - parser.add_option("--copy", dest="copy", - help="Use copies instead of symlinks for surfaces", - action="store_true") - parser.add_option("-t", "--T1", dest="T1", - help="Whether or not to pass the -T1 flag " - "(can be true, false, 0, or 1). " - "By default it takes the same value as gcaatlas.", - default=None) - parser.add_option("-b", "--brainmask", dest="brainmask", - help="The filename for the brainmask output file " - "relative to the " - "$SUBJECTS_DIR/$SUBJECT/bem/watershed/ directory.", - default="ws") + parser.add_option( + "-s", "--subject", dest="subject", help="Subject name (required)", default=None + ) + parser.add_option( + "-d", + "--subjects-dir", + dest="subjects_dir", + help="Subjects directory", + default=None, + ) + parser.add_option( + "-o", + "--overwrite", + dest="overwrite", + help="Write over existing files", + action="store_true", + ) + parser.add_option( + "-v", "--volume", dest="volume", help="Defaults to T1", default="T1" + ) + parser.add_option( + "-a", + "--atlas", + dest="atlas", + help="Specify the --atlas option for mri_watershed", + default=False, + action="store_true", + ) + parser.add_option( + "-g", + "--gcaatlas", + dest="gcaatlas", + help="Specify the --brain_atlas option for " "mri_watershed", + default=False, + action="store_true", + ) + parser.add_option( + "-p", + "--preflood", + dest="preflood", + help="Change the preflood height", + default=None, + ) + parser.add_option( + "--copy", + dest="copy", + help="Use copies instead of symlinks for surfaces", + action="store_true", + ) + parser.add_option( + "-t", + "--T1", + dest="T1", + help="Whether or not to pass the -T1 flag " + "(can be true, false, 0, or 1). " + "By default it takes the same value as gcaatlas.", + default=None, + ) + parser.add_option( + "-b", + "--brainmask", + dest="brainmask", + help="The filename for the brainmask output file " + "relative to the " + "$SUBJECTS_DIR/$SUBJECT/bem/watershed/ directory.", + default="ws", + ) _add_verbose_flag(parser) options, args = parser.parse_args() @@ -72,14 +110,23 @@ def run(): T1 = options.T1 if T1 is not None: T1 = T1.lower() - _check_option("--T1", T1, ('true', 'false', '0', '1')) - T1 = T1 in ('true', '1') + _check_option("--T1", T1, ("true", "false", "0", "1")) + T1 = T1 in ("true", "1") verbose = options.verbose - make_watershed_bem(subject=subject, subjects_dir=subjects_dir, - overwrite=overwrite, volume=volume, atlas=atlas, - gcaatlas=gcaatlas, preflood=preflood, copy=copy, - T1=T1, brainmask=brainmask, verbose=verbose) + make_watershed_bem( + subject=subject, + subjects_dir=subjects_dir, + overwrite=overwrite, + volume=volume, + atlas=atlas, + gcaatlas=gcaatlas, + preflood=preflood, + copy=copy, + T1=T1, + brainmask=brainmask, + verbose=verbose, + ) mne.utils.run_command_if_main() diff --git a/mne/commands/mne_what.py b/mne/commands/mne_what.py index 5d281facd0c..ab4a9d5ea8f 100644 --- a/mne/commands/mne_what.py +++ b/mne/commands/mne_what.py @@ -17,7 +17,8 @@ def run(): """Run command.""" from mne.commands.utils import get_optparser - parser = get_optparser(__file__, usage='usage: %prog fname [fname2 ...]') + + parser = get_optparser(__file__, usage="usage: %prog fname [fname2 ...]") options, args = parser.parse_args() for arg in args: print(mne.what(arg)) diff --git a/mne/commands/tests/test_commands.py b/mne/commands/tests/test_commands.py index 995edae59b9..c3bac034339 100644 --- a/mne/commands/tests/test_commands.py +++ b/mne/commands/tests/test_commands.py @@ -9,48 +9,74 @@ from numpy.testing import assert_equal, assert_allclose import mne -from mne import (concatenate_raws, read_bem_surfaces, read_surface, - read_source_spaces, read_bem_solution) +from mne import ( + concatenate_raws, + read_bem_surfaces, + read_surface, + read_source_spaces, + read_bem_solution, +) from mne.bem import ConductorModel, convert_flash_mris -from mne.commands import (mne_browse_raw, mne_bti2fiff, mne_clean_eog_ecg, - mne_compute_proj_ecg, mne_compute_proj_eog, - mne_coreg, mne_kit2fiff, - mne_make_scalp_surfaces, mne_maxfilter, - mne_report, mne_surf2bem, mne_watershed_bem, - mne_compare_fiff, mne_flash_bem, mne_show_fiff, - mne_show_info, mne_what, mne_setup_source_space, - mne_setup_forward_model, mne_anonymize, - mne_prepare_bem_model, mne_sys_info) +from mne.commands import ( + mne_browse_raw, + mne_bti2fiff, + mne_clean_eog_ecg, + mne_compute_proj_ecg, + mne_compute_proj_eog, + mne_coreg, + mne_kit2fiff, + mne_make_scalp_surfaces, + mne_maxfilter, + mne_report, + mne_surf2bem, + mne_watershed_bem, + mne_compare_fiff, + mne_flash_bem, + mne_show_fiff, + mne_show_info, + mne_what, + mne_setup_source_space, + mne_setup_forward_model, + mne_anonymize, + mne_prepare_bem_model, + mne_sys_info, +) from mne.datasets import testing from mne.io import read_raw_fif, read_info -from mne.utils import (requires_mne, requires_freesurfer, ArgvSetter, - _stamp_to_dt, _record_warnings) +from mne.utils import ( + requires_mne, + requires_freesurfer, + ArgvSetter, + _stamp_to_dt, + _record_warnings, +) -base_dir = op.join(op.dirname(__file__), '..', '..', 'io', 'tests', 'data') -raw_fname = op.join(base_dir, 'test_raw.fif') +base_dir = op.join(op.dirname(__file__), "..", "..", "io", "tests", "data") +raw_fname = op.join(base_dir, "test_raw.fif") testing_path = testing.data_path(download=False) -subjects_dir = op.join(testing_path, 'subjects') -bem_model_fname = op.join(testing_path, 'subjects', - 'sample', 'bem', 'sample-320-320-320-bem.fif') +subjects_dir = op.join(testing_path, "subjects") +bem_model_fname = op.join( + testing_path, "subjects", "sample", "bem", "sample-320-320-320-bem.fif" +) def check_usage(module, force_help=False): """Ensure we print usage.""" - args = ('--help',) if force_help else () + args = ("--help",) if force_help else () with ArgvSetter(args) as out: try: module.run() except SystemExit: pass - assert 'Usage: ' in out.stdout.getvalue() + assert "Usage: " in out.stdout.getvalue() @pytest.mark.slowtest def test_browse_raw(): """Test mne browse_raw.""" check_usage(mne_browse_raw) - with ArgvSetter(('--raw', raw_fname)): + with ArgvSetter(("--raw", raw_fname)): with _record_warnings(): # mpl show warning mne_browse_raw.run() @@ -60,7 +86,7 @@ def test_what(): check_usage(mne_browse_raw) with ArgvSetter((raw_fname,)) as out: mne_what.run() - assert 'raw' == out.stdout.getvalue().strip() + assert "raw" == out.stdout.getvalue().strip() def test_bti2fiff(): @@ -78,7 +104,7 @@ def test_show_fiff(): check_usage(mne_show_fiff) with ArgvSetter((raw_fname,)): mne_show_fiff.run() - with ArgvSetter((raw_fname, '--tag=102')): + with ArgvSetter((raw_fname, "--tag=102")): mne_show_fiff.run() @@ -87,42 +113,40 @@ def test_clean_eog_ecg(tmp_path): """Test mne clean_eog_ecg.""" check_usage(mne_clean_eog_ecg) tempdir = str(tmp_path) - raw = concatenate_raws([read_raw_fif(f) - for f in [raw_fname, raw_fname, raw_fname]]) - raw.info['bads'] = ['MEG 2443'] + raw = concatenate_raws([read_raw_fif(f) for f in [raw_fname, raw_fname, raw_fname]]) + raw.info["bads"] = ["MEG 2443"] use_fname = op.join(tempdir, op.basename(raw_fname)) raw.save(use_fname) - with ArgvSetter(('-i', use_fname, '--quiet')): + with ArgvSetter(("-i", use_fname, "--quiet")): mne_clean_eog_ecg.run() - for key, count in (('proj', 2), ('-eve', 3)): - fnames = glob.glob(op.join(tempdir, '*%s.fif' % key)) + for key, count in (("proj", 2), ("-eve", 3)): + fnames = glob.glob(op.join(tempdir, "*%s.fif" % key)) assert len(fnames) == count @pytest.mark.slowtest -@pytest.mark.parametrize('fun', (mne_compute_proj_ecg, mne_compute_proj_eog)) +@pytest.mark.parametrize("fun", (mne_compute_proj_ecg, mne_compute_proj_eog)) def test_compute_proj_exg(tmp_path, fun): """Test mne compute_proj_ecg/eog.""" check_usage(fun) tempdir = str(tmp_path) use_fname = op.join(tempdir, op.basename(raw_fname)) - bad_fname = op.join(tempdir, 'bads.txt') - with open(bad_fname, 'w') as fid: - fid.write('MEG 2443\n') + bad_fname = op.join(tempdir, "bads.txt") + with open(bad_fname, "w") as fid: + fid.write("MEG 2443\n") shutil.copyfile(raw_fname, use_fname) - with ArgvSetter(('-i', use_fname, '--bad=' + bad_fname, - '--rej-eeg', '150')): + with ArgvSetter(("-i", use_fname, "--bad=" + bad_fname, "--rej-eeg", "150")): with _record_warnings(): # samples, sometimes fun.run() - fnames = glob.glob(op.join(tempdir, '*proj.fif')) + fnames = glob.glob(op.join(tempdir, "*proj.fif")) assert len(fnames) == 1 - fnames = glob.glob(op.join(tempdir, '*-eve.fif')) + fnames = glob.glob(op.join(tempdir, "*-eve.fif")) assert len(fnames) == 1 def test_coreg(): """Test mne coreg.""" - assert hasattr(mne_coreg, 'run') + assert hasattr(mne_coreg, "run") def test_kit2fiff(): @@ -136,60 +160,73 @@ def test_kit2fiff(): @testing.requires_testing_data def test_make_scalp_surfaces(tmp_path, monkeypatch): """Test mne make_scalp_surfaces.""" - pytest.importorskip('nibabel') - pytest.importorskip('pyvista') + pytest.importorskip("nibabel") + pytest.importorskip("pyvista") check_usage(mne_make_scalp_surfaces) - has = 'SUBJECTS_DIR' in os.environ + has = "SUBJECTS_DIR" in os.environ # Copy necessary files to avoid FreeSurfer call tempdir = str(tmp_path) - surf_path = op.join(subjects_dir, 'sample', 'surf') - surf_path_new = op.join(tempdir, 'sample', 'surf') - os.mkdir(op.join(tempdir, 'sample')) + surf_path = op.join(subjects_dir, "sample", "surf") + surf_path_new = op.join(tempdir, "sample", "surf") + os.mkdir(op.join(tempdir, "sample")) os.mkdir(surf_path_new) - subj_dir = op.join(tempdir, 'sample', 'bem') + subj_dir = op.join(tempdir, "sample", "bem") os.mkdir(subj_dir) - cmd = ('-s', 'sample', '--subjects-dir', tempdir) + cmd = ("-s", "sample", "--subjects-dir", tempdir) monkeypatch.setattr( - mne.bem, 'decimate_surface', - lambda points, triangles, n_triangles: (points, triangles)) - dense_fname = op.join(subj_dir, 'sample-head-dense.fif') - medium_fname = op.join(subj_dir, 'sample-head-medium.fif') + mne.bem, + "decimate_surface", + lambda points, triangles, n_triangles: (points, triangles), + ) + dense_fname = op.join(subj_dir, "sample-head-dense.fif") + medium_fname = op.join(subj_dir, "sample-head-medium.fif") with ArgvSetter(cmd, disable_stdout=False, disable_stderr=False): - monkeypatch.delenv('FREESURFER_HOME') - with pytest.raises(RuntimeError, match='The FreeSurfer environ'): + monkeypatch.delenv("FREESURFER_HOME") + with pytest.raises(RuntimeError, match="The FreeSurfer environ"): mne_make_scalp_surfaces.run() - shutil.copy(op.join(surf_path, 'lh.seghead'), surf_path_new) - monkeypatch.setenv('FREESURFER_HOME', tempdir) + shutil.copy(op.join(surf_path, "lh.seghead"), surf_path_new) + monkeypatch.setenv("FREESURFER_HOME", tempdir) mne_make_scalp_surfaces.run() assert op.isfile(dense_fname) assert op.isfile(medium_fname) - with pytest.raises(OSError, match='overwrite'): + with pytest.raises(OSError, match="overwrite"): mne_make_scalp_surfaces.run() # actually check the outputs head_py = read_bem_surfaces(dense_fname) assert_equal(len(head_py), 1) head_py = head_py[0] - head_c = read_bem_surfaces(op.join(subjects_dir, 'sample', 'bem', - 'sample-head-dense.fif'))[0] - assert_allclose(head_py['rr'], head_c['rr']) + head_c = read_bem_surfaces( + op.join(subjects_dir, "sample", "bem", "sample-head-dense.fif") + )[0] + assert_allclose(head_py["rr"], head_c["rr"]) if not has: - assert 'SUBJECTS_DIR' not in os.environ + assert "SUBJECTS_DIR" not in os.environ def test_maxfilter(): """Test mne maxfilter.""" check_usage(mne_maxfilter) - with ArgvSetter(('-i', raw_fname, '--st', '--movecomp', '--linefreq', '60', - '--trans', raw_fname)) as out: + with ArgvSetter( + ( + "-i", + raw_fname, + "--st", + "--movecomp", + "--linefreq", + "60", + "--trans", + raw_fname, + ) + ) as out: with pytest.warns(RuntimeWarning, match="Don't use"): - os.environ['_MNE_MAXFILTER_TEST'] = 'true' + os.environ["_MNE_MAXFILTER_TEST"] = "true" try: mne_maxfilter.run() finally: - del os.environ['_MNE_MAXFILTER_TEST'] + del os.environ["_MNE_MAXFILTER_TEST"] out = out.stdout.getvalue() - for check in ('maxfilter', '-trans', '-movecomp'): + for check in ("maxfilter", "-trans", "-movecomp"): assert check in out, check @@ -197,16 +234,29 @@ def test_maxfilter(): @testing.requires_testing_data def test_report(tmp_path): """Test mne report.""" - pytest.importorskip('nibabel') + pytest.importorskip("nibabel") check_usage(mne_report) tempdir = str(tmp_path) use_fname = op.join(tempdir, op.basename(raw_fname)) shutil.copyfile(raw_fname, use_fname) - with ArgvSetter(('-p', tempdir, '-i', use_fname, '-d', subjects_dir, - '-s', 'sample', '--no-browser', '-m', '30')): + with ArgvSetter( + ( + "-p", + tempdir, + "-i", + use_fname, + "-d", + subjects_dir, + "-s", + "sample", + "--no-browser", + "-m", + "30", + ) + ): with _record_warnings(): # contour levels mne_report.run() - fnames = glob.glob(op.join(tempdir, '*.html')) + fnames = glob.glob(op.join(tempdir, "*.html")) assert len(fnames) == 1 @@ -218,48 +268,48 @@ def test_surf2bem(): @pytest.mark.timeout(900) # took ~400 s on a local test @pytest.mark.slowtest @pytest.mark.ultraslowtest -@requires_freesurfer('mri_watershed') +@requires_freesurfer("mri_watershed") @testing.requires_testing_data def test_watershed_bem(tmp_path): """Test mne watershed bem.""" - pytest.importorskip('nibabel') + pytest.importorskip("nibabel") check_usage(mne_watershed_bem) # from T1.mgz Mdc = np.array([[-1, 0, 0], [0, 0, -1], [0, 1, 0]]) Pxyz_c = np.array([-5.273613, 9.039085, -27.287964]) # Copy necessary files to tempdir tempdir = str(tmp_path) - mridata_path = op.join(subjects_dir, 'sample', 'mri') - subject_path_new = op.join(tempdir, 'sample') - mridata_path_new = op.join(subject_path_new, 'mri') + mridata_path = op.join(subjects_dir, "sample", "mri") + subject_path_new = op.join(tempdir, "sample") + mridata_path_new = op.join(subject_path_new, "mri") os.makedirs(mridata_path_new) - new_fname = op.join(mridata_path_new, 'T1.mgz') - shutil.copyfile(op.join(mridata_path, 'T1.mgz'), new_fname) + new_fname = op.join(mridata_path_new, "T1.mgz") + shutil.copyfile(op.join(mridata_path, "T1.mgz"), new_fname) old_mode = os.stat(new_fname).st_mode os.chmod(new_fname, 0) - args = ('-d', tempdir, '-s', 'sample', '-o') - with pytest.raises(PermissionError, match=r'read permissions.*T1\.mgz'): + args = ("-d", tempdir, "-s", "sample", "-o") + with pytest.raises(PermissionError, match=r"read permissions.*T1\.mgz"): with ArgvSetter(args): mne_watershed_bem.run() os.chmod(new_fname, old_mode) - for s in ('outer_skin', 'outer_skull', 'inner_skull'): - assert not op.isfile(op.join(subject_path_new, 'bem', '%s.surf' % s)) + for s in ("outer_skin", "outer_skull", "inner_skull"): + assert not op.isfile(op.join(subject_path_new, "bem", "%s.surf" % s)) with ArgvSetter(args): mne_watershed_bem.run() kwargs = dict(rtol=1e-5, atol=1e-5) - for s in ('outer_skin', 'outer_skull', 'inner_skull'): - rr, tris, vol_info = read_surface(op.join(subject_path_new, 'bem', - '%s.surf' % s), - read_metadata=True) + for s in ("outer_skin", "outer_skull", "inner_skull"): + rr, tris, vol_info = read_surface( + op.join(subject_path_new, "bem", "%s.surf" % s), read_metadata=True + ) assert_equal(len(tris), 20480) assert_equal(tris.min(), 0) assert_equal(rr.shape[0], tris.max() + 1) # compare the volume info to the mgz header - assert_allclose(vol_info['xras'], Mdc[0], **kwargs) - assert_allclose(vol_info['yras'], Mdc[1], **kwargs) - assert_allclose(vol_info['zras'], Mdc[2], **kwargs) - assert_allclose(vol_info['cras'], Pxyz_c, **kwargs) + assert_allclose(vol_info["xras"], Mdc[0], **kwargs) + assert_allclose(vol_info["yras"], Mdc[1], **kwargs) + assert_allclose(vol_info["zras"], Mdc[2], **kwargs) + assert_allclose(vol_info["cras"], Pxyz_c, **kwargs) @pytest.mark.timeout(180) # took ~70 s locally @@ -272,33 +322,38 @@ def test_flash_bem(tmp_path): check_usage(mne_flash_bem, force_help=True) # Copy necessary files to tempdir tempdir = Path(str(tmp_path)) - mridata_path = Path(subjects_dir) / 'sample' / 'mri' - subject_path_new = tempdir / 'sample' - mridata_path_new = subject_path_new / 'mri' - flash_path = mridata_path_new / 'flash' + mridata_path = Path(subjects_dir) / "sample" / "mri" + subject_path_new = tempdir / "sample" + mridata_path_new = subject_path_new / "mri" + flash_path = mridata_path_new / "flash" flash_path.mkdir(parents=True, exist_ok=True) - bem_path = mridata_path_new / 'bem' + bem_path = mridata_path_new / "bem" bem_path.mkdir(parents=True, exist_ok=True) - shutil.copyfile(op.join(mridata_path, 'T1.mgz'), - op.join(mridata_path_new, 'T1.mgz')) - shutil.copyfile(op.join(mridata_path, 'brain.mgz'), - op.join(mridata_path_new, 'brain.mgz')) + shutil.copyfile( + op.join(mridata_path, "T1.mgz"), op.join(mridata_path_new, "T1.mgz") + ) + shutil.copyfile( + op.join(mridata_path, "brain.mgz"), op.join(mridata_path_new, "brain.mgz") + ) # Copy the available mri/flash/mef*.mgz files from the dataset for kind in (5, 30): - in_fname = mridata_path / "flash" / f'mef{kind:02d}.mgz' - in_fname_echo = flash_path / f'mef{kind:02d}_001.mgz' + in_fname = mridata_path / "flash" / f"mef{kind:02d}.mgz" + in_fname_echo = flash_path / f"mef{kind:02d}_001.mgz" shutil.copyfile(in_fname, flash_path / in_fname_echo.name) # Test mne flash_bem with --noconvert option # (since there are no DICOM Flash images in dataset) - for s in ('outer_skin', 'outer_skull', 'inner_skull'): - assert not op.isfile(subject_path_new / 'bem' / f'{s}.surf') + for s in ("outer_skin", "outer_skull", "inner_skull"): + assert not op.isfile(subject_path_new / "bem" / f"{s}.surf") # First test without flash30 - with ArgvSetter(('-d', tempdir, '-s', 'sample', '-n', '-r', '-3'), - disable_stdout=False, disable_stderr=False): + with ArgvSetter( + ("-d", tempdir, "-s", "sample", "-n", "-r", "-3"), + disable_stdout=False, + disable_stderr=False, + ): mne_flash_bem.run() - for s in ('outer_skin', 'outer_skull', 'inner_skull'): - surf_path = subject_path_new / 'bem' / f'{s}.surf' + for s in ("outer_skin", "outer_skull", "inner_skull"): + surf_path = subject_path_new / "bem" / f"{s}.surf" assert surf_path.exists() surf_path.unlink() # cleanup shutil.rmtree(flash_path / "parameter_maps") # remove old files @@ -313,22 +368,33 @@ def test_flash_bem(tmp_path): # Test with flash5 and flash30 shutil.rmtree(flash_path) # first remove old files - with ArgvSetter(('-d', tempdir, '-s', 'sample', '-n', - '-3', str(mridata_path / "flash" / 'mef30.mgz'), - '-5', str(mridata_path / "flash" / 'mef05.mgz')), - disable_stdout=False, disable_stderr=False): + with ArgvSetter( + ( + "-d", + tempdir, + "-s", + "sample", + "-n", + "-3", + str(mridata_path / "flash" / "mef30.mgz"), + "-5", + str(mridata_path / "flash" / "mef05.mgz"), + ), + disable_stdout=False, + disable_stderr=False, + ): mne_flash_bem.run() kwargs = dict(rtol=1e-5, atol=1e-5) - for s in ('outer_skin', 'outer_skull', 'inner_skull'): - rr, tris = read_surface(op.join(subject_path_new, 'bem', - '%s.surf' % s)) + for s in ("outer_skin", "outer_skull", "inner_skull"): + rr, tris = read_surface(op.join(subject_path_new, "bem", "%s.surf" % s)) assert_equal(len(tris), 5120) assert_equal(tris.min(), 0) assert_equal(rr.shape[0], tris.max() + 1) # compare to the testing flash surfaces - rr_c, tris_c = read_surface(op.join(subjects_dir, 'sample', 'bem', - '%s.surf' % s)) + rr_c, tris_c = read_surface( + op.join(subjects_dir, "sample", "bem", "%s.surf" % s) + ) assert_allclose(rr, rr_c, **kwargs) assert_allclose(tris, tris_c, **kwargs) @@ -336,29 +402,80 @@ def test_flash_bem(tmp_path): @testing.requires_testing_data def test_setup_source_space(tmp_path): """Test mne setup_source_space.""" - pytest.importorskip('nibabel') + pytest.importorskip("nibabel") check_usage(mne_setup_source_space, force_help=True) # Using the sample dataset use_fname = op.join(tmp_path, "sources-src.fif") # Test command - with ArgvSetter(('--src', use_fname, '-d', subjects_dir, - '-s', 'sample', '--morph', 'sample', - '--add-dist', 'False', '--ico', '3', '--verbose')): + with ArgvSetter( + ( + "--src", + use_fname, + "-d", + subjects_dir, + "-s", + "sample", + "--morph", + "sample", + "--add-dist", + "False", + "--ico", + "3", + "--verbose", + ) + ): mne_setup_source_space.run() src = read_source_spaces(use_fname) assert len(src) == 2 with pytest.raises(Exception): - with ArgvSetter(('--src', use_fname, '-d', subjects_dir, - '-s', 'sample', '--ico', '3', '--oct', '3')): + with ArgvSetter( + ( + "--src", + use_fname, + "-d", + subjects_dir, + "-s", + "sample", + "--ico", + "3", + "--oct", + "3", + ) + ): assert mne_setup_source_space.run() with pytest.raises(Exception): - with ArgvSetter(('--src', use_fname, '-d', subjects_dir, - '-s', 'sample', '--ico', '3', '--spacing', '10')): + with ArgvSetter( + ( + "--src", + use_fname, + "-d", + subjects_dir, + "-s", + "sample", + "--ico", + "3", + "--spacing", + "10", + ) + ): assert mne_setup_source_space.run() with pytest.raises(Exception): - with ArgvSetter(('--src', use_fname, '-d', subjects_dir, - '-s', 'sample', '--ico', '3', '--spacing', '10', - '--oct', '3')): + with ArgvSetter( + ( + "--src", + use_fname, + "-d", + subjects_dir, + "-s", + "sample", + "--ico", + "3", + "--spacing", + "10", + "--oct", + "3", + ) + ): assert mne_setup_source_space.run() @@ -366,17 +483,29 @@ def test_setup_source_space(tmp_path): @testing.requires_testing_data def test_setup_forward_model(tmp_path): """Test mne setup_forward_model.""" - pytest.importorskip('nibabel') + pytest.importorskip("nibabel") check_usage(mne_setup_forward_model, force_help=True) # Using the sample dataset use_fname = op.join(tmp_path, "model-bem.fif") # Test command - with ArgvSetter(('--model', use_fname, '-d', subjects_dir, '--homog', - '-s', 'sample', '--ico', '3', '--verbose')): + with ArgvSetter( + ( + "--model", + use_fname, + "-d", + subjects_dir, + "--homog", + "-s", + "sample", + "--ico", + "3", + "--verbose", + ) + ): mne_setup_forward_model.run() model = read_bem_surfaces(use_fname) assert len(model) == 1 - sol_fname = op.splitext(use_fname)[0] + '-sol.fif' + sol_fname = op.splitext(use_fname)[0] + "-sol.fif" read_bem_solution(sol_fname) @@ -388,8 +517,9 @@ def test_mne_prepare_bem_model(tmp_path): # Using the sample dataset bem_solution_fname = op.join(tmp_path, "bem_solution-bem-sol.fif") # Test command - with ArgvSetter(('--bem', bem_model_fname, '--sol', bem_solution_fname, - '--verbose')): + with ArgvSetter( + ("--bem", bem_model_fname, "--sol", bem_solution_fname, "--verbose") + ): mne_prepare_bem_model.run() bem_solution = read_bem_solution(bem_solution_fname) assert isinstance(bem_solution, ConductorModel) @@ -406,19 +536,19 @@ def test_sys_info(): """Test mne show_info.""" check_usage(mne_sys_info, force_help=True) with ArgvSetter((raw_fname,)): - with pytest.raises(SystemExit, match='1'): + with pytest.raises(SystemExit, match="1"): mne_sys_info.run() with ArgvSetter() as out: mne_sys_info.run() - assert 'numpy' in out.stdout.getvalue() + assert "numpy" in out.stdout.getvalue() def test_anonymize(tmp_path): """Test mne anonymize.""" check_usage(mne_anonymize) - out_fname = op.join(tmp_path, 'anon_test_raw.fif') - with ArgvSetter(('-f', raw_fname, '-o', out_fname)): + out_fname = op.join(tmp_path, "anon_test_raw.fif") + with ArgvSetter(("-f", raw_fname, "-o", out_fname)): mne_anonymize.run() info = read_info(out_fname) assert op.exists(out_fname) - assert info['meas_date'] == _stamp_to_dt((946684800, 0)) + assert info["meas_date"] == _stamp_to_dt((946684800, 0)) diff --git a/mne/commands/utils.py b/mne/commands/utils.py index 415f513cad1..80d04ab1729 100644 --- a/mne/commands/utils.py +++ b/mne/commands/utils.py @@ -16,9 +16,13 @@ def _add_verbose_flag(parser): - parser.add_option("--verbose", dest='verbose', - help="Enable verbose mode (printing of log messages).", - default=None, action="store_true") + parser.add_option( + "--verbose", + dest="verbose", + help="Enable verbose mode (printing of log messages).", + default=None, + action="store_true", + ) def load_module(name, path): @@ -38,31 +42,32 @@ def load_module(name, path): """ from importlib.util import spec_from_file_location, module_from_spec + spec = spec_from_file_location(name, path) mod = module_from_spec(spec) spec.loader.exec_module(mod) return mod -def get_optparser(cmdpath, usage=None, prog_prefix='mne', version=None): +def get_optparser(cmdpath, usage=None, prog_prefix="mne", version=None): """Create OptionParser with cmd specific settings (e.g., prog value).""" # Fetch description - mod = load_module('__temp', cmdpath) + mod = load_module("__temp", cmdpath) if mod.__doc__: doc, description, epilog = mod.__doc__, None, None - doc_lines = doc.split('\n') + doc_lines = doc.split("\n") description = doc_lines[0] if len(doc_lines) > 1: - epilog = '\n'.join(doc_lines[1:]) + epilog = "\n".join(doc_lines[1:]) # Get the name of the command command = os.path.basename(cmdpath) command, _ = os.path.splitext(command) - command = command[len(prog_prefix) + 1:] # +1 is for `_` character + command = command[len(prog_prefix) + 1 :] # +1 is for `_` character # Set prog - prog = prog_prefix + ' {}'.format(command) + prog = prog_prefix + " {}".format(command) # Set version if version is None: @@ -70,10 +75,9 @@ def get_optparser(cmdpath, usage=None, prog_prefix='mne', version=None): # monkey patch OptionParser to not wrap epilog OptionParser.format_epilog = lambda self, formatter: self.epilog - parser = OptionParser(prog=prog, - version=version, - description=description, - epilog=epilog, usage=usage) + parser = OptionParser( + prog=prog, version=version, description=description, epilog=epilog, usage=usage + ) return parser @@ -81,8 +85,7 @@ def get_optparser(cmdpath, usage=None, prog_prefix='mne', version=None): def main(): """Entrypoint for mne usage.""" mne_bin_dir = op.dirname(op.dirname(__file__)) - valid_commands = sorted(glob.glob(op.join(mne_bin_dir, - 'commands', 'mne_*.py'))) + valid_commands = sorted(glob.glob(op.join(mne_bin_dir, "commands", "mne_*.py"))) valid_commands = [c.split(op.sep)[-1][4:-3] for c in valid_commands] def print_help(): # noqa @@ -102,6 +105,6 @@ def print_help(): # noqa print_help() else: cmd = sys.argv[1] - cmd = importlib.import_module('.mne_%s' % (cmd,), 'mne.commands') + cmd = importlib.import_module(".mne_%s" % (cmd,), "mne.commands") sys.argv = sys.argv[1:] cmd.run() diff --git a/mne/conftest.py b/mne/conftest.py index 72e95b6e788..c99d2eeebe1 100644 --- a/mne/conftest.py +++ b/mne/conftest.py @@ -26,81 +26,93 @@ from mne.fixes import has_numba, _compare_version from mne.io import read_raw_fif, read_raw_ctf, read_raw_nirx, read_raw_snirf from mne.stats import cluster_level -from mne.utils import (_pl, _assert_no_instances, numerics, Bunch, - _check_qt_version, _TempDir, check_version) +from mne.utils import ( + _pl, + _assert_no_instances, + numerics, + Bunch, + _check_qt_version, + _TempDir, + check_version, +) # data from sample dataset from mne.viz._figure import use_browser_backend from mne.viz.backends._utils import _init_mne_qtapp test_path = testing.data_path(download=False) -s_path = op.join(test_path, 'MEG', 'sample') -fname_evoked = op.join(s_path, 'sample_audvis_trunc-ave.fif') -fname_cov = op.join(s_path, 'sample_audvis_trunc-cov.fif') -fname_fwd = op.join(s_path, 'sample_audvis_trunc-meg-eeg-oct-4-fwd.fif') -fname_fwd_full = op.join(s_path, 'sample_audvis_trunc-meg-eeg-oct-6-fwd.fif') -bem_path = op.join(test_path, 'subjects', 'sample', 'bem') -fname_bem = op.join(bem_path, 'sample-1280-bem.fif') -fname_aseg = op.join(test_path, 'subjects', 'sample', 'mri', 'aseg.mgz') -subjects_dir = op.join(test_path, 'subjects') -fname_src = op.join(bem_path, 'sample-oct-4-src.fif') -fname_trans = op.join(s_path, 'sample_audvis_trunc-trans.fif') - -ctf_dir = op.join(test_path, 'CTF') -fname_ctf_continuous = op.join(ctf_dir, 'testdata_ctf.ds') - -nirx_path = test_path / 'NIRx' -snirf_path = test_path / 'SNIRF' -nirsport2 = nirx_path / 'nirsport_v2' / 'aurora_recording _w_short_and_acc' -nirsport2_snirf = ( - snirf_path / 'NIRx' / 'NIRSport2' / '1.0.3' / - '2021-05-05_001.snirf') -nirsport2_2021_9 = nirx_path / 'nirsport_v2' / 'aurora_2021_9' +s_path = op.join(test_path, "MEG", "sample") +fname_evoked = op.join(s_path, "sample_audvis_trunc-ave.fif") +fname_cov = op.join(s_path, "sample_audvis_trunc-cov.fif") +fname_fwd = op.join(s_path, "sample_audvis_trunc-meg-eeg-oct-4-fwd.fif") +fname_fwd_full = op.join(s_path, "sample_audvis_trunc-meg-eeg-oct-6-fwd.fif") +bem_path = op.join(test_path, "subjects", "sample", "bem") +fname_bem = op.join(bem_path, "sample-1280-bem.fif") +fname_aseg = op.join(test_path, "subjects", "sample", "mri", "aseg.mgz") +subjects_dir = op.join(test_path, "subjects") +fname_src = op.join(bem_path, "sample-oct-4-src.fif") +fname_trans = op.join(s_path, "sample_audvis_trunc-trans.fif") + +ctf_dir = op.join(test_path, "CTF") +fname_ctf_continuous = op.join(ctf_dir, "testdata_ctf.ds") + +nirx_path = test_path / "NIRx" +snirf_path = test_path / "SNIRF" +nirsport2 = nirx_path / "nirsport_v2" / "aurora_recording _w_short_and_acc" +nirsport2_snirf = snirf_path / "NIRx" / "NIRSport2" / "1.0.3" / "2021-05-05_001.snirf" +nirsport2_2021_9 = nirx_path / "nirsport_v2" / "aurora_2021_9" nirsport2_20219_snirf = ( - snirf_path / 'NIRx' / 'NIRSport2' / '2021.9' / - '2021-10-01_002.snirf') + snirf_path / "NIRx" / "NIRSport2" / "2021.9" / "2021-10-01_002.snirf" +) # data from mne.io.tests.data -base_dir = op.join(op.dirname(__file__), 'io', 'tests', 'data') -fname_raw_io = op.join(base_dir, 'test_raw.fif') -fname_event_io = op.join(base_dir, 'test-eve.fif') -fname_cov_io = op.join(base_dir, 'test-cov.fif') -fname_evoked_io = op.join(base_dir, 'test-ave.fif') +base_dir = op.join(op.dirname(__file__), "io", "tests", "data") +fname_raw_io = op.join(base_dir, "test_raw.fif") +fname_event_io = op.join(base_dir, "test-eve.fif") +fname_cov_io = op.join(base_dir, "test-cov.fif") +fname_evoked_io = op.join(base_dir, "test-ave.fif") event_id, tmin, tmax = 1, -0.1, 1.0 -vv_layout = read_layout('Vectorview-all') +vv_layout = read_layout("Vectorview-all") -collect_ignore = [ - 'export/_brainvision.py', - 'export/_eeglab.py', - 'export/_edf.py'] +collect_ignore = ["export/_brainvision.py", "export/_eeglab.py", "export/_edf.py"] def pytest_configure(config): """Configure pytest options.""" # Markers - for marker in ('slowtest', 'ultraslowtest', 'pgtest', 'allow_unclosed', - 'allow_unclosed_pyside2'): - config.addinivalue_line('markers', marker) + for marker in ( + "slowtest", + "ultraslowtest", + "pgtest", + "allow_unclosed", + "allow_unclosed_pyside2", + ): + config.addinivalue_line("markers", marker) # Fixtures - for fixture in ('matplotlib_config', 'close_all', 'check_verbose', - 'qt_config', 'protect_config'): - config.addinivalue_line('usefixtures', fixture) + for fixture in ( + "matplotlib_config", + "close_all", + "check_verbose", + "qt_config", + "protect_config", + ): + config.addinivalue_line("usefixtures", fixture) # pytest-qt uses PYTEST_QT_API, but let's make it respect qtpy's QT_API # if present - if os.getenv('PYTEST_QT_API') is None and os.getenv('QT_API') is not None: - os.environ['PYTEST_QT_API'] = os.environ['QT_API'] + if os.getenv("PYTEST_QT_API") is None and os.getenv("QT_API") is not None: + os.environ["PYTEST_QT_API"] = os.environ["QT_API"] # Warnings # - Once SciPy updates not to have non-integer and non-tuple errors (1.2.0) # we should remove them from here. # - This list should also be considered alongside reset_warnings in # doc/conf.py. - if os.getenv('MNE_IGNORE_WARNINGS_IN_TESTS', '') != 'true': - first_kind = 'error' + if os.getenv("MNE_IGNORE_WARNINGS_IN_TESTS", "") != "true": + first_kind = "error" else: - first_kind = 'always' + first_kind = "always" warning_lines = f" {first_kind}::" warning_lines += r""" # matplotlib->traitlets (notebook) @@ -143,10 +155,10 @@ def pytest_configure(config): # h5py ignore:`product` is deprecated as of NumPy.*:DeprecationWarning """ # noqa: E501 - for warning_line in warning_lines.split('\n'): + for warning_line in warning_lines.split("\n"): warning_line = warning_line.strip() - if warning_line and not warning_line.startswith('#'): - config.addinivalue_line('filterwarnings', warning_line) + if warning_line and not warning_line.startswith("#"): + config.addinivalue_line("filterwarnings", warning_line) # Have to be careful with autouse=True, but this is just an int comparison @@ -160,9 +172,10 @@ def check_verbose(request): try: assert mne.utils.logger.level == starting_level except AssertionError: - pytest.fail('.'.join([request.module.__name__, - request.function.__name__]) + - ' modifies logger.level') + pytest.fail( + ".".join([request.module.__name__, request.function.__name__]) + + " modifies logger.level" + ) @pytest.fixture(autouse=True) @@ -170,8 +183,9 @@ def close_all(): """Close all matplotlib plots, regardless of test status.""" # This adds < 1 µS in local testing, and we have ~2500 tests, so ~2 ms max import matplotlib.pyplot as plt + yield - plt.close('all') + plt.close("all") @pytest.fixture(autouse=True) @@ -180,44 +194,46 @@ def add_mne(doctest_namespace): doctest_namespace["mne"] = mne -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def verbose_debug(): """Run a test with debug verbosity.""" - with mne.utils.use_log_level('debug'): + with mne.utils.use_log_level("debug"): yield -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def qt_config(): """Configure the Qt backend for viz tests.""" - os.environ['_MNE_BROWSER_NO_BLOCK'] = 'true' + os.environ["_MNE_BROWSER_NO_BLOCK"] = "true" -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def matplotlib_config(): """Configure matplotlib for viz tests.""" import matplotlib from matplotlib import cbook + # Allow for easy interactive debugging with a call like: # # $ MNE_MPL_TESTING_BACKEND=Qt5Agg pytest mne/viz/tests/test_raw.py -k annotation -x --pdb # noqa: E501 # try: - want = os.environ['MNE_MPL_TESTING_BACKEND'] + want = os.environ["MNE_MPL_TESTING_BACKEND"] except KeyError: - want = 'agg' # don't pop up windows + want = "agg" # don't pop up windows with warnings.catch_warnings(record=True): # ignore warning - warnings.filterwarnings('ignore') + warnings.filterwarnings("ignore") matplotlib.use(want, force=True) import matplotlib.pyplot as plt + assert plt.get_backend() == want # overwrite some params that can horribly slow down tests that # users might have changed locally (but should not otherwise affect # functionality) plt.ioff() - plt.rcParams['figure.dpi'] = 100 + plt.rcParams["figure.dpi"] = 100 try: - plt.rcParams['figure.raise_window'] = False + plt.rcParams["figure.raise_window"] = False except KeyError: # MPL < 3.3 pass @@ -231,21 +247,22 @@ def __init__(self, exception_handler=None, signals=None): cbook.CallbackRegistry = CallbackRegistryReraise -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def azure_windows(): """Determine if running on Azure Windows.""" - return (os.getenv('AZURE_CI_WINDOWS', 'false').lower() == 'true' and - sys.platform.startswith('win')) + return os.getenv( + "AZURE_CI_WINDOWS", "false" + ).lower() == "true" and sys.platform.startswith("win") -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def raw_orig(): """Get raw data without any change to it from mne.io.tests.data.""" raw = read_raw_fif(fname_raw_io, preload=True) return raw -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def raw(): """ Get raw data and pick channels to reduce load for testing. @@ -254,21 +271,21 @@ def raw(): """ raw = read_raw_fif(fname_raw_io, preload=True) # Throws a warning about a changed unit. - with pytest.warns(RuntimeWarning, match='unit'): - raw.set_channel_types({raw.ch_names[0]: 'ias'}) + with pytest.warns(RuntimeWarning, match="unit"): + raw.set_channel_types({raw.ch_names[0]: "ias"}) raw.pick_channels(raw.ch_names[:9]) raw.info.normalize_proj() # Fix projectors after subselection return raw -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def raw_ctf(): """Get ctf raw data from mne.io.tests.data.""" raw_ctf = read_raw_ctf(fname_ctf_continuous, preload=True) return raw_ctf -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def events(): """Get events from mne.io.tests.data.""" return read_events(fname_event_io) @@ -278,13 +295,22 @@ def _get_epochs(stop=5, meg=True, eeg=False, n_chan=20): """Get epochs.""" raw = read_raw_fif(fname_raw_io) events = read_events(fname_event_io) - picks = pick_types(raw.info, meg=meg, eeg=eeg, stim=False, - ecg=False, eog=False, exclude='bads') + picks = pick_types( + raw.info, meg=meg, eeg=eeg, stim=False, ecg=False, eog=False, exclude="bads" + ) # Use a subset of channels for plotting speed picks = np.round(np.linspace(0, len(picks) + 1, n_chan)).astype(int) - with pytest.warns(RuntimeWarning, match='projection'): - epochs = Epochs(raw, events[:stop], event_id, tmin, tmax, picks=picks, - proj=False, preload=False) + with pytest.warns(RuntimeWarning, match="projection"): + epochs = Epochs( + raw, + events[:stop], + event_id, + tmin, + tmax, + picks=picks, + proj=False, + preload=False, + ) epochs.info.normalize_proj() # avoid warnings return epochs @@ -311,12 +337,13 @@ def epochs_full(): return _get_epochs(None).load_data() -@pytest.fixture(scope='session', params=[testing._pytest_param()]) +@pytest.fixture(scope="session", params=[testing._pytest_param()]) def _evoked(): # This one is session scoped, so be sure not to modify it (use evoked # instead) - evoked = mne.read_evokeds(fname_evoked, condition='Left Auditory', - baseline=(None, 0)) + evoked = mne.read_evokeds( + fname_evoked, condition="Left Auditory", baseline=(None, 0) + ) evoked.crop(0, 0.2) return evoked @@ -327,7 +354,7 @@ def evoked(_evoked): return _evoked.copy() -@pytest.fixture(scope='function', params=[testing._pytest_param()]) +@pytest.fixture(scope="function", params=[testing._pytest_param()]) def noise_cov(): """Get a noise cov from the testing dataset.""" return mne.read_cov(fname_cov) @@ -339,45 +366,44 @@ def noise_cov_io(): return mne.read_cov(fname_cov_io) -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def bias_params_free(evoked, noise_cov): """Provide inputs for free bias functions.""" fwd = mne.read_forward_solution(fname_fwd) return _bias_params(evoked, noise_cov, fwd) -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def bias_params_fixed(evoked, noise_cov): """Provide inputs for fixed bias functions.""" fwd = mne.read_forward_solution(fname_fwd) - mne.convert_forward_solution( - fwd, force_fixed=True, surf_ori=True, copy=False) + mne.convert_forward_solution(fwd, force_fixed=True, surf_ori=True, copy=False) return _bias_params(evoked, noise_cov, fwd) def _bias_params(evoked, noise_cov, fwd): evoked.pick_types(meg=True, eeg=True, exclude=()) # restrict to limited set of verts (small src here) and one hemi for speed - vertices = [fwd['src'][0]['vertno'].copy(), []] + vertices = [fwd["src"][0]["vertno"].copy(), []] stc = mne.SourceEstimate( - np.zeros((sum(len(v) for v in vertices), 1)), vertices, 0, 1) + np.zeros((sum(len(v) for v in vertices), 1)), vertices, 0, 1 + ) fwd = mne.forward.restrict_forward_to_stc(fwd, stc) - assert fwd['sol']['row_names'] == noise_cov['names'] - assert noise_cov['names'] == evoked.ch_names - evoked = mne.EvokedArray(fwd['sol']['data'].copy(), evoked.info) + assert fwd["sol"]["row_names"] == noise_cov["names"] + assert noise_cov["names"] == evoked.ch_names + evoked = mne.EvokedArray(fwd["sol"]["data"].copy(), evoked.info) data_cov = noise_cov.copy() - data = fwd['sol']['data'] @ fwd['sol']['data'].T + data = fwd["sol"]["data"] @ fwd["sol"]["data"].T data *= 1e-14 # 100 nAm at each source, effectively (1e-18 would be 1 nAm) # This is rank-deficient, so let's make it actually positive semidefinite # by regularizing a tiny bit - data.flat[::data.shape[0] + 1] += mne.make_ad_hoc_cov(evoked.info)['data'] + data.flat[:: data.shape[0] + 1] += mne.make_ad_hoc_cov(evoked.info)["data"] # Do our projection - proj, _, _ = mne.io.proj.make_projector( - data_cov['projs'], data_cov['names']) + proj, _, _ = mne.io.proj.make_projector(data_cov["projs"], data_cov["names"]) data = proj @ data @ proj.T - data_cov['data'][:] = data - assert data_cov['data'].shape[0] == len(noise_cov['names']) - want = np.arange(fwd['sol']['data'].shape[1]) + data_cov["data"][:] = data + assert data_cov["data"].shape[0] == len(noise_cov["names"]) + want = np.arange(fwd["sol"]["data"].shape[1]) if not mne.forward.is_fixed_orient(fwd): want //= 3 return evoked, fwd, noise_cov, data_cov, want @@ -393,42 +419,42 @@ def garbage_collect(): @pytest.fixture def mpl_backend(garbage_collect): """Use for epochs/ica when not implemented with pyqtgraph yet.""" - with use_browser_backend('matplotlib') as backend: + with use_browser_backend("matplotlib") as backend: yield backend backend._close_all() # Skip functions or modules for mne-qt-browser < 0.2.0 -pre_2_0_skip_modules = ['mne.viz.tests.test_epochs', - 'mne.viz.tests.test_ica'] -pre_2_0_skip_funcs = ['test_plot_raw_white', - 'test_plot_raw_selection'] +pre_2_0_skip_modules = ["mne.viz.tests.test_epochs", "mne.viz.tests.test_ica"] +pre_2_0_skip_funcs = ["test_plot_raw_white", "test_plot_raw_selection"] def _check_pyqtgraph(request): # Check Qt qt_version, api = _check_qt_version(return_api=True) - if (not qt_version) or _compare_version(qt_version, '<', '5.12'): - pytest.skip(f'Qt API {api} has version {qt_version} ' - f'but pyqtgraph needs >= 5.12!') + if (not qt_version) or _compare_version(qt_version, "<", "5.12"): + pytest.skip( + f"Qt API {api} has version {qt_version} " f"but pyqtgraph needs >= 5.12!" + ) try: import mne_qt_browser # noqa: F401 + # Check mne-qt-browser version - lower_2_0 = _compare_version(mne_qt_browser.__version__, '<', '0.2.0') + lower_2_0 = _compare_version(mne_qt_browser.__version__, "<", "0.2.0") m_name = request.function.__module__ f_name = request.function.__name__ if lower_2_0 and m_name in pre_2_0_skip_modules: - pytest.skip(f'Test-Module "{m_name}" was skipped for' - f' mne-qt-browser < 0.2.0') + pytest.skip( + f'Test-Module "{m_name}" was skipped for' f" mne-qt-browser < 0.2.0" + ) elif lower_2_0 and f_name in pre_2_0_skip_funcs: - pytest.skip(f'Test "{f_name}" was skipped for ' - f'mne-qt-browser < 0.2.0') + pytest.skip(f'Test "{f_name}" was skipped for ' f"mne-qt-browser < 0.2.0") except Exception: - pytest.skip('Requires mne_qt_browser') + pytest.skip("Requires mne_qt_browser") else: ver = mne_qt_browser.__version__ - if api != 'PyQt5' and _compare_version(ver, '<=', '0.2.6'): - pytest.skip(f'mne_qt_browser {ver} requires PyQt5, API is {api}') + if api != "PyQt5" and _compare_version(ver, "<=", "0.2.6"): + pytest.skip(f"mne_qt_browser {ver} requires PyQt5, API is {api}") @pytest.fixture @@ -436,35 +462,39 @@ def pg_backend(request, garbage_collect): """Use for pyqtgraph-specific test-functions.""" _check_pyqtgraph(request) from mne_qt_browser._pg_figure import MNEQtBrowser - with use_browser_backend('qt') as backend: + + with use_browser_backend("qt") as backend: backend._close_all() yield backend backend._close_all() # This shouldn't be necessary, but let's make sure nothing is stale import mne_qt_browser + mne_qt_browser._browser_instances.clear() - if check_version('mne_qt_browser', min_version='0.4'): - _assert_no_instances( - MNEQtBrowser, f'Closure of {request.node.name}') + if check_version("mne_qt_browser", min_version="0.4"): + _assert_no_instances(MNEQtBrowser, f"Closure of {request.node.name}") -@pytest.fixture(params=[ - 'matplotlib', - pytest.param('qt', marks=pytest.mark.pgtest), -]) +@pytest.fixture( + params=[ + "matplotlib", + pytest.param("qt", marks=pytest.mark.pgtest), + ] +) def browser_backend(request, garbage_collect, monkeypatch): """Parametrizes the name of the browser backend.""" backend_name = request.param - if backend_name == 'qt': + if backend_name == "qt": _check_pyqtgraph(request) with use_browser_backend(backend_name) as backend: backend._close_all() - monkeypatch.setenv('MNE_BROWSE_RAW_SIZE', '10,10') + monkeypatch.setenv("MNE_BROWSE_RAW_SIZE", "10,10") yield backend backend._close_all() - if backend_name == 'qt': + if backend_name == "qt": # This shouldn't be necessary, but let's make sure nothing is stale import mne_qt_browser + mne_qt_browser._browser_instances.clear() @@ -506,9 +536,11 @@ def renderer_interactive(request, options_3d): @contextmanager def _use_backend(backend_name, interactive): from mne.viz.backends.renderer import _use_test_3d_backend + _check_skip_backend(backend_name) with _use_test_3d_backend(backend_name, interactive=interactive): from mne.viz.backends import renderer + try: yield renderer finally: @@ -516,34 +548,39 @@ def _use_backend(backend_name, interactive): def _check_skip_backend(name): - from mne.viz.backends.tests._utils import (has_pyvista, - has_imageio_ffmpeg, - has_pyvistaqt) + from mne.viz.backends.tests._utils import ( + has_pyvista, + has_imageio_ffmpeg, + has_pyvistaqt, + ) from mne.viz.backends._utils import _notebook_vtk_works + if not has_pyvista(): pytest.skip("Test skipped, requires pyvista.") if not has_imageio_ffmpeg(): pytest.skip("Test skipped, requires imageio-ffmpeg") - if name == 'pyvistaqt': + if name == "pyvistaqt": if not _check_qt_version(): pytest.skip("Test skipped, requires Qt.") if not has_pyvistaqt(): pytest.skip("Test skipped, requires pyvistaqt") else: - assert name == 'notebook', name + assert name == "notebook", name if not _notebook_vtk_works(): pytest.skip("Test skipped, requires working notebook vtk") -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def pixel_ratio(): """Get the pixel ratio.""" from mne.viz.backends.tests._utils import has_pyvista + # _check_qt_version will init an app for us, so no need for us to do it if not has_pyvista() or not _check_qt_version(): - return 1. + return 1.0 from qtpy.QtWidgets import QMainWindow from qtpy.QtCore import Qt + app = _init_mne_qtapp() app.processEvents() window = QMainWindow() @@ -553,10 +590,10 @@ def pixel_ratio(): return ratio -@pytest.fixture(scope='function', params=[testing._pytest_param()]) +@pytest.fixture(scope="function", params=[testing._pytest_param()]) def subjects_dir_tmp(tmp_path): """Copy MNE-testing-data subjects_dir to a temp dir for manipulation.""" - for key in ('sample', 'fsaverage'): + for key in ("sample", "fsaverage"): shutil.copytree(op.join(subjects_dir, key), str(tmp_path / key)) return str(tmp_path) @@ -564,59 +601,64 @@ def subjects_dir_tmp(tmp_path): @pytest.fixture(params=[testing._pytest_param()]) def subjects_dir_tmp_few(tmp_path): """Copy fewer files to a tmp_path.""" - subjects_path = tmp_path / 'subjects' + subjects_path = tmp_path / "subjects" os.mkdir(subjects_path) # add fsaverage - create_default_subject(subjects_dir=subjects_path, fs_home=test_path, - verbose=True) + create_default_subject(subjects_dir=subjects_path, fs_home=test_path, verbose=True) # add sample (with few files) - sample_path = subjects_path / 'sample' - os.makedirs(sample_path / 'bem') - for dirname in ('mri', 'surf'): + sample_path = subjects_path / "sample" + os.makedirs(sample_path / "bem") + for dirname in ("mri", "surf"): shutil.copytree( - test_path / 'subjects' / 'sample' / dirname, sample_path / dirname) + test_path / "subjects" / "sample" / dirname, sample_path / dirname + ) return subjects_path # Scoping these as session will make things faster, but need to make sure # not to modify them in-place in the tests, so keep them private -@pytest.fixture(scope='session', params=[testing._pytest_param()]) +@pytest.fixture(scope="session", params=[testing._pytest_param()]) def _evoked_cov_sphere(_evoked): """Compute a small evoked/cov/sphere combo for use with forwards.""" evoked = _evoked.copy().pick_types(meg=True) evoked.pick_channels(evoked.ch_names[::4]) assert len(evoked.ch_names) == 77 cov = mne.read_cov(fname_cov) - sphere = mne.make_sphere_model('auto', 'auto', evoked.info) + sphere = mne.make_sphere_model("auto", "auto", evoked.info) return evoked, cov, sphere -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def _fwd_surf(_evoked_cov_sphere): """Compute a forward for a surface source space.""" evoked, cov, sphere = _evoked_cov_sphere src_surf = mne.read_source_spaces(fname_src) return mne.make_forward_solution( - evoked.info, fname_trans, src_surf, sphere, mindist=5.0) + evoked.info, fname_trans, src_surf, sphere, mindist=5.0 + ) -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def _fwd_subvolume(_evoked_cov_sphere): """Compute a forward for a surface source space.""" - pytest.importorskip('nibabel') + pytest.importorskip("nibabel") evoked, cov, sphere = _evoked_cov_sphere - volume_labels = ['Left-Cerebellum-Cortex', 'right-Cerebellum-Cortex'] - with pytest.raises(ValueError, - match=r"Did you mean one of \['Right-Cere"): + volume_labels = ["Left-Cerebellum-Cortex", "right-Cerebellum-Cortex"] + with pytest.raises(ValueError, match=r"Did you mean one of \['Right-Cere"): mne.setup_volume_source_space( - 'sample', pos=20., volume_label=volume_labels, - subjects_dir=subjects_dir) - volume_labels[1] = 'R' + volume_labels[1][1:] + "sample", pos=20.0, volume_label=volume_labels, subjects_dir=subjects_dir + ) + volume_labels[1] = "R" + volume_labels[1][1:] src_vol = mne.setup_volume_source_space( - 'sample', pos=20., volume_label=volume_labels, - subjects_dir=subjects_dir, add_interpolator=False) + "sample", + pos=20.0, + volume_label=volume_labels, + subjects_dir=subjects_dir, + add_interpolator=False, + ) return mne.make_forward_solution( - evoked.info, fname_trans, src_vol, sphere, mindist=5.0) + evoked.info, fname_trans, src_vol, sphere, mindist=5.0 + ) @pytest.fixture @@ -625,52 +667,50 @@ def fwd_volume_small(_fwd_subvolume): return _fwd_subvolume.copy() -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def _all_src_types_fwd(_fwd_surf, _fwd_subvolume): """Create all three forward types (surf, vol, mixed).""" - fwds = dict( - surface=_fwd_surf.copy(), - volume=_fwd_subvolume.copy()) - with pytest.raises(RuntimeError, - match='Invalid source space with kinds'): - fwds['volume']['src'] + fwds['surface']['src'] + fwds = dict(surface=_fwd_surf.copy(), volume=_fwd_subvolume.copy()) + with pytest.raises(RuntimeError, match="Invalid source space with kinds"): + fwds["volume"]["src"] + fwds["surface"]["src"] # mixed (4) - fwd = fwds['surface'].copy() - f2 = fwds['volume'].copy() + fwd = fwds["surface"].copy() + f2 = fwds["volume"].copy() del _fwd_surf, _fwd_subvolume - for keys, axis in [(('source_rr',), 0), - (('source_nn',), 0), - (('sol', 'data'), 1), - (('_orig_sol',), 1)]: + for keys, axis in [ + (("source_rr",), 0), + (("source_nn",), 0), + (("sol", "data"), 1), + (("_orig_sol",), 1), + ]: a, b = fwd, f2 key = keys[0] if len(keys) > 1: a, b = a[key], b[key] key = keys[1] a[key] = np.concatenate([a[key], b[key]], axis=axis) - fwd['sol']['ncol'] = fwd['sol']['data'].shape[1] - fwd['nsource'] = fwd['sol']['ncol'] // 3 - fwd['src'] = fwd['src'] + f2['src'] - fwds['mixed'] = fwd + fwd["sol"]["ncol"] = fwd["sol"]["data"].shape[1] + fwd["nsource"] = fwd["sol"]["ncol"] // 3 + fwd["src"] = fwd["src"] + f2["src"] + fwds["mixed"] = fwd return fwds -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def _all_src_types_inv_evoked(_evoked_cov_sphere, _all_src_types_fwd): """Compute inverses for all source types.""" evoked, cov, _ = _evoked_cov_sphere invs = dict() for kind, fwd in _all_src_types_fwd.items(): - assert fwd['src'].kind == kind - with pytest.warns(RuntimeWarning, match='has been reduced'): - invs[kind] = mne.minimum_norm.make_inverse_operator( - evoked.info, fwd, cov) + assert fwd["src"].kind == kind + with pytest.warns(RuntimeWarning, match="has been reduced"): + invs[kind] = mne.minimum_norm.make_inverse_operator(evoked.info, fwd, cov) return invs, evoked -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def all_src_types_inv_evoked(_all_src_types_inv_evoked): """All source types of inverses, allowing for possible modification.""" invs, evoked = _all_src_types_inv_evoked @@ -679,42 +719,48 @@ def all_src_types_inv_evoked(_all_src_types_inv_evoked): return invs, evoked -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def mixed_fwd_cov_evoked(_evoked_cov_sphere, _all_src_types_fwd): """Compute inverses for all source types.""" evoked, cov, _ = _evoked_cov_sphere - return _all_src_types_fwd['mixed'].copy(), cov.copy(), evoked.copy() + return _all_src_types_fwd["mixed"].copy(), cov.copy(), evoked.copy() -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") @pytest.mark.slowtest @pytest.mark.parametrize(params=[testing._pytest_param()]) def src_volume_labels(): """Create a 7mm source space with labels.""" - pytest.importorskip('nibabel') + pytest.importorskip("nibabel") volume_labels = mne.get_volume_labels_from_aseg(fname_aseg) - with pytest.warns(RuntimeWarning, match='Found no usable.*Left-vessel.*'): + with pytest.warns(RuntimeWarning, match="Found no usable.*Left-vessel.*"): src = mne.setup_volume_source_space( - 'sample', 7., mri='aseg.mgz', volume_label=volume_labels, - add_interpolator=False, bem=fname_bem, - subjects_dir=subjects_dir) + "sample", + 7.0, + mri="aseg.mgz", + volume_label=volume_labels, + add_interpolator=False, + bem=fname_bem, + subjects_dir=subjects_dir, + ) lut, _ = mne.read_freesurfer_lut() assert len(volume_labels) == 46 - assert volume_labels[0] == 'Unknown' - assert lut['Unknown'] == 0 # it will be excluded during label gen + assert volume_labels[0] == "Unknown" + assert lut["Unknown"] == 0 # it will be excluded during label gen return src, tuple(volume_labels), lut def _fail(*args, **kwargs): __tracebackhide__ = True - raise AssertionError('Test should not download') + raise AssertionError("Test should not download") -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def download_is_error(monkeypatch): """Prevent downloading by raising an error when it's attempted.""" import pooch - monkeypatch.setattr(pooch, 'retrieve', _fail) + + monkeypatch.setattr(pooch, "retrieve", _fail) yield @@ -722,14 +768,14 @@ def download_is_error(monkeypatch): def fake_retrieve(monkeypatch, download_is_error): """Monkeypatch pooch.retrieve to avoid downloading (just touch files).""" import pooch + my_func = _FakeFetch() - monkeypatch.setattr(pooch, 'retrieve', my_func) - monkeypatch.setattr(pooch, 'create', my_func) + monkeypatch.setattr(pooch, "retrieve", my_func) + monkeypatch.setattr(pooch, "create", my_func) yield my_func class _FakeFetch: - def __init__(self): self.call_args_list = list() @@ -739,15 +785,15 @@ def call_count(self): # Wrapper for pooch.retrieve(...) and pooch.create(...) def __call__(self, *args, **kwargs): - assert 'path' in kwargs - if 'fname' in kwargs: # pooch.retrieve(...) + assert "path" in kwargs + if "fname" in kwargs: # pooch.retrieve(...) self.call_args_list.append((args, kwargs)) - path = Path(kwargs['path'], kwargs['fname']) + path = Path(kwargs["path"], kwargs["fname"]) path.parent.mkdir(parents=True, exist_ok=True) - path.write_text('test') + path.write_text("test") return path else: # pooch.create(...) has been called - self.path = kwargs['path'] + self.path = kwargs["path"] return self # Wrappers for Pooch instances (e.g., in eegbci we pooch.create) @@ -761,20 +807,21 @@ def load_registry(self, registry): # We can't use monkeypatch because its scope (function-level) conflicts with # the requests fixture (module-level), so we live with a module-scoped version # that uses mock -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def options_3d(): """Disable advanced 3d rendering.""" with mock.patch.dict( - os.environ, { + os.environ, + { "MNE_3D_OPTION_ANTIALIAS": "false", "MNE_3D_OPTION_DEPTH_PEELING": "false", "MNE_3D_OPTION_SMOOTH_SHADING": "false", - } + }, ): yield -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def protect_config(): """Protect ~/.mne.""" temp = _TempDir() @@ -786,23 +833,23 @@ def protect_config(): def brain_gc(request): """Ensure that brain can be properly garbage collected.""" keys = ( - 'renderer_interactive', - 'renderer_interactive_pyvistaqt', - 'renderer', - 'renderer_pyvistaqt', - 'renderer_notebook', + "renderer_interactive", + "renderer_interactive_pyvistaqt", + "renderer", + "renderer_pyvistaqt", + "renderer_notebook", ) assert set(request.fixturenames) & set(keys) != set() for key in keys: if key in request.fixturenames: - is_pv = \ - request.getfixturevalue(key)._get_3d_backend() == 'pyvistaqt' + is_pv = request.getfixturevalue(key)._get_3d_backend() == "pyvistaqt" close_func = request.getfixturevalue(key).backend._close_all break if not is_pv: yield return from mne.viz import Brain + ignore = set(id(o) for o in gc.get_objects()) yield close_func() @@ -810,10 +857,10 @@ def brain_gc(request): try: outcome = request.node.harvest_rep_call except Exception: - outcome = 'failed' - if outcome != 'passed': + outcome = "failed" + if outcome != "passed": return - _assert_no_instances(Brain, 'after') + _assert_no_instances(Brain, "after") # Check VTK objs = gc.get_objects() bad = list() @@ -823,11 +870,11 @@ def brain_gc(request): except Exception: # old Python, probably pass else: - if name.startswith('vtk') and id(o) not in ignore: + if name.startswith("vtk") and id(o) not in ignore: bad.append(name) del o del objs, ignore, Brain - assert len(bad) == 0, 'VTK objects linger:\n' + '\n'.join(bad) + assert len(bad) == 0, "VTK objects linger:\n" + "\n".join(bad) _files = list() @@ -838,26 +885,26 @@ def pytest_sessionfinish(session, exitstatus): n = session.config.option.durations if n is None: return - print('\n') + print("\n") try: import pytest_harvest except ImportError: - print('Module-level timings require pytest-harvest') + print("Module-level timings require pytest-harvest") return # get the number to print res = pytest_harvest.get_session_synthesis_dct(session) files = dict() for key, val in res.items(): - parts = Path(key.split(':')[0]).parts + parts = Path(key.split(":")[0]).parts # split mne/tests/test_whatever.py into separate categories since these # are essentially submodule-level tests. Keeping just [:3] works, # except for mne/viz where we want level-4 granulatity - split_submodules = (('mne', 'viz'), ('mne', 'preprocessing')) - parts = parts[:4 if parts[:2] in split_submodules else 3] - if not parts[-1].endswith('.py'): - parts = parts + ('',) - file_key = '/'.join(parts) - files[file_key] = files.get(file_key, 0) + val['pytest_duration_s'] + split_submodules = (("mne", "viz"), ("mne", "preprocessing")) + parts = parts[: 4 if parts[:2] in split_submodules else 3] + if not parts[-1].endswith(".py"): + parts = parts + ("",) + file_key = "/".join(parts) + files[file_key] = files.get(file_key, 0) + val["pytest_duration_s"] files = sorted(list(files.items()), key=lambda x: x[1])[::-1] # print _files[:] = files[:n] @@ -868,36 +915,38 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config): writer = terminalreporter n = len(_files) if n: - writer.line('') # newline - writer.write_sep('=', f'slowest {n} test module{_pl(n)}') + writer.line("") # newline + writer.write_sep("=", f"slowest {n} test module{_pl(n)}") names, timings = zip(*_files) - timings = [f'{timing:0.2f}s total' for timing in timings] + timings = [f"{timing:0.2f}s total" for timing in timings] rjust = max(len(timing) for timing in timings) timings = [timing.rjust(rjust) for timing in timings] for name, timing in zip(names, timings): - writer.line(f'{timing.ljust(15)}{name}') + writer.line(f"{timing.ljust(15)}{name}") -@pytest.fixture(scope="function", params=('Numba', 'NumPy')) +@pytest.fixture(scope="function", params=("Numba", "NumPy")) def numba_conditional(monkeypatch, request): """Test both code paths on machines that have Numba.""" - assert request.param in ('Numba', 'NumPy') - if request.param == 'NumPy' and has_numba: + assert request.param in ("Numba", "NumPy") + if request.param == "NumPy" and has_numba: monkeypatch.setattr( - cluster_level, '_get_buddies', cluster_level._get_buddies_fallback) + cluster_level, "_get_buddies", cluster_level._get_buddies_fallback + ) monkeypatch.setattr( - cluster_level, '_get_selves', cluster_level._get_selves_fallback) + cluster_level, "_get_selves", cluster_level._get_selves_fallback + ) monkeypatch.setattr( - cluster_level, '_where_first', cluster_level._where_first_fallback) - monkeypatch.setattr( - numerics, '_arange_div', numerics._arange_div_fallback) - if request.param == 'Numba' and not has_numba: - pytest.skip('Numba not installed') + cluster_level, "_where_first", cluster_level._where_first_fallback + ) + monkeypatch.setattr(numerics, "_arange_div", numerics._arange_div_fallback) + if request.param == "Numba" and not has_numba: + pytest.skip("Numba not installed") yield request.param # Create one nbclient and reuse it -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def _nbclient(): try: import nbformat @@ -906,9 +955,10 @@ def _nbclient(): from ipywidgets import Button # noqa import ipyvtklink # noqa except Exception as exc: - return pytest.skip(f'Skipping Notebook test: {exc}') + return pytest.skip(f"Skipping Notebook test: {exc}") km = AsyncKernelManager(config=None) - nb = nbformat.reads(""" + nb = nbformat.reads( + """ { "cells": [ { @@ -934,7 +984,9 @@ def _nbclient(): }, "nbformat": 4, "nbformat_minor": 4 -}""", as_version=4) +}""", + as_version=4, + ) client = NotebookClient(nb, km=km) yield client try: @@ -943,7 +995,7 @@ def _nbclient(): pass -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def nbexec(_nbclient): """Execute Python code in a notebook.""" # Adapted/simplified from nbclient/client.py (BSD-3-Clause) @@ -953,7 +1005,7 @@ def execute(code, reset=False): _nbclient.reset_execution_trackers() with _nbclient.setup_kernel(): assert _nbclient.kc is not None - cell = Bunch(cell_type='code', metadata={}, source=dedent(code)) + cell = Bunch(cell_type="code", metadata={}, source=dedent(code)) _nbclient.execute_cell(cell, 0, execution_count=0) _nbclient.set_widgets_metadata() @@ -962,15 +1014,15 @@ def execute(code, reset=False): def pytest_runtest_call(item): """Run notebook code written in Python.""" - if 'nbexec' in getattr(item, 'fixturenames', ()): - nbexec = item.funcargs['nbexec'] - code = inspect.getsource(getattr(item.module, item.name.split('[')[0])) + if "nbexec" in getattr(item, "fixturenames", ()): + nbexec = item.funcargs["nbexec"] + code = inspect.getsource(getattr(item.module, item.name.split("[")[0])) code = code.splitlines() ci = 0 for ci, c in enumerate(code): - if c.startswith(' '): # actual content + if c.startswith(" "): # actual content break - code = '\n'.join(code[ci:]) + code = "\n".join(code[ci:]) def run(nbexec=nbexec, code=code): nbexec(code) @@ -979,27 +1031,32 @@ def run(nbexec=nbexec, code=code): return -@pytest.mark.filterwarnings('ignore:.*Extraction of measurement.*:') -@pytest.fixture(params=( - [nirsport2, nirsport2_snirf, testing._pytest_param()], - [nirsport2_2021_9, nirsport2_20219_snirf, testing._pytest_param()], -)) +@pytest.mark.filterwarnings("ignore:.*Extraction of measurement.*:") +@pytest.fixture( + params=( + [nirsport2, nirsport2_snirf, testing._pytest_param()], + [nirsport2_2021_9, nirsport2_20219_snirf, testing._pytest_param()], + ) +) def nirx_snirf(request): """Return a (raw_nirx, raw_snirf) matched pair.""" - pytest.importorskip('h5py') + pytest.importorskip("h5py") skipper = request.param[2].marks[0].mark if skipper.args[0]: # will skip - pytest.skip(skipper.kwargs['reason']) - return (read_raw_nirx(request.param[0], preload=True), - read_raw_snirf(request.param[1], preload=True)) + pytest.skip(skipper.kwargs["reason"]) + return ( + read_raw_nirx(request.param[0], preload=True), + read_raw_snirf(request.param[1], preload=True), + ) @pytest.fixture def qt_windows_closed(request): """Ensure that no new Qt windows are open after a test.""" - _check_skip_backend('pyvistaqt') + _check_skip_backend("pyvistaqt") app = _init_mne_qtapp() from qtpy import API_NAME + app.processEvents() gc.collect() n_before = len(app.topLevelWidgets()) @@ -1007,9 +1064,9 @@ def qt_windows_closed(request): yield app.processEvents() gc.collect() - if 'allow_unclosed' in marks: + if "allow_unclosed" in marks: return - if 'allow_unclosed_pyside2' in marks and API_NAME.lower() == 'pyside2': + if "allow_unclosed_pyside2" in marks and API_NAME.lower() == "pyside2": return # Don't check when the test fails report = request.node.stash[_phase_report_key] diff --git a/mne/coreg.py b/mne/coreg.py index 3e21f3ff917..fc9a7a9753f 100644 --- a/mne/coreg.py +++ b/mne/coreg.py @@ -21,6 +21,7 @@ from .io.constants import FIFF from .io.meas_info import Info from .io._digitization import _get_data_as_dict_from_dig + # keep get_mni_fiducials for backward compat (no burden to keep in this # namespace, too) from ._freesurfer import ( @@ -34,52 +35,82 @@ read_source_spaces, # noqa: F401 write_source_spaces, ) -from .surface import (read_surface, write_surface, _normalize_vectors, - complete_surface_info, decimate_surface, - _DistanceQuery) +from .surface import ( + read_surface, + write_surface, + _normalize_vectors, + complete_surface_info, + decimate_surface, + _DistanceQuery, +) from .bem import read_bem_surfaces, write_bem_surfaces -from .transforms import (rotation, rotation3d, scaling, translation, Transform, - _read_fs_xfm, _write_fs_xfm, invert_transform, - combine_transforms, _quat_to_euler, - _fit_matched_points, apply_trans, - rot_to_quat, _angle_between_quats) +from .transforms import ( + rotation, + rotation3d, + scaling, + translation, + Transform, + _read_fs_xfm, + _write_fs_xfm, + invert_transform, + combine_transforms, + _quat_to_euler, + _fit_matched_points, + apply_trans, + rot_to_quat, + _angle_between_quats, +) from .channels import make_dig_montage -from .utils import (get_config, get_subjects_dir, logger, pformat, verbose, - warn, fill_doc, _validate_type, - _check_subject, _check_option, _import_nibabel) +from .utils import ( + get_config, + get_subjects_dir, + logger, + pformat, + verbose, + warn, + fill_doc, + _validate_type, + _check_subject, + _check_option, + _import_nibabel, +) from .viz._3d import _fiducial_coords # some path templates -trans_fname = os.path.join('{raw_dir}', '{subject}-trans.fif') -subject_dirname = os.path.join('{subjects_dir}', '{subject}') -bem_dirname = os.path.join(subject_dirname, 'bem') -mri_dirname = os.path.join(subject_dirname, 'mri') -mri_transforms_dirname = os.path.join(subject_dirname, 'mri', 'transforms') -surf_dirname = os.path.join(subject_dirname, 'surf') +trans_fname = os.path.join("{raw_dir}", "{subject}-trans.fif") +subject_dirname = os.path.join("{subjects_dir}", "{subject}") +bem_dirname = os.path.join(subject_dirname, "bem") +mri_dirname = os.path.join(subject_dirname, "mri") +mri_transforms_dirname = os.path.join(subject_dirname, "mri", "transforms") +surf_dirname = os.path.join(subject_dirname, "surf") bem_fname = os.path.join(bem_dirname, "{subject}-{name}.fif") -head_bem_fname = pformat(bem_fname, name='head') -head_sparse_fname = pformat(bem_fname, name='head-sparse') -fid_fname = pformat(bem_fname, name='fiducials') +head_bem_fname = pformat(bem_fname, name="head") +head_sparse_fname = pformat(bem_fname, name="head-sparse") +fid_fname = pformat(bem_fname, name="fiducials") fid_fname_general = os.path.join(bem_dirname, "{head}-fiducials.fif") -src_fname = os.path.join(bem_dirname, '{subject}-{spacing}-src.fif') -_head_fnames = (os.path.join(bem_dirname, 'outer_skin.surf'), - head_sparse_fname, - head_bem_fname) -_high_res_head_fnames = (os.path.join(bem_dirname, '{subject}-head-dense.fif'), - os.path.join(surf_dirname, 'lh.seghead'), - os.path.join(surf_dirname, 'lh.smseghead')) +src_fname = os.path.join(bem_dirname, "{subject}-{spacing}-src.fif") +_head_fnames = ( + os.path.join(bem_dirname, "outer_skin.surf"), + head_sparse_fname, + head_bem_fname, +) +_high_res_head_fnames = ( + os.path.join(bem_dirname, "{subject}-head-dense.fif"), + os.path.join(surf_dirname, "lh.seghead"), + os.path.join(surf_dirname, "lh.smseghead"), +) def _map_fid_name_to_idx(name: str) -> int: """Map a fiducial name to its index in the DigMontage.""" name = name.lower() - if name == 'lpa': + if name == "lpa": return 0 - elif name == 'nasion': + elif name == "nasion": return 1 else: - assert name == 'rpa' + assert name == "rpa" return 2 @@ -90,7 +121,7 @@ def _make_writable(fname): def _make_writable_recursive(path): """Recursively set writable.""" - if sys.platform.startswith('win'): + if sys.platform.startswith("win"): return # can't safely set perms for root, dirs, files in os.walk(path, topdown=False): for f in dirs + files: @@ -132,21 +163,19 @@ def coregister_fiducials(info, fiducials, tol=0.01): fiducials, coord_frame_to = read_fiducials(fiducials) else: coord_frame_to = FIFF.FIFFV_COORD_MRI - frames_from = {d['coord_frame'] for d in info['dig']} + frames_from = {d["coord_frame"] for d in info["dig"]} if len(frames_from) > 1: - raise ValueError("info contains fiducials from different coordinate " - "frames") + raise ValueError("info contains fiducials from different coordinate " "frames") else: coord_frame_from = frames_from.pop() - coords_from = _fiducial_coords(info['dig']) + coords_from = _fiducial_coords(info["dig"]) coords_to = _fiducial_coords(fiducials, coord_frame_to) trans = fit_matched_points(coords_from, coords_to, tol=tol) return Transform(coord_frame_from, coord_frame_to, trans) @verbose -def create_default_subject(fs_home=None, update=False, subjects_dir=None, - verbose=None): +def create_default_subject(fs_home=None, update=False, subjects_dir=None, verbose=None): """Create an average brain subject for subjects without structural MRI. Create a copy of fsaverage from the Freesurfer directory in subjects_dir @@ -177,37 +206,43 @@ def create_default_subject(fs_home=None, update=False, subjects_dir=None, """ subjects_dir = str(get_subjects_dir(subjects_dir, raise_error=True)) if fs_home is None: - fs_home = get_config('FREESURFER_HOME', fs_home) + fs_home = get_config("FREESURFER_HOME", fs_home) if fs_home is None: raise ValueError( "FREESURFER_HOME environment variable not found. Please " "specify the fs_home parameter in your call to " - "create_default_subject().") + "create_default_subject()." + ) # make sure freesurfer files exist - fs_src = os.path.join(fs_home, 'subjects', 'fsaverage') + fs_src = os.path.join(fs_home, "subjects", "fsaverage") if not os.path.exists(fs_src): - raise OSError('fsaverage not found at %r. Is fs_home specified ' - 'correctly?' % fs_src) - for name in ('label', 'mri', 'surf'): + raise OSError( + "fsaverage not found at %r. Is fs_home specified " "correctly?" % fs_src + ) + for name in ("label", "mri", "surf"): dirname = os.path.join(fs_src, name) if not os.path.isdir(dirname): - raise OSError("Freesurfer fsaverage seems to be incomplete: No " - "directory named %s found in %s" % (name, fs_src)) + raise OSError( + "Freesurfer fsaverage seems to be incomplete: No " + "directory named %s found in %s" % (name, fs_src) + ) # make sure destination does not already exist - dest = os.path.join(subjects_dir, 'fsaverage') + dest = os.path.join(subjects_dir, "fsaverage") if dest == fs_src: raise OSError( "Your subjects_dir points to the freesurfer subjects_dir (%r). " "The default subject can not be created in the freesurfer " "installation directory; please specify a different " - "subjects_dir." % subjects_dir) + "subjects_dir." % subjects_dir + ) elif (not update) and os.path.exists(dest): raise OSError( "Can not create fsaverage because %r already exists in " "subjects_dir %r. Delete or rename the existing fsaverage " - "subject folder." % ('fsaverage', subjects_dir)) + "subject folder." % ("fsaverage", subjects_dir) + ) # copy fsaverage from freesurfer logger.info("Copying fsaverage subject from freesurfer directory...") @@ -216,15 +251,16 @@ def create_default_subject(fs_home=None, update=False, subjects_dir=None, _make_writable_recursive(dest) # copy files from mne - source_fname = os.path.join(os.path.dirname(__file__), 'data', 'fsaverage', - 'fsaverage-%s.fif') - dest_bem = os.path.join(dest, 'bem') + source_fname = os.path.join( + os.path.dirname(__file__), "data", "fsaverage", "fsaverage-%s.fif" + ) + dest_bem = os.path.join(dest, "bem") if not os.path.exists(dest_bem): os.mkdir(dest_bem) logger.info("Copying auxiliary fsaverage files from mne...") - dest_fname = os.path.join(dest_bem, 'fsaverage-%s.fif') + dest_fname = os.path.join(dest_bem, "fsaverage-%s.fif") _make_writable_recursive(dest_bem) - for name in ('fiducials', 'head', 'inner_skull-bem', 'trans'): + for name in ("fiducials", "head", "inner_skull-bem", "trans"): if not os.path.exists(dest_fname % name): shutil.copy(source_fname % name, dest_bem) @@ -249,10 +285,11 @@ def _decimate_points(pts, res=10): The decimated points. """ from scipy.spatial.distance import cdist + pts = np.asarray(pts) # find the bin edges for the voxel space - xmin, ymin, zmin = pts.min(0) - res / 2. + xmin, ymin, zmin = pts.min(0) - res / 2.0 xmax, ymax, zmax = pts.max(0) + res xax = np.arange(xmin, xmax, res) yax = np.arange(ymin, ymax, res) @@ -264,19 +301,18 @@ def _decimate_points(pts, res=10): x = xax[xbins] y = yax[ybins] z = zax[zbins] - mids = np.c_[x, y, z] + res / 2. + mids = np.c_[x, y, z] + res / 2.0 # each point belongs to at most one voxel center, so figure those out # (cKDTree faster than BallTree for these small problems) - tree = _DistanceQuery(mids, method='cKDTree') + tree = _DistanceQuery(mids, method="cKDTree") _, mid_idx = tree.query(pts) # then figure out which to actually use based on proximity # (take advantage of sorting the mid_idx to get our mapping of # pts to nearest voxel midpoint) sort_idx = np.argsort(mid_idx) - bounds = np.cumsum( - np.concatenate([[0], np.bincount(mid_idx, minlength=len(mids))])) + bounds = np.cumsum(np.concatenate([[0], np.bincount(mid_idx, minlength=len(mids))])) assert len(bounds) == len(mids) + 1 out = list() for mi, mid in enumerate(mids): @@ -287,14 +323,13 @@ def _decimate_points(pts, res=10): # But it's faster for many points than making a big boolean indexer # over and over (esp. since each point can only belong to a single # voxel). - use_pts = pts[sort_idx[bounds[mi]:bounds[mi + 1]]] + use_pts = pts[sort_idx[bounds[mi] : bounds[mi + 1]]] if not len(use_pts): out.append([np.inf] * 3) else: - out.append( - use_pts[np.argmin(cdist(use_pts, mid[np.newaxis])[:, 0])]) + out.append(use_pts[np.argmin(cdist(use_pts, mid[np.newaxis])[:, 0])]) out = np.array(out, float).reshape(-1, 3) - out = out[np.abs(out - mids).max(axis=1) < res / 2.] + out = out[np.abs(out - mids).max(axis=1) < res / 2.0] # """ return out @@ -312,7 +347,7 @@ def _trans_from_params(param_info, params): i += 3 if do_translate: - x, y, z = params[i:i + 3] + x, y, z = params[i : i + 3] trans.insert(0, translation(x, y, z)) i += 3 @@ -320,7 +355,7 @@ def _trans_from_params(param_info, params): s = params[i] trans.append(scaling(s, s, s)) elif do_scale == 3: - x, y, z = params[i:i + 3] + x, y, z = params[i : i + 3] trans.append(scaling(x, y, z)) trans = reduce(np.dot, trans) @@ -331,9 +366,17 @@ def _trans_from_params(param_info, params): # XXX this function should be moved out of coreg as used elsewhere -def fit_matched_points(src_pts, tgt_pts, rotate=True, translate=True, - scale=False, tol=None, x0=None, out='trans', - weights=None): +def fit_matched_points( + src_pts, + tgt_pts, + rotate=True, + translate=True, + scale=False, + tol=None, + x0=None, + out="trans", + weights=None, +): """Find a transform between matched sets of points. This minimizes the squared distance between two matching sets of points. @@ -378,13 +421,21 @@ def fit_matched_points(src_pts, tgt_pts, rotate=True, translate=True, src_pts = np.atleast_2d(src_pts) tgt_pts = np.atleast_2d(tgt_pts) if src_pts.shape != tgt_pts.shape: - raise ValueError("src_pts and tgt_pts must have same shape (got " - "{}, {})".format(src_pts.shape, tgt_pts.shape)) + raise ValueError( + "src_pts and tgt_pts must have same shape (got " + "{}, {})".format(src_pts.shape, tgt_pts.shape) + ) if weights is not None: weights = np.asarray(weights, src_pts.dtype) if weights.ndim != 1 or weights.size not in (src_pts.shape[0], 1): - raise ValueError("weights (shape=%s) must be None or have shape " - "(%s,)" % (weights.shape, src_pts.shape[0],)) + raise ValueError( + "weights (shape=%s) must be None or have shape " + "(%s,)" + % ( + weights.shape, + src_pts.shape[0], + ) + ) weights = weights[:, np.newaxis] param_info = (bool(rotate), bool(translate), int(scale)) @@ -397,15 +448,14 @@ def fit_matched_points(src_pts, tgt_pts, rotate=True, translate=True, tgt_pts = np.asarray(tgt_pts, float) if weights is not None: weights = np.asarray(weights, float) - x, s = _fit_matched_points( - src_pts, tgt_pts, weights, bool(param_info[2])) + x, s = _fit_matched_points(src_pts, tgt_pts, weights, bool(param_info[2])) x[:3] = _quat_to_euler(x[:3]) x = np.concatenate((x, [s])) if param_info[2] else x else: x = _generic_fit(src_pts, tgt_pts, param_info, weights, x0) # re-create the final transformation matrix - if (tol is not None) or (out == 'trans'): + if (tol is not None) or (out == "trans"): trans = _trans_from_params(param_info, x) # assess the error of the solution @@ -416,21 +466,24 @@ def fit_matched_points(src_pts, tgt_pts, rotate=True, translate=True, if np.any(err > tol): raise RuntimeError("Error exceeds tolerance. Error = %r" % err) - if out == 'params': + if out == "params": return x - elif out == 'trans': + elif out == "trans": return trans else: - raise ValueError("Invalid out parameter: %r. Needs to be 'params' or " - "'trans'." % out) + raise ValueError( + "Invalid out parameter: %r. Needs to be 'params' or " "'trans'." % out + ) def _generic_fit(src_pts, tgt_pts, param_info, weights, x0): from scipy.optimize import leastsq + if param_info[1]: # translate src_pts = np.hstack((src_pts, np.ones((len(src_pts), 1)))) if param_info == (True, False, 0): + def error(x): rx, ry, rz = x trans = rotation3d(rx, ry, rz) @@ -439,9 +492,11 @@ def error(x): if weights is not None: d *= weights return d.ravel() + if x0 is None: x0 = (0, 0, 0) elif param_info == (True, True, 0): + def error(x): rx, ry, rz, tx, ty, tz = x trans = np.dot(translation(tx, ty, tz), rotation(rx, ry, rz)) @@ -450,44 +505,52 @@ def error(x): if weights is not None: d *= weights return d.ravel() + if x0 is None: x0 = (0, 0, 0, 0, 0, 0) elif param_info == (True, True, 1): + def error(x): rx, ry, rz, tx, ty, tz, s = x - trans = reduce(np.dot, (translation(tx, ty, tz), - rotation(rx, ry, rz), - scaling(s, s, s))) + trans = reduce( + np.dot, + (translation(tx, ty, tz), rotation(rx, ry, rz), scaling(s, s, s)), + ) est = np.dot(src_pts, trans.T)[:, :3] d = tgt_pts - est if weights is not None: d *= weights return d.ravel() + if x0 is None: x0 = (0, 0, 0, 0, 0, 0, 1) elif param_info == (True, True, 3): + def error(x): rx, ry, rz, tx, ty, tz, sx, sy, sz = x - trans = reduce(np.dot, (translation(tx, ty, tz), - rotation(rx, ry, rz), - scaling(sx, sy, sz))) + trans = reduce( + np.dot, + (translation(tx, ty, tz), rotation(rx, ry, rz), scaling(sx, sy, sz)), + ) est = np.dot(src_pts, trans.T)[:, :3] d = tgt_pts - est if weights is not None: d *= weights return d.ravel() + if x0 is None: x0 = (0, 0, 0, 0, 0, 0, 1, 1, 1) else: raise NotImplementedError( "The specified parameter combination is not implemented: " - "rotate=%r, translate=%r, scale=%r" % param_info) + "rotate=%r, translate=%r, scale=%r" % param_info + ) x, _, _, _, _ = leastsq(error, x0, full_output=True) return x -def _find_label_paths(subject='fsaverage', pattern=None, subjects_dir=None): +def _find_label_paths(subject="fsaverage", pattern=None, subjects_dir=None): """Find paths to label files in a subject's label directory. Parameters @@ -515,7 +578,7 @@ def _find_label_paths(subject='fsaverage', pattern=None, subjects_dir=None): paths = [] for dirpath, _, filenames in os.walk(lbl_dir): rel_dir = os.path.relpath(dirpath, lbl_dir) - for filename in fnmatch.filter(filenames, '*.label'): + for filename in fnmatch.filter(filenames, "*.label"): path = os.path.join(rel_dir, filename) paths.append(path) else: @@ -548,41 +611,56 @@ def _find_mri_paths(subject, skip_fiducials, subjects_dir): paths = {} # directories to create - paths['dirs'] = [bem_dirname, surf_dirname] + paths["dirs"] = [bem_dirname, surf_dirname] # surf/ files - paths['surf'] = [] - surf_fname = os.path.join(surf_dirname, '{name}') - surf_names = ('inflated', 'white', 'orig', 'orig_avg', 'inflated_avg', - 'inflated_pre', 'pial', 'pial_avg', 'smoothwm', 'white_avg', - 'seghead', 'smseghead') - if os.getenv('_MNE_FEW_SURFACES', '') == 'true': # for testing + paths["surf"] = [] + surf_fname = os.path.join(surf_dirname, "{name}") + surf_names = ( + "inflated", + "white", + "orig", + "orig_avg", + "inflated_avg", + "inflated_pre", + "pial", + "pial_avg", + "smoothwm", + "white_avg", + "seghead", + "smseghead", + ) + if os.getenv("_MNE_FEW_SURFACES", "") == "true": # for testing surf_names = surf_names[:4] for surf_name in surf_names: - for hemi in ('lh.', 'rh.'): + for hemi in ("lh.", "rh."): name = hemi + surf_name - path = surf_fname.format(subjects_dir=subjects_dir, - subject=subject, name=name) + path = surf_fname.format( + subjects_dir=subjects_dir, subject=subject, name=name + ) if os.path.exists(path): - paths['surf'].append(pformat(surf_fname, name=name)) - surf_fname = os.path.join(bem_dirname, '{name}') - surf_names = ('inner_skull.surf', 'outer_skull.surf', 'outer_skin.surf') + paths["surf"].append(pformat(surf_fname, name=name)) + surf_fname = os.path.join(bem_dirname, "{name}") + surf_names = ("inner_skull.surf", "outer_skull.surf", "outer_skin.surf") for surf_name in surf_names: - path = surf_fname.format(subjects_dir=subjects_dir, - subject=subject, name=surf_name) + path = surf_fname.format( + subjects_dir=subjects_dir, subject=subject, name=surf_name + ) if os.path.exists(path): - paths['surf'].append(pformat(surf_fname, name=surf_name)) + paths["surf"].append(pformat(surf_fname, name=surf_name)) del surf_names, surf_name, path, hemi # BEM files - paths['bem'] = bem = [] + paths["bem"] = bem = [] path = head_bem_fname.format(subjects_dir=subjects_dir, subject=subject) if os.path.exists(path): - bem.append('head') - bem_pattern = pformat(bem_fname, subjects_dir=subjects_dir, - subject=subject, name='*-bem') - re_pattern = pformat(bem_fname, subjects_dir=subjects_dir, subject=subject, - name='(.+)').replace('\\', '\\\\') + bem.append("head") + bem_pattern = pformat( + bem_fname, subjects_dir=subjects_dir, subject=subject, name="*-bem" + ) + re_pattern = pformat( + bem_fname, subjects_dir=subjects_dir, subject=subject, name="(.+)" + ).replace("\\", "\\\\") for path in iglob(bem_pattern): match = re.match(re_pattern, path) name = match.group(1) @@ -591,54 +669,57 @@ def _find_mri_paths(subject, skip_fiducials, subjects_dir): # fiducials if skip_fiducials: - paths['fid'] = [] + paths["fid"] = [] else: - paths['fid'] = _find_fiducials_files(subject, subjects_dir) + paths["fid"] = _find_fiducials_files(subject, subjects_dir) # check that we found at least one - if len(paths['fid']) == 0: - raise OSError("No fiducials file found for %s. The fiducials " - "file should be named " - "{subject}/bem/{subject}-fiducials.fif. In " - "order to scale an MRI without fiducials set " - "skip_fiducials=True." % subject) + if len(paths["fid"]) == 0: + raise OSError( + "No fiducials file found for %s. The fiducials " + "file should be named " + "{subject}/bem/{subject}-fiducials.fif. In " + "order to scale an MRI without fiducials set " + "skip_fiducials=True." % subject + ) # duplicate files (curvature and some surfaces) - paths['duplicate'] = [] - path = os.path.join(surf_dirname, '{name}') - surf_fname = os.path.join(surf_dirname, '{name}') - surf_dup_names = ('curv', 'sphere', 'sphere.reg', 'sphere.reg.avg') + paths["duplicate"] = [] + path = os.path.join(surf_dirname, "{name}") + surf_fname = os.path.join(surf_dirname, "{name}") + surf_dup_names = ("curv", "sphere", "sphere.reg", "sphere.reg.avg") for surf_dup_name in surf_dup_names: - for hemi in ('lh.', 'rh.'): + for hemi in ("lh.", "rh."): name = hemi + surf_dup_name - path = surf_fname.format(subjects_dir=subjects_dir, - subject=subject, name=name) + path = surf_fname.format( + subjects_dir=subjects_dir, subject=subject, name=name + ) if os.path.exists(path): - paths['duplicate'].append(pformat(surf_fname, name=name)) + paths["duplicate"].append(pformat(surf_fname, name=name)) del surf_dup_name, name, path, hemi # transform files (talairach) - paths['transforms'] = [] - transform_fname = os.path.join(mri_transforms_dirname, 'talairach.xfm') + paths["transforms"] = [] + transform_fname = os.path.join(mri_transforms_dirname, "talairach.xfm") path = transform_fname.format(subjects_dir=subjects_dir, subject=subject) if os.path.exists(path): - paths['transforms'].append(transform_fname) + paths["transforms"].append(transform_fname) del transform_fname, path # find source space files - paths['src'] = src = [] + paths["src"] = src = [] bem_dir = bem_dirname.format(subjects_dir=subjects_dir, subject=subject) - fnames = fnmatch.filter(os.listdir(bem_dir), '*-src.fif') - prefix = subject + '-' + fnames = fnmatch.filter(os.listdir(bem_dir), "*-src.fif") + prefix = subject + "-" for fname in fnames: if fname.startswith(prefix): - fname = "{subject}-%s" % fname[len(prefix):] + fname = "{subject}-%s" % fname[len(prefix) :] path = os.path.join(bem_dirname, fname) src.append(path) # find MRIs mri_dir = mri_dirname.format(subjects_dir=subjects_dir, subject=subject) - fnames = fnmatch.filter(os.listdir(mri_dir), '*.mgz') - paths['mri'] = [os.path.join(mri_dir, f) for f in fnames] + fnames = fnmatch.filter(os.listdir(mri_dir), "*.mgz") + paths["mri"] = [os.path.join(mri_dir, f) for f in fnames] return paths @@ -647,17 +728,18 @@ def _find_fiducials_files(subject, subjects_dir): """Find fiducial files.""" fid = [] # standard fiducials - if os.path.exists(fid_fname.format(subjects_dir=subjects_dir, - subject=subject)): + if os.path.exists(fid_fname.format(subjects_dir=subjects_dir, subject=subject)): fid.append(fid_fname) # fiducials with subject name - pattern = pformat(fid_fname_general, subjects_dir=subjects_dir, - subject=subject, head='*') - regex = pformat(fid_fname_general, subjects_dir=subjects_dir, - subject=subject, head='(.+)').replace('\\', '\\\\') + pattern = pformat( + fid_fname_general, subjects_dir=subjects_dir, subject=subject, head="*" + ) + regex = pformat( + fid_fname_general, subjects_dir=subjects_dir, subject=subject, head="(.+)" + ).replace("\\", "\\\\") for path in iglob(pattern): match = re.match(regex, path) - head = match.group(1).replace(subject, '{subject}') + head = match.group(1).replace(subject, "{subject}") fid.append(pformat(fid_fname_general, head=head)) return fid @@ -678,8 +760,10 @@ def _is_mri_subject(subject, subjects_dir=None): Whether ``subject`` is an mri subject. """ subjects_dir = str(get_subjects_dir(subjects_dir, raise_error=True)) - return bool(_find_head_bem(subject, subjects_dir) or - _find_head_bem(subject, subjects_dir, high_res=True)) + return bool( + _find_head_bem(subject, subjects_dir) + or _find_head_bem(subject, subjects_dir, high_res=True) + ) def _is_scaled_mri_subject(subject, subjects_dir=None): @@ -720,8 +804,7 @@ def _mri_subject_has_bem(subject, subjects_dir=None): Whether ``subject`` has a bem file. """ subjects_dir = str(get_subjects_dir(subjects_dir, raise_error=True)) - pattern = bem_fname.format(subjects_dir=subjects_dir, subject=subject, - name='*-bem') + pattern = bem_fname.format(subjects_dir=subjects_dir, subject=subject, name="*-bem") fnames = glob(pattern) return bool(len(fnames)) @@ -745,23 +828,28 @@ def read_mri_cfg(subject, subjects_dir=None): fname = subjects_dir / subject / "MRI scaling parameters.cfg" if not fname.exists(): - raise OSError("%r does not seem to be a scaled mri subject: %r does " - "not exist." % (subject, fname)) + raise OSError( + "%r does not seem to be a scaled mri subject: %r does " + "not exist." % (subject, fname) + ) logger.info("Reading MRI cfg file %s" % fname) config = configparser.RawConfigParser() config.read(fname) - n_params = config.getint("MRI Scaling", 'n_params') + n_params = config.getint("MRI Scaling", "n_params") if n_params == 1: - scale = config.getfloat("MRI Scaling", 'scale') + scale = config.getfloat("MRI Scaling", "scale") elif n_params == 3: - scale_str = config.get("MRI Scaling", 'scale') + scale_str = config.get("MRI Scaling", "scale") scale = np.array([float(s) for s in scale_str.split()]) else: raise ValueError("Invalid n_params value in MRI cfg: %i" % n_params) - out = {'subject_from': config.get("MRI Scaling", 'subject_from'), - 'n_params': n_params, 'scale': scale} + out = { + "subject_from": config.get("MRI Scaling", "subject_from"), + "n_params": n_params, + "scale": scale, + } return out @@ -787,15 +875,15 @@ def _write_mri_config(fname, subject_from, subject_to, scale): config = configparser.RawConfigParser() config.add_section("MRI Scaling") - config.set("MRI Scaling", 'subject_from', subject_from) - config.set("MRI Scaling", 'subject_to', subject_to) - config.set("MRI Scaling", 'n_params', str(n_params)) + config.set("MRI Scaling", "subject_from", subject_from) + config.set("MRI Scaling", "subject_to", subject_to) + config.set("MRI Scaling", "n_params", str(n_params)) if n_params == 1: - config.set("MRI Scaling", 'scale', str(scale)) + config.set("MRI Scaling", "scale", str(scale)) else: - config.set("MRI Scaling", 'scale', ' '.join([str(s) for s in scale])) - config.set("MRI Scaling", 'version', '1') - with open(fname, 'w') as fid: + config.set("MRI Scaling", "scale", " ".join([str(s) for s in scale])) + config.set("MRI Scaling", "version", "1") + with open(fname, "w") as fid: config.write(fid) @@ -816,27 +904,38 @@ def _scale_params(subject_to, subject_from, scale, subjects_dir): """ subjects_dir = get_subjects_dir(subjects_dir, raise_error=True) if (subject_from is None) != (scale is None): - raise TypeError("Need to provide either both subject_from and scale " - "parameters, or neither.") + raise TypeError( + "Need to provide either both subject_from and scale " + "parameters, or neither." + ) if subject_from is None: cfg = read_mri_cfg(subject_to, subjects_dir) - subject_from = cfg['subject_from'] - n_params = cfg['n_params'] + subject_from = cfg["subject_from"] + n_params = cfg["n_params"] assert n_params in (1, 3) - scale = cfg['scale'] + scale = cfg["scale"] scale = np.atleast_1d(scale) if scale.ndim != 1 or scale.shape[0] not in (1, 3): - raise ValueError("Invalid shape for scale parameter. Need scalar " - "or array of length 3. Got shape %s." - % (scale.shape,)) + raise ValueError( + "Invalid shape for scale parameter. Need scalar " + "or array of length 3. Got shape %s." % (scale.shape,) + ) n_params = len(scale) return str(subjects_dir), subject_from, scale, n_params == 1 @verbose -def scale_bem(subject_to, bem_name, subject_from=None, scale=None, - subjects_dir=None, *, on_defects='raise', verbose=None): +def scale_bem( + subject_to, + bem_name, + subject_from=None, + scale=None, + subjects_dir=None, + *, + on_defects="raise", + verbose=None, +): """Scale a bem file. Parameters @@ -860,29 +959,36 @@ def scale_bem(subject_to, bem_name, subject_from=None, scale=None, .. versionadded:: 1.0 %(verbose)s """ - subjects_dir, subject_from, scale, uniform = \ - _scale_params(subject_to, subject_from, scale, subjects_dir) + subjects_dir, subject_from, scale, uniform = _scale_params( + subject_to, subject_from, scale, subjects_dir + ) - src = bem_fname.format(subjects_dir=subjects_dir, subject=subject_from, - name=bem_name) - dst = bem_fname.format(subjects_dir=subjects_dir, subject=subject_to, - name=bem_name) + src = bem_fname.format( + subjects_dir=subjects_dir, subject=subject_from, name=bem_name + ) + dst = bem_fname.format(subjects_dir=subjects_dir, subject=subject_to, name=bem_name) if os.path.exists(dst): raise OSError("File already exists: %s" % dst) surfs = read_bem_surfaces(src, on_defects=on_defects) for surf in surfs: - surf['rr'] *= scale + surf["rr"] *= scale if not uniform: - assert len(surf['nn']) > 0 - surf['nn'] /= scale - _normalize_vectors(surf['nn']) + assert len(surf["nn"]) > 0 + surf["nn"] /= scale + _normalize_vectors(surf["nn"]) write_bem_surfaces(dst, surfs) -def scale_labels(subject_to, pattern=None, overwrite=False, subject_from=None, - scale=None, subjects_dir=None): +def scale_labels( + subject_to, + pattern=None, + overwrite=False, + subject_from=None, + scale=None, + subjects_dir=None, +): r"""Scale labels to match a brain that was previously created by scaling. Parameters @@ -907,7 +1013,8 @@ def scale_labels(subject_to, pattern=None, overwrite=False, subject_from=None, Override the ``SUBJECTS_DIR`` environment variable. """ subjects_dir, subject_from, scale, _ = _scale_params( - subject_to, subject_from, scale, subjects_dir) + subject_to, subject_from, scale, subjects_dir + ) # find labels paths = _find_label_paths(subject_from, pattern, subjects_dir) @@ -930,15 +1037,31 @@ def scale_labels(subject_to, pattern=None, overwrite=False, subject_from=None, src = src_root / fname l_old = read_label(src) pos = l_old.pos * scale - l_new = Label(l_old.vertices, pos, l_old.values, l_old.hemi, - l_old.comment, subject=subject_to) + l_new = Label( + l_old.vertices, + pos, + l_old.values, + l_old.hemi, + l_old.comment, + subject=subject_to, + ) l_new.save(dst) @verbose -def scale_mri(subject_from, subject_to, scale, overwrite=False, - subjects_dir=None, skip_fiducials=False, labels=True, - annot=False, *, on_defects='raise', verbose=None): +def scale_mri( + subject_from, + subject_to, + scale, + overwrite=False, + subjects_dir=None, + skip_fiducials=False, + labels=True, + annot=False, + *, + on_defects="raise", + verbose=None, +): """Create a scaled copy of an MRI subject. Parameters @@ -984,99 +1107,119 @@ def scale_mri(subject_from, subject_to, scale, overwrite=False, if np.isclose(scale[1], scale[0]) and np.isclose(scale[2], scale[0]): scale = scale[0] # speed up scaling conditionals using a singleton elif scale.shape != (1,): - raise ValueError('scale must have shape (3,) or (1,), got %s' - % (scale.shape,)) + raise ValueError("scale must have shape (3,) or (1,), got %s" % (scale.shape,)) # make sure we have an empty target directory - dest = subject_dirname.format(subject=subject_to, - subjects_dir=subjects_dir) + dest = subject_dirname.format(subject=subject_to, subjects_dir=subjects_dir) if os.path.exists(dest): if not overwrite: - raise OSError("Subject directory for %s already exists: %r" - % (subject_to, dest)) + raise OSError( + "Subject directory for %s already exists: %r" % (subject_to, dest) + ) shutil.rmtree(dest) - logger.debug('create empty directory structure') - for dirname in paths['dirs']: + logger.debug("create empty directory structure") + for dirname in paths["dirs"]: dir_ = dirname.format(subject=subject_to, subjects_dir=subjects_dir) os.makedirs(dir_) - logger.debug('save MRI scaling parameters') - fname = os.path.join(dest, 'MRI scaling parameters.cfg') + logger.debug("save MRI scaling parameters") + fname = os.path.join(dest, "MRI scaling parameters.cfg") _write_mri_config(fname, subject_from, subject_to, scale) - logger.debug('surf files [in mm]') - for fname in paths['surf']: + logger.debug("surf files [in mm]") + for fname in paths["surf"]: src = fname.format(subject=subject_from, subjects_dir=subjects_dir) src = os.path.realpath(src) dest = fname.format(subject=subject_to, subjects_dir=subjects_dir) pts, tri = read_surface(src) write_surface(dest, pts * scale, tri) - logger.debug('BEM files [in m]') - for bem_name in paths['bem']: - scale_bem(subject_to, bem_name, subject_from, scale, subjects_dir, - on_defects=on_defects, verbose=False) + logger.debug("BEM files [in m]") + for bem_name in paths["bem"]: + scale_bem( + subject_to, + bem_name, + subject_from, + scale, + subjects_dir, + on_defects=on_defects, + verbose=False, + ) - logger.debug('fiducials [in m]') - for fname in paths['fid']: + logger.debug("fiducials [in m]") + for fname in paths["fid"]: src = fname.format(subject=subject_from, subjects_dir=subjects_dir) src = os.path.realpath(src) pts, cframe = read_fiducials(src, verbose=False) for pt in pts: - pt['r'] = pt['r'] * scale + pt["r"] = pt["r"] * scale dest = fname.format(subject=subject_to, subjects_dir=subjects_dir) write_fiducials(dest, pts, cframe, overwrite=True, verbose=False) - logger.debug('MRIs [nibabel]') - os.mkdir(mri_dirname.format(subjects_dir=subjects_dir, - subject=subject_to)) - for fname in paths['mri']: + logger.debug("MRIs [nibabel]") + os.mkdir(mri_dirname.format(subjects_dir=subjects_dir, subject=subject_to)) + for fname in paths["mri"]: mri_name = os.path.basename(fname) _scale_mri(subject_to, mri_name, subject_from, scale, subjects_dir) - logger.debug('Transforms') - for mri_name in paths['mri']: - if mri_name.endswith('T1.mgz'): - os.mkdir(mri_transforms_dirname.format(subjects_dir=subjects_dir, - subject=subject_to)) - for fname in paths['transforms']: + logger.debug("Transforms") + for mri_name in paths["mri"]: + if mri_name.endswith("T1.mgz"): + os.mkdir( + mri_transforms_dirname.format( + subjects_dir=subjects_dir, subject=subject_to + ) + ) + for fname in paths["transforms"]: xfm_name = os.path.basename(fname) - _scale_xfm(subject_to, xfm_name, mri_name, - subject_from, scale, subjects_dir) + _scale_xfm( + subject_to, xfm_name, mri_name, subject_from, scale, subjects_dir + ) break - logger.debug('duplicate files') - for fname in paths['duplicate']: + logger.debug("duplicate files") + for fname in paths["duplicate"]: src = fname.format(subject=subject_from, subjects_dir=subjects_dir) dest = fname.format(subject=subject_to, subjects_dir=subjects_dir) shutil.copyfile(src, dest) - logger.debug('source spaces') - for fname in paths['src']: + logger.debug("source spaces") + for fname in paths["src"]: src_name = os.path.basename(fname) - scale_source_space(subject_to, src_name, subject_from, scale, - subjects_dir, verbose=False) + scale_source_space( + subject_to, src_name, subject_from, scale, subjects_dir, verbose=False + ) - logger.debug('labels [in m]') - os.mkdir(os.path.join(subjects_dir, subject_to, 'label')) + logger.debug("labels [in m]") + os.mkdir(os.path.join(subjects_dir, subject_to, "label")) if labels: - scale_labels(subject_to, subject_from=subject_from, scale=scale, - subjects_dir=subjects_dir) + scale_labels( + subject_to, + subject_from=subject_from, + scale=scale, + subjects_dir=subjects_dir, + ) - logger.debug('copy *.annot files') + logger.debug("copy *.annot files") # they don't contain scale-dependent information if annot: - src_pattern = os.path.join(subjects_dir, subject_from, 'label', - '*.annot') - dst_dir = os.path.join(subjects_dir, subject_to, 'label') + src_pattern = os.path.join(subjects_dir, subject_from, "label", "*.annot") + dst_dir = os.path.join(subjects_dir, subject_to, "label") for src_file in iglob(src_pattern): shutil.copy(src_file, dst_dir) @verbose -def scale_source_space(subject_to, src_name, subject_from=None, scale=None, - subjects_dir=None, n_jobs=None, verbose=None): +def scale_source_space( + subject_to, + src_name, + subject_from=None, + scale=None, + subjects_dir=None, + n_jobs=None, + verbose=None, +): """Scale a source space for an mri created with scale_mri(). Parameters @@ -1110,8 +1253,9 @@ def scale_source_space(subject_to, src_name, subject_from=None, scale=None, are updated so that source estimates can be plotted on the original MRI volume. """ - subjects_dir, subject_from, scale, uniform = \ - _scale_params(subject_to, subject_from, scale, subjects_dir) + subjects_dir, subject_from, scale, uniform = _scale_params( + subject_to, subject_from, scale, subjects_dir + ) # if n_params==1 scale is a scalar; if n_params==3 scale is a (3,) array # find the source space file names @@ -1121,45 +1265,46 @@ def scale_source_space(subject_to, src_name, subject_from=None, scale=None, else: match = re.match(r"(oct|ico|vol)-?(\d+)$", src_name) if match: - spacing = '-'.join(match.groups()) + spacing = "-".join(match.groups()) src_pattern = src_fname else: spacing = None src_pattern = os.path.join(bem_dirname, src_name) - src = src_pattern.format(subjects_dir=subjects_dir, subject=subject_from, - spacing=spacing) - dst = src_pattern.format(subjects_dir=subjects_dir, subject=subject_to, - spacing=spacing) + src = src_pattern.format( + subjects_dir=subjects_dir, subject=subject_from, spacing=spacing + ) + dst = src_pattern.format( + subjects_dir=subjects_dir, subject=subject_to, spacing=spacing + ) # read and scale the source space [in m] sss = read_source_spaces(src) - logger.info("scaling source space %s: %s -> %s", spacing, subject_from, - subject_to) + logger.info("scaling source space %s: %s -> %s", spacing, subject_from, subject_to) logger.info("Scale factor: %s", scale) add_dist = False for ss in sss: - ss['subject_his_id'] = subject_to - ss['rr'] *= scale + ss["subject_his_id"] = subject_to + ss["rr"] *= scale # additional tags for volume source spaces - for key in ('vox_mri_t', 'src_mri_t'): + for key in ("vox_mri_t", "src_mri_t"): # maintain transform to original MRI volume ss['mri_volume_name'] if key in ss: - ss[key]['trans'][:3] *= scale[:, np.newaxis] + ss[key]["trans"][:3] *= scale[:, np.newaxis] # distances and patch info if uniform: - if ss['dist'] is not None: - ss['dist'] *= scale[0] + if ss["dist"] is not None: + ss["dist"] *= scale[0] # Sometimes this is read-only due to how it's read - ss['nearest_dist'] = ss['nearest_dist'] * scale - ss['dist_limit'] = ss['dist_limit'] * scale + ss["nearest_dist"] = ss["nearest_dist"] * scale + ss["dist_limit"] = ss["dist_limit"] * scale else: # non-uniform scaling - ss['nn'] /= scale - _normalize_vectors(ss['nn']) - if ss['dist'] is not None: + ss["nn"] /= scale + _normalize_vectors(ss["nn"]) + if ss["dist"] is not None: add_dist = True - dist_limit = float(np.abs(sss[0]['dist_limit'])) - elif ss['nearest'] is not None: + dist_limit = float(np.abs(sss[0]["dist_limit"])) + elif ss["nearest"] is not None: add_dist = True dist_limit = 0 @@ -1173,12 +1318,15 @@ def scale_source_space(subject_to, src_name, subject_from=None, scale=None, def _scale_mri(subject_to, mri_fname, subject_from, scale, subjects_dir): """Scale an MRI by setting its affine.""" subjects_dir, subject_from, scale, _ = _scale_params( - subject_to, subject_from, scale, subjects_dir) - nibabel = _import_nibabel('scale an MRI') - fname_from = op.join(mri_dirname.format( - subjects_dir=subjects_dir, subject=subject_from), mri_fname) - fname_to = op.join(mri_dirname.format( - subjects_dir=subjects_dir, subject=subject_to), mri_fname) + subject_to, subject_from, scale, subjects_dir + ) + nibabel = _import_nibabel("scale an MRI") + fname_from = op.join( + mri_dirname.format(subjects_dir=subjects_dir, subject=subject_from), mri_fname + ) + fname_to = op.join( + mri_dirname.format(subjects_dir=subjects_dir, subject=subject_to), mri_fname + ) img = nibabel.load(fname_from) zooms = np.array(img.header.get_zooms()) zooms[[0, 2, 1]] *= scale @@ -1189,21 +1337,23 @@ def _scale_mri(subject_to, mri_fname, subject_from, scale, subjects_dir): nibabel.save(img, fname_to) -def _scale_xfm(subject_to, xfm_fname, mri_name, subject_from, scale, - subjects_dir): +def _scale_xfm(subject_to, xfm_fname, mri_name, subject_from, scale, subjects_dir): """Scale a transform.""" subjects_dir, subject_from, scale, _ = _scale_params( - subject_to, subject_from, scale, subjects_dir) + subject_to, subject_from, scale, subjects_dir + ) # The nibabel warning should already be there in MRI step, if applicable, # as we only get here if T1.mgz is present (and thus a scaling was # attempted) so we can silently return here. fname_from = os.path.join( - mri_transforms_dirname.format( - subjects_dir=subjects_dir, subject=subject_from), xfm_fname) + mri_transforms_dirname.format(subjects_dir=subjects_dir, subject=subject_from), + xfm_fname, + ) fname_to = op.join( - mri_transforms_dirname.format( - subjects_dir=subjects_dir, subject=subject_to), xfm_fname) + mri_transforms_dirname.format(subjects_dir=subjects_dir, subject=subject_to), + xfm_fname, + ) assert op.isfile(fname_from), fname_from assert op.isdir(op.dirname(fname_to)), op.dirname(fname_to) # The "talairach.xfm" file stores the ras_mni transform. @@ -1228,23 +1378,25 @@ def _scale_xfm(subject_to, xfm_fname, mri_name, subject_from, scale, # prepare the scale (S) transform scale = np.atleast_1d(scale) scale = np.tile(scale, 3) if len(scale) == 1 else scale - S = Transform('mri', 'mri', scaling(*scale)) # F_mri->T_mri + S = Transform("mri", "mri", scaling(*scale)) # F_mri->T_mri # # Get the necessary transforms of the "from" subject # xfm, kind = _read_fs_xfm(fname_from) - assert kind == 'MNI Transform File', kind - _, _, F_mri_ras, _, _ = _read_mri_info(mri_name, units='mm') - F_ras_mni = Transform('ras', 'mni_tal', xfm) + assert kind == "MNI Transform File", kind + _, _, F_mri_ras, _, _ = _read_mri_info(mri_name, units="mm") + F_ras_mni = Transform("ras", "mni_tal", xfm) del xfm # # Get the necessary transforms of the "to" subject # - mri_name = op.join(mri_dirname.format( - subjects_dir=subjects_dir, subject=subject_to), op.basename(mri_name)) - _, _, T_mri_ras, _, _ = _read_mri_info(mri_name, units='mm') + mri_name = op.join( + mri_dirname.format(subjects_dir=subjects_dir, subject=subject_to), + op.basename(mri_name), + ) + _, _, T_mri_ras, _, _ = _read_mri_info(mri_name, units="mm") T_ras_mri = invert_transform(T_mri_ras) del mri_name, T_mri_ras @@ -1253,32 +1405,35 @@ def _scale_xfm(subject_to, xfm_fname, mri_name, subject_from, scale, # T_ras_mni = F_ras_mni @ F_mri_ras @ S⁻¹ @ T_ras_mri # # By moving right to left through the equation. - T_ras_mni = \ + T_ras_mni = combine_transforms( combine_transforms( - combine_transforms( - combine_transforms( - T_ras_mri, invert_transform(S), 'ras', 'mri'), - F_mri_ras, 'ras', 'ras'), - F_ras_mni, 'ras', 'mni_tal') - _write_fs_xfm(fname_to, T_ras_mni['trans'], kind) + combine_transforms(T_ras_mri, invert_transform(S), "ras", "mri"), + F_mri_ras, + "ras", + "ras", + ), + F_ras_mni, + "ras", + "mni_tal", + ) + _write_fs_xfm(fname_to, T_ras_mni["trans"], kind) def _read_surface(filename, *, on_defects): bem = dict() if filename is not None and op.exists(filename): - if filename.endswith('.fif'): - bem = read_bem_surfaces( - filename, on_defects=on_defects, verbose=False - )[0] + if filename.endswith(".fif"): + bem = read_bem_surfaces(filename, on_defects=on_defects, verbose=False)[0] else: try: bem = read_surface(filename, return_dict=True)[2] - bem['rr'] *= 1e-3 + bem["rr"] *= 1e-3 complete_surface_info(bem, copy=False) except Exception: raise ValueError( "Error loading surface from %s (see " - "Terminal for details)." % filename) + "Terminal for details)." % filename + ) return bem @@ -1320,20 +1475,20 @@ class Coregistration: to create a surrogate MRI subject with the proper scale factors. """ - def __init__(self, info, subject, subjects_dir=None, fiducials='auto', *, - on_defects='raise'): - _validate_type(info, (Info, None), 'info') + def __init__( + self, info, subject, subjects_dir=None, fiducials="auto", *, on_defects="raise" + ): + _validate_type(info, (Info, None), "info") self._info = info self._subject = _check_subject(subject, subject) - self._subjects_dir = str( - get_subjects_dir(subjects_dir, raise_error=True) - ) + self._subjects_dir = str(get_subjects_dir(subjects_dir, raise_error=True)) self._scale_mode = None self._on_defects = on_defects self._rot_trans = None - self._default_parameters = \ - np.array([0., 0., 0., 0., 0., 0., 1., 1., 1.]) + self._default_parameters = np.array( + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0] + ) self._rotation = self._default_parameters[:3] self._translation = self._default_parameters[3:6] @@ -1342,14 +1497,14 @@ def __init__(self, info, subject, subjects_dir=None, fiducials='auto', *, self._icp_angle = 0.2 self._icp_distance = 0.2 self._icp_scale = 0.2 - self._icp_fid_matches = ('nearest', 'matched') + self._icp_fid_matches = ("nearest", "matched") self._icp_fid_match = self._icp_fid_matches[0] - self._lpa_weight = 1. - self._nasion_weight = 10. - self._rpa_weight = 1. - self._hsp_weight = 1. - self._eeg_weight = 1. - self._hpi_weight = 1. + self._lpa_weight = 1.0 + self._nasion_weight = 10.0 + self._rpa_weight = 1.0 + self._hsp_weight = 1.0 + self._eeg_weight = 1.0 + self._hpi_weight = 1.0 self._extra_points_filter = None self._setup_digs() @@ -1371,77 +1526,86 @@ def _setup_digs(self): ) else: self._dig_dict = _get_data_as_dict_from_dig( - dig=self._info['dig'], - exclude_ref_channel=False + dig=self._info["dig"], exclude_ref_channel=False ) # adjustments: # set weights to 0 for None input # convert fids to float arrays - for k, w_atr in zip(['nasion', 'lpa', 'rpa', 'hsp', 'hpi'], - ['_nasion_weight', '_lpa_weight', - '_rpa_weight', '_hsp_weight', '_hpi_weight']): + for k, w_atr in zip( + ["nasion", "lpa", "rpa", "hsp", "hpi"], + [ + "_nasion_weight", + "_lpa_weight", + "_rpa_weight", + "_hsp_weight", + "_hpi_weight", + ], + ): if self._dig_dict[k] is None: self._dig_dict[k] = np.zeros((0, 3)) setattr(self, w_atr, 0) - elif k in ['rpa', 'nasion', 'lpa']: + elif k in ["rpa", "nasion", "lpa"]: self._dig_dict[k] = np.array([self._dig_dict[k]], float) def _setup_bem(self): # find high-res head model (if possible) - high_res_path = _find_head_bem(self._subject, self._subjects_dir, - high_res=True) - low_res_path = _find_head_bem(self._subject, self._subjects_dir, - high_res=False) + high_res_path = _find_head_bem(self._subject, self._subjects_dir, high_res=True) + low_res_path = _find_head_bem(self._subject, self._subjects_dir, high_res=False) if high_res_path is None and low_res_path is None: - raise RuntimeError("No standard head model was " - f"found for subject {self._subject}") + raise RuntimeError( + "No standard head model was " f"found for subject {self._subject}" + ) if high_res_path is not None: self._bem_high_res = _read_surface( high_res_path, on_defects=self._on_defects ) - logger.info(f'Using high resolution head model in {high_res_path}') + logger.info(f"Using high resolution head model in {high_res_path}") else: self._bem_high_res = _read_surface( low_res_path, on_defects=self._on_defects ) - logger.info(f'Using low resolution head model in {low_res_path}') + logger.info(f"Using low resolution head model in {low_res_path}") if low_res_path is None: # This should be very rare! - warn('No low-resolution head found, decimating high resolution ' - 'mesh (%d vertices): %s' % (len(self._bem_high_res['rr']), - high_res_path,)) + warn( + "No low-resolution head found, decimating high resolution " + "mesh (%d vertices): %s" + % ( + len(self._bem_high_res["rr"]), + high_res_path, + ) + ) # Create one from the high res one, which we know we have - rr, tris = decimate_surface(self._bem_high_res['rr'], - self._bem_high_res['tris'], - n_triangles=5120) + rr, tris = decimate_surface( + self._bem_high_res["rr"], self._bem_high_res["tris"], n_triangles=5120 + ) # directly set the attributes of bem_low_res self._bem_low_res = complete_surface_info( - dict(rr=rr, tris=tris), copy=False, verbose=False) - else: - self._bem_low_res = _read_surface( - low_res_path, on_defects=self._on_defects + dict(rr=rr, tris=tris), copy=False, verbose=False ) + else: + self._bem_low_res = _read_surface(low_res_path, on_defects=self._on_defects) def _setup_fiducials(self, fids): _validate_type(fids, (str, dict, list)) # find fiducials file fid_accurate = None - if fids == 'auto': - fid_files = _find_fiducials_files(self._subject, - self._subjects_dir) + if fids == "auto": + fid_files = _find_fiducials_files(self._subject, self._subjects_dir) if len(fid_files) > 0: # Read fiducials from disk fid_filename = fid_files[0].format( - subjects_dir=self._subjects_dir, subject=self._subject) - logger.info(f'Using fiducials from: {fid_filename}.') + subjects_dir=self._subjects_dir, subject=self._subject + ) + logger.info(f"Using fiducials from: {fid_filename}.") fids, _ = read_fiducials(fid_filename) fid_accurate = True self._fid_filename = fid_filename else: - fids = 'estimated' + fids = "estimated" - if fids == 'estimated': - logger.info('Estimating fiducials from fsaverage.') + if fids == "estimated": + logger.info("Estimating fiducials from fsaverage.") fid_accurate = False fids = get_mni_fiducials(self._subject, self._subjects_dir) @@ -1450,8 +1614,9 @@ def _setup_fiducials(self, fids): fid_coords = _fiducial_coords(fids) else: assert isinstance(fids, dict) - fid_coords = np.array([fids['lpa'], fids['nasion'], fids['rpa']], - dtype=float) + fid_coords = np.array( + [fids["lpa"], fids["nasion"], fids["rpa"]], dtype=float + ) self._fid_points = fid_coords self._fid_accurate = fid_accurate @@ -1464,12 +1629,11 @@ def _reset_fiducials(self): lpa=self._fid_points[0], nasion=self._fid_points[1], rpa=self._fid_points[2], - coord_frame='mri' + coord_frame="mri", ) self.fiducials = dig_montage - def _update_params(self, rot=None, tra=None, sca=None, - force_update=False): + def _update_params(self, rot=None, tra=None, sca=None, force_update=False): if force_update and tra is None: tra = self._translation rot_changed = False @@ -1485,18 +1649,19 @@ def _update_params(self, rot=None, tra=None, sca=None, self._last_translation = self._translation.copy() self._translation = tra self._head_mri_t = rotation(*self._rotation).T - self._head_mri_t[:3, 3] = \ - -np.dot(self._head_mri_t[:3, :3], tra) - self._transformed_dig_hpi = \ - apply_trans(self._head_mri_t, self._dig_dict['hpi']) - self._transformed_dig_eeg = \ - apply_trans( - self._head_mri_t, self._dig_dict['dig_ch_pos_location']) - self._transformed_dig_extra = \ - apply_trans(self._head_mri_t, - self._filtered_extra_points) - self._transformed_orig_dig_extra = \ - apply_trans(self._head_mri_t, self._dig_dict['hsp']) + self._head_mri_t[:3, 3] = -np.dot(self._head_mri_t[:3, :3], tra) + self._transformed_dig_hpi = apply_trans( + self._head_mri_t, self._dig_dict["hpi"] + ) + self._transformed_dig_eeg = apply_trans( + self._head_mri_t, self._dig_dict["dig_ch_pos_location"] + ) + self._transformed_dig_extra = apply_trans( + self._head_mri_t, self._filtered_extra_points + ) + self._transformed_orig_dig_extra = apply_trans( + self._head_mri_t, self._dig_dict["hsp"] + ) self._mri_head_t = rotation(*self._rotation) self._mri_head_t[:3, 3] = np.array(tra) if tra_changed or sca is not None: @@ -1506,27 +1671,32 @@ def _update_params(self, rot=None, tra=None, sca=None, self._scale = sca self._mri_trans = np.eye(4) self._mri_trans[:, :3] *= sca - self._transformed_high_res_mri_points = \ - apply_trans(self._mri_trans, - self._processed_high_res_mri_points) + self._transformed_high_res_mri_points = apply_trans( + self._mri_trans, self._processed_high_res_mri_points + ) self._update_nearest_calc() if tra_changed: - self._nearest_transformed_high_res_mri_idx_orig_hsp = \ + self._nearest_transformed_high_res_mri_idx_orig_hsp = ( self._nearest_calc.query(self._transformed_orig_dig_extra)[1] - self._nearest_transformed_high_res_mri_idx_hpi = \ - self._nearest_calc.query(self._transformed_dig_hpi)[1] - self._nearest_transformed_high_res_mri_idx_eeg = \ - self._nearest_calc.query(self._transformed_dig_eeg)[1] - self._nearest_transformed_high_res_mri_idx_rpa = \ - self._nearest_calc.query( - apply_trans(self._head_mri_t, self._dig_dict['rpa']))[1] - self._nearest_transformed_high_res_mri_idx_nasion = \ - self._nearest_calc.query( - apply_trans(self._head_mri_t, self._dig_dict['nasion']))[1] - self._nearest_transformed_high_res_mri_idx_lpa = \ + ) + self._nearest_transformed_high_res_mri_idx_hpi = self._nearest_calc.query( + self._transformed_dig_hpi + )[1] + self._nearest_transformed_high_res_mri_idx_eeg = self._nearest_calc.query( + self._transformed_dig_eeg + )[1] + self._nearest_transformed_high_res_mri_idx_rpa = self._nearest_calc.query( + apply_trans(self._head_mri_t, self._dig_dict["rpa"]) + )[1] + self._nearest_transformed_high_res_mri_idx_nasion = ( self._nearest_calc.query( - apply_trans(self._head_mri_t, self._dig_dict['lpa']))[1] + apply_trans(self._head_mri_t, self._dig_dict["nasion"]) + )[1] + ) + self._nearest_transformed_high_res_mri_idx_lpa = self._nearest_calc.query( + apply_trans(self._head_mri_t, self._dig_dict["lpa"]) + )[1] def set_scale_mode(self, scale_mode): """Select how to fit the scale parameters. @@ -1616,14 +1786,15 @@ def set_scale(self, sca): def _update_nearest_calc(self): self._nearest_calc = _DistanceQuery( - self._processed_high_res_mri_points * self._scale) + self._processed_high_res_mri_points * self._scale + ) @property def _filtered_extra_points(self): if self._extra_points_filter is None: - return self._dig_dict['hsp'] + return self._dig_dict["hsp"] else: - return self._dig_dict['hsp'][self._extra_points_filter] + return self._dig_dict["hsp"][self._extra_points_filter] @property def _parameters(self): @@ -1631,79 +1802,89 @@ def _parameters(self): @property def _last_parameters(self): - return np.concatenate((self._last_rotation, - self._last_translation, self._last_scale)) + return np.concatenate( + (self._last_rotation, self._last_translation, self._last_scale) + ) @property def _changes(self): move = np.linalg.norm(self._last_translation - self._translation) * 1e3 - angle = np.rad2deg(_angle_between_quats( - rot_to_quat(rotation(*self._rotation)[:3, :3]), - rot_to_quat(rotation(*self._last_rotation)[:3, :3]))) + angle = np.rad2deg( + _angle_between_quats( + rot_to_quat(rotation(*self._rotation)[:3, :3]), + rot_to_quat(rotation(*self._last_rotation)[:3, :3]), + ) + ) percs = 100 * (self._scale - self._last_scale) / self._last_scale return move, angle, percs @property def _nearest_transformed_high_res_mri_idx_hsp(self): return self._nearest_calc.query( - apply_trans(self._head_mri_t, self._filtered_extra_points))[1] + apply_trans(self._head_mri_t, self._filtered_extra_points) + )[1] @property def _has_hsp_data(self): - return (self._has_mri_data and - len(self._nearest_transformed_high_res_mri_idx_hsp) > 0) + return ( + self._has_mri_data + and len(self._nearest_transformed_high_res_mri_idx_hsp) > 0 + ) @property def _has_hpi_data(self): - return (self._has_mri_data and - len(self._nearest_transformed_high_res_mri_idx_hpi) > 0) + return ( + self._has_mri_data + and len(self._nearest_transformed_high_res_mri_idx_hpi) > 0 + ) @property def _has_eeg_data(self): - return (self._has_mri_data and - len(self._nearest_transformed_high_res_mri_idx_eeg) > 0) + return ( + self._has_mri_data + and len(self._nearest_transformed_high_res_mri_idx_eeg) > 0 + ) @property def _has_lpa_data(self): - mri_point = self.fiducials.dig[_map_fid_name_to_idx('lpa')] - assert mri_point['ident'] == FIFF.FIFFV_POINT_LPA - has_mri_data = np.any(mri_point['r']) - has_head_data = np.any(self._dig_dict['lpa']) + mri_point = self.fiducials.dig[_map_fid_name_to_idx("lpa")] + assert mri_point["ident"] == FIFF.FIFFV_POINT_LPA + has_mri_data = np.any(mri_point["r"]) + has_head_data = np.any(self._dig_dict["lpa"]) return has_mri_data and has_head_data @property def _has_nasion_data(self): - mri_point = self.fiducials.dig[_map_fid_name_to_idx('nasion')] - assert mri_point['ident'] == FIFF.FIFFV_POINT_NASION - has_mri_data = np.any(mri_point['r']) - has_head_data = np.any(self._dig_dict['nasion']) + mri_point = self.fiducials.dig[_map_fid_name_to_idx("nasion")] + assert mri_point["ident"] == FIFF.FIFFV_POINT_NASION + has_mri_data = np.any(mri_point["r"]) + has_head_data = np.any(self._dig_dict["nasion"]) return has_mri_data and has_head_data @property def _has_rpa_data(self): - mri_point = self.fiducials.dig[_map_fid_name_to_idx('rpa')] - assert mri_point['ident'] == FIFF.FIFFV_POINT_RPA - has_mri_data = np.any(mri_point['r']) - has_head_data = np.any(self._dig_dict['rpa']) + mri_point = self.fiducials.dig[_map_fid_name_to_idx("rpa")] + assert mri_point["ident"] == FIFF.FIFFV_POINT_RPA + has_mri_data = np.any(mri_point["r"]) + has_head_data = np.any(self._dig_dict["rpa"]) return has_mri_data and has_head_data @property def _processed_high_res_mri_points(self): - return self._get_processed_mri_points('high') + return self._get_processed_mri_points("high") @property def _processed_low_res_mri_points(self): - return self._get_processed_mri_points('low') + return self._get_processed_mri_points("low") def _get_processed_mri_points(self, res): - bem = self._bem_low_res if res == 'low' else self._bem_high_res - points = bem['rr'].copy() + bem = self._bem_low_res if res == "low" else self._bem_high_res + points = bem["rr"].copy() if self._grow_hair: - assert len(bem['nn']) # should be guaranteed by _read_surface - scaled_hair_dist = (1e-3 * self._grow_hair / - np.array(self._scale)) + assert len(bem["nn"]) # should be guaranteed by _read_surface + scaled_hair_dist = 1e-3 * self._grow_hair / np.array(self._scale) hair = points[:, 2] > points[:, 1] - points[hair] += bem['nn'][hair] * scaled_hair_dist + points[hair] += bem["nn"][hair] * scaled_hair_dist return points @property @@ -1712,20 +1893,24 @@ def _has_mri_data(self): @property def _has_dig_data(self): - return (self._has_mri_data and - len(self._nearest_transformed_high_res_mri_idx_hsp) > 0) + return ( + self._has_mri_data + and len(self._nearest_transformed_high_res_mri_idx_hsp) > 0 + ) @property def _orig_hsp_point_distance(self): mri_points = self._transformed_high_res_mri_points[ - self._nearest_transformed_high_res_mri_idx_orig_hsp] + self._nearest_transformed_high_res_mri_idx_orig_hsp + ] hsp_points = self._transformed_orig_dig_extra return np.linalg.norm(mri_points - hsp_points, axis=-1) def _log_dig_mri_distance(self, prefix): errs_nearest = self.compute_dig_mri_distances() - logger.info(f'{prefix} median distance: ' - f'{np.median(errs_nearest * 1000):6.2f} mm') + logger.info( + f"{prefix} median distance: " f"{np.median(errs_nearest * 1000):6.2f} mm" + ) @property def scale(self): @@ -1739,8 +1924,9 @@ def scale(self): return self._scale.copy() @verbose - def fit_fiducials(self, lpa_weight=1., nasion_weight=10., rpa_weight=1., - verbose=None): + def fit_fiducials( + self, lpa_weight=1.0, nasion_weight=10.0, rpa_weight=1.0, verbose=None + ): """Find rotation and translation to fit all 3 fiducials. Parameters @@ -1758,34 +1944,41 @@ def fit_fiducials(self, lpa_weight=1., nasion_weight=10., rpa_weight=1., self : Coregistration The modified Coregistration object. """ - logger.info('Aligning using fiducials') - self._log_dig_mri_distance('Start') + logger.info("Aligning using fiducials") + self._log_dig_mri_distance("Start") n_scale_params = self._n_scale_params if n_scale_params == 3: # enforce 1 even for 3-axis here (3 points is not enough) - logger.info("Enforcing 1 scaling parameter for fit " - "with fiducials.") + logger.info("Enforcing 1 scaling parameter for fit " "with fiducials.") n_scale_params = 1 self._lpa_weight = lpa_weight self._nasion_weight = nasion_weight self._rpa_weight = rpa_weight - head_pts = np.vstack((self._dig_dict['lpa'], - self._dig_dict['nasion'], - self._dig_dict['rpa'])) + head_pts = np.vstack( + (self._dig_dict["lpa"], self._dig_dict["nasion"], self._dig_dict["rpa"]) + ) mri_pts = np.vstack( - (self.fiducials.dig[0]['r'], # LPA - self.fiducials.dig[1]['r'], # Nasion - self.fiducials.dig[2]['r']) # RPA + ( + self.fiducials.dig[0]["r"], # LPA + self.fiducials.dig[1]["r"], # Nasion + self.fiducials.dig[2]["r"], + ) # RPA ) weights = [lpa_weight, nasion_weight, rpa_weight] if n_scale_params == 0: mri_pts *= self._scale # not done in fit_matched_points x0 = self._parameters - x0 = x0[:6 + n_scale_params] - est = fit_matched_points(mri_pts, head_pts, x0=x0, out='params', - scale=n_scale_params, weights=weights) + x0 = x0[: 6 + n_scale_params] + est = fit_matched_points( + mri_pts, + head_pts, + x0=x0, + out="params", + scale=n_scale_params, + weights=weights, + ) if n_scale_params == 0: self._update_params(rot=est[:3], tra=est[3:6]) else: @@ -1793,7 +1986,7 @@ def fit_fiducials(self, lpa_weight=1., nasion_weight=10., rpa_weight=1., est = np.concatenate([est, [est[-1]] * 2]) assert est.size == 9 self._update_params(rot=est[:3], tra=est[3:6], sca=est[6:9]) - self._log_dig_mri_distance('End ') + self._log_dig_mri_distance("End ") return self def _setup_icp(self, n_scale_params): @@ -1802,34 +1995,47 @@ def _setup_icp(self, n_scale_params): weights = list() if self._has_dig_data and self._hsp_weight > 0: # should be true head_pts.append(self._filtered_extra_points) - mri_pts.append(self._processed_high_res_mri_points[ - self._nearest_transformed_high_res_mri_idx_hsp]) + mri_pts.append( + self._processed_high_res_mri_points[ + self._nearest_transformed_high_res_mri_idx_hsp + ] + ) weights.append(np.full(len(head_pts[-1]), self._hsp_weight)) - for key in ('lpa', 'nasion', 'rpa'): - if getattr(self, f'_has_{key}_data'): + for key in ("lpa", "nasion", "rpa"): + if getattr(self, f"_has_{key}_data"): head_pts.append(self._dig_dict[key]) - if self._icp_fid_match == 'matched': + if self._icp_fid_match == "matched": idx = _map_fid_name_to_idx(name=key) - p = self.fiducials.dig[idx]['r'].reshape(1, -1) + p = self.fiducials.dig[idx]["r"].reshape(1, -1) mri_pts.append(p) else: - assert self._icp_fid_match == 'nearest' - mri_pts.append(self._processed_high_res_mri_points[ - getattr( - self, - '_nearest_transformed_high_res_mri_idx_%s' - % (key,))]) - weights.append(np.full(len(mri_pts[-1]), - getattr(self, '_%s_weight' % key))) + assert self._icp_fid_match == "nearest" + mri_pts.append( + self._processed_high_res_mri_points[ + getattr( + self, + "_nearest_transformed_high_res_mri_idx_%s" % (key,), + ) + ] + ) + weights.append( + np.full(len(mri_pts[-1]), getattr(self, "_%s_weight" % key)) + ) if self._has_eeg_data and self._eeg_weight > 0: - head_pts.append(self._dig_dict['dig_ch_pos_location']) - mri_pts.append(self._processed_high_res_mri_points[ - self._nearest_transformed_high_res_mri_idx_eeg]) + head_pts.append(self._dig_dict["dig_ch_pos_location"]) + mri_pts.append( + self._processed_high_res_mri_points[ + self._nearest_transformed_high_res_mri_idx_eeg + ] + ) weights.append(np.full(len(mri_pts[-1]), self._eeg_weight)) if self._has_hpi_data and self._hpi_weight > 0: - head_pts.append(self._dig_dict['hpi']) - mri_pts.append(self._processed_high_res_mri_points[ - self._nearest_transformed_high_res_mri_idx_hpi]) + head_pts.append(self._dig_dict["hpi"]) + mri_pts.append( + self._processed_high_res_mri_points[ + self._nearest_transformed_high_res_mri_idx_hpi + ] + ) weights.append(np.full(len(mri_pts[-1]), self._hpi_weight)) head_pts = np.concatenate(head_pts) mri_pts = np.concatenate(mri_pts) @@ -1853,14 +2059,23 @@ def set_fid_match(self, match): self : Coregistration The modified Coregistration object. """ - _check_option('match', match, self._icp_fid_matches) + _check_option("match", match, self._icp_fid_matches) self._icp_fid_match = match return self @verbose - def fit_icp(self, n_iterations=20, lpa_weight=1., nasion_weight=10., - rpa_weight=1., hsp_weight=1., eeg_weight=1., hpi_weight=1., - callback=None, verbose=None): + def fit_icp( + self, + n_iterations=20, + lpa_weight=1.0, + nasion_weight=10.0, + rpa_weight=1.0, + hsp_weight=1.0, + eeg_weight=1.0, + hpi_weight=1.0, + callback=None, + verbose=None, + ): """Find MRI scaling, translation, and rotation to match HSP. Parameters @@ -1890,8 +2105,8 @@ def fit_icp(self, n_iterations=20, lpa_weight=1., nasion_weight=10., self : Coregistration The modified Coregistration object. """ - logger.info('Aligning using ICP') - self._log_dig_mri_distance('Start ') + logger.info("Aligning using ICP") + self._log_dig_mri_distance("Start ") n_scale_params = self._n_scale_params self._lpa_weight = lpa_weight self._nasion_weight = nasion_weight @@ -1902,13 +2117,19 @@ def fit_icp(self, n_iterations=20, lpa_weight=1., nasion_weight=10., # Initial guess (current state) est = self._parameters - est = est[:[6, 7, None, 9][n_scale_params]] + est = est[: [6, 7, None, 9][n_scale_params]] # Do the fits, assigning and evaluating at each step for iteration in range(n_iterations): head_pts, mri_pts, weights = self._setup_icp(n_scale_params) - est = fit_matched_points(mri_pts, head_pts, scale=n_scale_params, - x0=est, out='params', weights=weights) + est = fit_matched_points( + mri_pts, + head_pts, + scale=n_scale_params, + x0=est, + out="params", + weights=weights, + ) if n_scale_params == 0: self._update_params(rot=est[:3], tra=est[3:6]) elif n_scale_params == 1: @@ -1917,20 +2138,23 @@ def fit_icp(self, n_iterations=20, lpa_weight=1., nasion_weight=10., else: self._update_params(rot=est[:3], tra=est[3:6], sca=est[6:9]) angle, move, scale = self._changes - self._log_dig_mri_distance(f' ICP {iteration + 1:2d} ') + self._log_dig_mri_distance(f" ICP {iteration + 1:2d} ") if callback is not None: callback(iteration, n_iterations) - if angle <= self._icp_angle and move <= self._icp_distance and \ - all(scale <= self._icp_scale): + if ( + angle <= self._icp_angle + and move <= self._icp_distance + and all(scale <= self._icp_scale) + ): break - self._log_dig_mri_distance('End ') + self._log_dig_mri_distance("End ") return self @property def _n_scale_params(self): if self._scale_mode is None: n_scale_params = 0 - elif self._scale_mode == 'uniform': + elif self._scale_mode == "uniform": n_scale_params = 1 else: n_scale_params = 3 @@ -1957,8 +2181,12 @@ def omit_head_shape_points(self, distance): # find the new filter mask = self._orig_hsp_point_distance <= distance n_excluded = np.sum(~mask) - logger.info("Coregistration: Excluding %i head shape points with " - "distance >= %.3f m.", n_excluded, distance) + logger.info( + "Coregistration: Excluding %i head shape points with " + "distance >= %.3f m.", + n_excluded, + distance, + ) # set the filter self._extra_points_filter = mask self._update_params(force_update=True) @@ -1985,7 +2213,7 @@ def compute_dig_mri_distances(self): @property def trans(self): """The head->mri :class:`~mne.transforms.Transform`.""" - return Transform('head', 'mri', self._head_mri_t) + return Transform("head", "mri", self._head_mri_t) def reset(self): """Reset all the parameters affecting the coregistration. @@ -1995,7 +2223,7 @@ def reset(self): self : Coregistration The modified Coregistration object. """ - self._grow_hair = 0. + self._grow_hair = 0.0 self.set_rotation(self._default_parameters[:3]) self.set_translation(self._default_parameters[3:6]) self.set_scale(self._default_parameters[6:9]) @@ -2005,15 +2233,13 @@ def reset(self): def _get_fiducials_distance(self): distance = dict() - for key in ('lpa', 'nasion', 'rpa'): + for key in ("lpa", "nasion", "rpa"): idx = _map_fid_name_to_idx(name=key) - fid = self.fiducials.dig[idx]['r'].reshape(1, -1) + fid = self.fiducials.dig[idx]["r"].reshape(1, -1) transformed_mri = apply_trans(self._mri_trans, fid) - transformed_hsp = apply_trans( - self._head_mri_t, self._dig_dict[key]) - distance[key] = np.linalg.norm( - np.ravel(transformed_mri - transformed_hsp)) + transformed_hsp = apply_trans(self._head_mri_t, self._dig_dict[key]) + distance[key] = np.linalg.norm(np.ravel(transformed_mri - transformed_hsp)) return np.array(list(distance.values())) * 1e3 def _get_fiducials_distance_str(self): @@ -2024,18 +2250,27 @@ def _get_point_distance(self): mri_points = list() hsp_points = list() if self._hsp_weight > 0 and self._has_hsp_data: - mri_points.append(self._transformed_high_res_mri_points[ - self._nearest_transformed_high_res_mri_idx_hsp]) + mri_points.append( + self._transformed_high_res_mri_points[ + self._nearest_transformed_high_res_mri_idx_hsp + ] + ) hsp_points.append(self._transformed_dig_extra) assert len(mri_points[-1]) == len(hsp_points[-1]) if self._eeg_weight > 0 and self._has_eeg_data: - mri_points.append(self._transformed_high_res_mri_points[ - self._nearest_transformed_high_res_mri_idx_eeg]) + mri_points.append( + self._transformed_high_res_mri_points[ + self._nearest_transformed_high_res_mri_idx_eeg + ] + ) hsp_points.append(self._transformed_dig_eeg) assert len(mri_points[-1]) == len(hsp_points[-1]) if self._hpi_weight > 0 and self._has_hpi_data: - mri_points.append(self._transformed_high_res_mri_points[ - self._nearest_transformed_high_res_mri_idx_hpi]) + mri_points.append( + self._transformed_high_res_mri_points[ + self._nearest_transformed_high_res_mri_idx_hpi + ] + ) hsp_points.append(self._transformed_dig_hpi) assert len(mri_points[-1]) == len(hsp_points[-1]) if all(len(h) == 0 for h in hsp_points): @@ -2051,10 +2286,14 @@ def _get_point_distance_str(self): dists = 1e3 * point_distance av_dist = np.mean(dists) std_dist = np.std(dists) - kinds = [kind for kind, check in - (('HSP', self._hsp_weight > 0 and self._has_hsp_data), - ('EEG', self._eeg_weight > 0 and self._has_eeg_data), - ('HPI', self._hpi_weight > 0 and self._has_hpi_data)) - if check] - kinds = '+'.join(kinds) + kinds = [ + kind + for kind, check in ( + ("HSP", self._hsp_weight > 0 and self._has_hsp_data), + ("EEG", self._eeg_weight > 0 and self._has_eeg_data), + ("HPI", self._hpi_weight > 0 and self._has_hpi_data), + ) + if check + ] + kinds = "+".join(kinds) return f"{len(dists)} {kinds}: {av_dist:.1f} ± {std_dist:.1f} mm" diff --git a/mne/cov.py b/mne/cov.py index 43c993c6c91..15fd043d022 100644 --- a/mne/cov.py +++ b/mne/cov.py @@ -10,55 +10,100 @@ import numpy as np -from .defaults import (_INTERPOLATION_DEFAULT, _EXTRAPOLATE_DEFAULT, - _BORDER_DEFAULT, DEFAULTS) +from .defaults import ( + _INTERPOLATION_DEFAULT, + _EXTRAPOLATE_DEFAULT, + _BORDER_DEFAULT, + DEFAULTS, +) from .io.write import start_and_end_file -from .io.proj import (make_projector, _proj_equal, activate_proj, - _check_projs, _needs_eeg_average_ref_proj, - _has_eeg_average_ref_proj, _read_proj, _write_proj) +from .io.proj import ( + make_projector, + _proj_equal, + activate_proj, + _check_projs, + _needs_eeg_average_ref_proj, + _has_eeg_average_ref_proj, + _read_proj, + _write_proj, +) from .io import fiff_open, RawArray -from .io.pick import (pick_types, pick_channels_cov, pick_channels, pick_info, - _picks_by_type, _pick_data_channels, _picks_to_idx, - _DATA_CH_TYPES_SPLIT) +from .io.pick import ( + pick_types, + pick_channels_cov, + pick_channels, + pick_info, + _picks_by_type, + _pick_data_channels, + _picks_to_idx, + _DATA_CH_TYPES_SPLIT, +) from .io.constants import FIFF from .io.meas_info import _read_bad_channels, create_info, _write_bad_channels from .io.tag import find_tag from .io.tree import dir_tree_find -from .io.write import (start_block, end_block, write_int, write_double, - write_float_matrix, write_string, _safe_name_list, - write_name_list_sanitized) +from .io.write import ( + start_block, + end_block, + write_int, + write_double, + write_float_matrix, + write_string, + _safe_name_list, + write_name_list_sanitized, +) from .defaults import _handle_default from .epochs import Epochs from .event import make_fixed_length_events from .evoked import EvokedArray from .rank import compute_rank -from .utils import (check_fname, logger, verbose, check_version, _time_mask, - warn, copy_function_doc_to_method_doc, _pl, - _undo_scaling_cov, _scaled_array, _validate_type, - _check_option, eigh, fill_doc, _on_missing, - _check_on_missing, _check_fname, _verbose_safe_false) +from .utils import ( + check_fname, + logger, + verbose, + check_version, + _time_mask, + warn, + copy_function_doc_to_method_doc, + _pl, + _undo_scaling_cov, + _scaled_array, + _validate_type, + _check_option, + eigh, + fill_doc, + _on_missing, + _check_on_missing, + _check_fname, + _verbose_safe_false, +) from . import viz -from .fixes import (BaseEstimator, EmpiricalCovariance, _logdet, - empirical_covariance, log_likelihood) +from .fixes import ( + BaseEstimator, + EmpiricalCovariance, + _logdet, + empirical_covariance, + log_likelihood, +) def _check_covs_algebra(cov1, cov2): if cov1.ch_names != cov2.ch_names: - raise ValueError('Both Covariance do not have the same list of ' - 'channels.') - projs1 = [str(c) for c in cov1['projs']] - projs2 = [str(c) for c in cov1['projs']] + raise ValueError("Both Covariance do not have the same list of " "channels.") + projs1 = [str(c) for c in cov1["projs"]] + projs2 = [str(c) for c in cov1["projs"]] if projs1 != projs2: - raise ValueError('Both Covariance do not have the same list of ' - 'SSP projections.') + raise ValueError( + "Both Covariance do not have the same list of " "SSP projections." + ) def _get_tslice(epochs, tmin, tmax): """Get the slice.""" - mask = _time_mask(epochs.times, tmin, tmax, sfreq=epochs.info['sfreq']) + mask = _time_mask(epochs.times, tmin, tmax, sfreq=epochs.info["sfreq"]) tstart = np.where(mask)[0][0] if tmin is not None else None tend = np.where(mask)[0][-1] + 1 if tmax is not None else None tslice = slice(tstart, tend, None) @@ -116,33 +161,54 @@ class Covariance(dict): """ @verbose - def __init__(self, data, names, bads, projs, nfree, eig=None, eigvec=None, - method=None, loglik=None, *, verbose=None): + def __init__( + self, + data, + names, + bads, + projs, + nfree, + eig=None, + eigvec=None, + method=None, + loglik=None, + *, + verbose=None, + ): """Init of covariance.""" - diag = (data.ndim == 1) + diag = data.ndim == 1 projs = _check_projs(projs) - self.update(data=data, dim=len(data), names=names, bads=bads, - nfree=nfree, eig=eig, eigvec=eigvec, diag=diag, - projs=projs, kind=FIFF.FIFFV_MNE_NOISE_COV) + self.update( + data=data, + dim=len(data), + names=names, + bads=bads, + nfree=nfree, + eig=eig, + eigvec=eigvec, + diag=diag, + projs=projs, + kind=FIFF.FIFFV_MNE_NOISE_COV, + ) if method is not None: - self['method'] = method + self["method"] = method if loglik is not None: - self['loglik'] = loglik + self["loglik"] = loglik @property def data(self): """Numpy array of Noise covariance matrix.""" - return self['data'] + return self["data"] @property def ch_names(self): """Channel names.""" - return self['names'] + return self["names"] @property def nfree(self): """Number of degrees of freedom.""" - return self['nfree'] + return self["nfree"] @verbose def save(self, fname, *, overwrite=False, verbose=None): @@ -157,8 +223,9 @@ def save(self, fname, *, overwrite=False, verbose=None): .. versionadded:: 1.0 %(verbose)s """ - check_fname(fname, 'covariance', ('-cov.fif', '-cov.fif.gz', - '_cov.fif', '_cov.fif.gz')) + check_fname( + fname, "covariance", ("-cov.fif", "-cov.fif.gz", "_cov.fif", "_cov.fif.gz") + ) fname = _check_fname(fname=fname, overwrite=overwrite) with start_and_end_file(fname) as fid: _write_cov(fid, self) @@ -188,35 +255,35 @@ def as_diag(self): This function operates in place. """ - if self['diag']: + if self["diag"]: return self - self['diag'] = True - self['data'] = np.diag(self['data']) - self['eig'] = None - self['eigvec'] = None + self["diag"] = True + self["data"] = np.diag(self["data"]) + self["eig"] = None + self["eigvec"] = None return self def _as_square(self): # This is a hack but it works because np.diag() behaves nicely - if self['diag']: - self['diag'] = False + if self["diag"]: + self["diag"] = False self.as_diag() - self['diag'] = False + self["diag"] = False return self def _get_square(self): - if self['diag'] != (self.data.ndim == 1): + if self["diag"] != (self.data.ndim == 1): raise RuntimeError( - 'Covariance attributes inconsistent, got data with ' - 'dimensionality %d but diag=%s' - % (self.data.ndim, self['diag'])) - return np.diag(self.data) if self['diag'] else self.data.copy() + "Covariance attributes inconsistent, got data with " + "dimensionality %d but diag=%s" % (self.data.ndim, self["diag"]) + ) + return np.diag(self.data) if self["diag"] else self.data.copy() def __repr__(self): # noqa: D105 if self.data.ndim == 2: - s = 'size : %s x %s' % self.data.shape + s = "size : %s x %s" % self.data.shape else: # ndim == 1 - s = 'diagonal : %s' % self.data.size + s = "diagonal : %s" % self.data.size s += ", n_samples : %s" % self.nfree s += ", data : %s" % self.data return "" % s @@ -225,43 +292,74 @@ def __add__(self, cov): """Add Covariance taking into account number of degrees of freedom.""" _check_covs_algebra(self, cov) this_cov = cov.copy() - this_cov['data'] = (((this_cov['data'] * this_cov['nfree']) + - (self['data'] * self['nfree'])) / - (self['nfree'] + this_cov['nfree'])) - this_cov['nfree'] += self['nfree'] + this_cov["data"] = ( + (this_cov["data"] * this_cov["nfree"]) + (self["data"] * self["nfree"]) + ) / (self["nfree"] + this_cov["nfree"]) + this_cov["nfree"] += self["nfree"] - this_cov['bads'] = list(set(this_cov['bads']).union(self['bads'])) + this_cov["bads"] = list(set(this_cov["bads"]).union(self["bads"])) return this_cov def __iadd__(self, cov): """Add Covariance taking into account number of degrees of freedom.""" _check_covs_algebra(self, cov) - self['data'][:] = (((self['data'] * self['nfree']) + - (cov['data'] * cov['nfree'])) / - (self['nfree'] + cov['nfree'])) - self['nfree'] += cov['nfree'] + self["data"][:] = ( + (self["data"] * self["nfree"]) + (cov["data"] * cov["nfree"]) + ) / (self["nfree"] + cov["nfree"]) + self["nfree"] += cov["nfree"] - self['bads'] = list(set(self['bads']).union(cov['bads'])) + self["bads"] = list(set(self["bads"]).union(cov["bads"])) return self @verbose @copy_function_doc_to_method_doc(viz.misc.plot_cov) - def plot(self, info, exclude=[], colorbar=True, proj=False, show_svd=True, - show=True, verbose=None): - return viz.misc.plot_cov(self, info, exclude, colorbar, proj, show_svd, - show, verbose) + def plot( + self, + info, + exclude=[], + colorbar=True, + proj=False, + show_svd=True, + show=True, + verbose=None, + ): + return viz.misc.plot_cov( + self, info, exclude, colorbar, proj, show_svd, show, verbose + ) @verbose def plot_topomap( - self, info, ch_type=None, *, scalings=None, proj=False, - noise_cov=None, sensors=True, show_names=False, mask=None, - mask_params=None, contours=6, outlines='head', sphere=None, - image_interp=_INTERPOLATION_DEFAULT, - extrapolate=_EXTRAPOLATE_DEFAULT, border=_BORDER_DEFAULT, res=64, - size=1, cmap=None, vlim=(None, None), cnorm=None, colorbar=True, - cbar_fmt='%3.1f', units=None, axes=None, show=True, verbose=None): + self, + info, + ch_type=None, + *, + scalings=None, + proj=False, + noise_cov=None, + sensors=True, + show_names=False, + mask=None, + mask_params=None, + contours=6, + outlines="head", + sphere=None, + image_interp=_INTERPOLATION_DEFAULT, + extrapolate=_EXTRAPOLATE_DEFAULT, + border=_BORDER_DEFAULT, + res=64, + size=1, + cmap=None, + vlim=(None, None), + cnorm=None, + colorbar=True, + cbar_fmt="%3.1f", + units=None, + axes=None, + show=True, + verbose=None, + ): """Plot a topomap of the covariance diagonal. Parameters @@ -319,22 +417,40 @@ def plot_topomap( # entries is the same as multiplying twice evoked = whiten_evoked(whiten_evoked(evoked, noise_cov), noise_cov) if units is None: - units = 'AU' + units = "AU" if scalings is None: - scalings = 1. + scalings = 1.0 if units is None: - units = {k: f'({v})²' for k, v in DEFAULTS['units'].items()} + units = {k: f"({v})²" for k, v in DEFAULTS["units"].items()} if scalings is None: - scalings = {k: v * v for k, v in DEFAULTS['scalings'].items()} + scalings = {k: v * v for k, v in DEFAULTS["scalings"].items()} return evoked.plot_topomap( - times=[0], ch_type=ch_type, vlim=vlim, cmap=cmap, - sensors=sensors, cnorm=cnorm, colorbar=colorbar, scalings=scalings, - units=units, res=res, size=size, cbar_fmt=cbar_fmt, - proj=proj, show=show, show_names=show_names, - mask=mask, mask_params=mask_params, outlines=outlines, - contours=contours, image_interp=image_interp, axes=axes, - extrapolate=extrapolate, sphere=sphere, border=border, - time_format='') + times=[0], + ch_type=ch_type, + vlim=vlim, + cmap=cmap, + sensors=sensors, + cnorm=cnorm, + colorbar=colorbar, + scalings=scalings, + units=units, + res=res, + size=size, + cbar_fmt=cbar_fmt, + proj=proj, + show=show, + show_names=show_names, + mask=mask, + mask_params=mask_params, + outlines=outlines, + contours=contours, + image_interp=image_interp, + axes=axes, + extrapolate=extrapolate, + sphere=sphere, + border=border, + time_format="", + ) @verbose def pick_channels(self, ch_names, ordered=None, *, verbose=None): @@ -358,13 +474,15 @@ def pick_channels(self, ch_names, ordered=None, *, verbose=None): .. versionadded:: 0.20.0 """ - return pick_channels_cov(self, ch_names, exclude=[], ordered=ordered, - copy=False) + return pick_channels_cov( + self, ch_names, exclude=[], ordered=ordered, copy=False + ) ############################################################################### # IO + @verbose def read_cov(fname, verbose=None): """Read a noise covariance from a FIF file. @@ -385,18 +503,21 @@ def read_cov(fname, verbose=None): -------- write_cov, compute_covariance, compute_raw_covariance """ - check_fname(fname, 'covariance', ('-cov.fif', '-cov.fif.gz', - '_cov.fif', '_cov.fif.gz')) - fname = _check_fname(fname=fname, must_exist=True, overwrite='read') + check_fname( + fname, "covariance", ("-cov.fif", "-cov.fif.gz", "_cov.fif", "_cov.fif.gz") + ) + fname = _check_fname(fname=fname, must_exist=True, overwrite="read") f, tree, _ = fiff_open(fname) with f as fid: - return Covariance(**_read_cov(fid, tree, FIFF.FIFFV_MNE_NOISE_COV, - limited=True)) + return Covariance( + **_read_cov(fid, tree, FIFF.FIFFV_MNE_NOISE_COV, limited=True) + ) ############################################################################### # Estimate from data + @verbose def make_ad_hoc_cov(info, std=None, *, verbose=None): """Create an ad hoc noise covariance. @@ -423,33 +544,51 @@ def make_ad_hoc_cov(info, std=None, *, verbose=None): .. versionadded:: 0.9.0 """ picks = pick_types(info, meg=True, eeg=True, exclude=()) - std = _handle_default('noise_std', std) + std = _handle_default("noise_std", std) data = np.zeros(len(picks)) - for meg, eeg, val in zip(('grad', 'mag', False), (False, False, True), - (std['grad'], std['mag'], std['eeg'])): + for meg, eeg, val in zip( + ("grad", "mag", False), + (False, False, True), + (std["grad"], std["mag"], std["eeg"]), + ): these_picks = pick_types(info, meg=meg, eeg=eeg) data[np.searchsorted(picks, these_picks)] = val * val - ch_names = [info['ch_names'][pick] for pick in picks] - return Covariance(data, ch_names, info['bads'], info['projs'], nfree=0) + ch_names = [info["ch_names"][pick] for pick in picks] + return Covariance(data, ch_names, info["bads"], info["projs"], nfree=0) def _check_n_samples(n_samples, n_chan): """Check to see if there are enough samples for reliable cov calc.""" n_samples_min = 10 * (n_chan + 1) // 2 if n_samples <= 0: - raise ValueError('No samples found to compute the covariance matrix') + raise ValueError("No samples found to compute the covariance matrix") if n_samples < n_samples_min: - warn('Too few samples (required : %d got : %d), covariance ' - 'estimate may be unreliable' % (n_samples_min, n_samples)) + warn( + "Too few samples (required : %d got : %d), covariance " + "estimate may be unreliable" % (n_samples_min, n_samples) + ) @verbose -def compute_raw_covariance(raw, tmin=0, tmax=None, tstep=0.2, reject=None, - flat=None, picks=None, method='empirical', - method_params=None, cv=3, scalings=None, - n_jobs=None, return_estimators=False, - reject_by_annotation=True, rank=None, verbose=None): +def compute_raw_covariance( + raw, + tmin=0, + tmax=None, + tstep=0.2, + reject=None, + flat=None, + picks=None, + method="empirical", + method_params=None, + cv=3, + scalings=None, + n_jobs=None, + return_estimators=False, + reject_by_annotation=True, + rank=None, + verbose=None, +): """Estimate noise covariance matrix from a continuous segment of raw data. It is typically useful to estimate a noise covariance from empty room @@ -557,31 +696,40 @@ def compute_raw_covariance(raw, tmin=0, tmax=None, tstep=0.2, reject=None, baseline correction) subtract the mean across time *for each epoch* (instead of across epochs) for each channel. """ - tmin = 0. if tmin is None else float(tmin) - dt = 1. / raw.info['sfreq'] + tmin = 0.0 if tmin is None else float(tmin) + dt = 1.0 / raw.info["sfreq"] tmax = raw.times[-1] + dt if tmax is None else float(tmax) tstep = tmax - tmin if tstep is None else float(tstep) tstep_m1 = tstep - dt # inclusive! events = make_fixed_length_events(raw, 1, tmin, tmax, tstep) - logger.info('Using up to %s segment%s' % (len(events), _pl(events))) + logger.info("Using up to %s segment%s" % (len(events), _pl(events))) # don't exclude any bad channels, inverses expect all channels present if picks is None: # Need to include all channels e.g. if eog rejection is to be used - picks = np.arange(raw.info['nchan']) - pick_mask = np.in1d( - picks, _pick_data_channels(raw.info, with_ref_meg=False)) + picks = np.arange(raw.info["nchan"]) + pick_mask = np.in1d(picks, _pick_data_channels(raw.info, with_ref_meg=False)) else: pick_mask = slice(None) picks = _picks_to_idx(raw.info, picks) - epochs = Epochs(raw, events, 1, 0, tstep_m1, baseline=None, - picks=picks, reject=reject, flat=flat, - verbose=_verbose_safe_false(), - preload=False, proj=False, - reject_by_annotation=reject_by_annotation) + epochs = Epochs( + raw, + events, + 1, + 0, + tstep_m1, + baseline=None, + picks=picks, + reject=reject, + flat=flat, + verbose=_verbose_safe_false(), + preload=False, + proj=False, + reject_by_annotation=reject_by_annotation, + ) if method is None: - method = 'empirical' - if isinstance(method, str) and method == 'empirical': + method = "empirical" + if isinstance(method, str) and method == "empirical": # potentially *much* more memory efficient to do it the iterative way picks = picks[pick_mask] data = 0 @@ -595,13 +743,12 @@ def compute_raw_covariance(raw, tmin=0, tmax=None, tstep=0.2, reject=None, n_samples += raw_segment.shape[1] _check_n_samples(n_samples, len(picks)) data -= mu[:, None] * (mu[None, :] / n_samples) - data /= (n_samples - 1.0) + data /= n_samples - 1.0 logger.info("Number of samples used : %d" % n_samples) - logger.info('[done]') - ch_names = [raw.info['ch_names'][k] for k in picks] - bads = [b for b in raw.info['bads'] if b in ch_names] - return Covariance(data, ch_names, bads, raw.info['projs'], - nfree=n_samples - 1) + logger.info("[done]") + ch_names = [raw.info["ch_names"][k] for k in picks] + bads = [b for b in raw.info["bads"] if b in ch_names] + return Covariance(data, ch_names, bads, raw.info["projs"], nfree=n_samples - 1) del picks, pick_mask # This makes it equivalent to what we used to do (and do above for @@ -611,85 +758,130 @@ def compute_raw_covariance(raw, tmin=0, tmax=None, tstep=0.2, reject=None, epochs._data -= ch_means[np.newaxis, :, np.newaxis] # fake this value so there are no complaints from compute_covariance epochs.baseline = (None, None) - return compute_covariance(epochs, keep_sample_mean=True, method=method, - method_params=method_params, cv=cv, - scalings=scalings, n_jobs=n_jobs, - return_estimators=return_estimators, - rank=rank) - - -def _check_method_params(method, method_params, keep_sample_mean=True, - name='method', allow_auto=True, rank=None): + return compute_covariance( + epochs, + keep_sample_mean=True, + method=method, + method_params=method_params, + cv=cv, + scalings=scalings, + n_jobs=n_jobs, + return_estimators=return_estimators, + rank=rank, + ) + + +def _check_method_params( + method, + method_params, + keep_sample_mean=True, + name="method", + allow_auto=True, + rank=None, +): """Check that method and method_params are usable.""" - accepted_methods = ('auto', 'empirical', 'diagonal_fixed', 'ledoit_wolf', - 'oas', 'shrunk', 'pca', 'factor_analysis', 'shrinkage') + accepted_methods = ( + "auto", + "empirical", + "diagonal_fixed", + "ledoit_wolf", + "oas", + "shrunk", + "pca", + "factor_analysis", + "shrinkage", + ) _method_params = { - 'empirical': {'store_precision': False, 'assume_centered': True}, - 'diagonal_fixed': {'store_precision': False, 'assume_centered': True}, - 'ledoit_wolf': {'store_precision': False, 'assume_centered': True}, - 'oas': {'store_precision': False, 'assume_centered': True}, - 'shrinkage': {'shrinkage': 0.1, 'store_precision': False, - 'assume_centered': True}, - 'shrunk': {'shrinkage': np.logspace(-4, 0, 30), - 'store_precision': False, 'assume_centered': True}, - 'pca': {'iter_n_components': None}, - 'factor_analysis': {'iter_n_components': None} + "empirical": {"store_precision": False, "assume_centered": True}, + "diagonal_fixed": {"store_precision": False, "assume_centered": True}, + "ledoit_wolf": {"store_precision": False, "assume_centered": True}, + "oas": {"store_precision": False, "assume_centered": True}, + "shrinkage": { + "shrinkage": 0.1, + "store_precision": False, + "assume_centered": True, + }, + "shrunk": { + "shrinkage": np.logspace(-4, 0, 30), + "store_precision": False, + "assume_centered": True, + }, + "pca": {"iter_n_components": None}, + "factor_analysis": {"iter_n_components": None}, } for ch_type in _DATA_CH_TYPES_SPLIT: - _method_params['diagonal_fixed'][ch_type] = 0.1 + _method_params["diagonal_fixed"][ch_type] = 0.1 if isinstance(method_params, dict): for key, values in method_params.items(): if key not in _method_params: - raise ValueError('key (%s) must be "%s"' % - (key, '" or "'.join(_method_params))) + raise ValueError( + 'key (%s) must be "%s"' % (key, '" or "'.join(_method_params)) + ) _method_params[key].update(method_params[key]) - shrinkage = method_params.get('shrinkage', {}).get('shrinkage', 0.1) + shrinkage = method_params.get("shrinkage", {}).get("shrinkage", 0.1) if not 0 <= shrinkage <= 1: - raise ValueError('shrinkage must be between 0 and 1, got %s' - % (shrinkage,)) + raise ValueError("shrinkage must be between 0 and 1, got %s" % (shrinkage,)) was_auto = False if method is None: - method = ['empirical'] - elif method == 'auto' and allow_auto: + method = ["empirical"] + elif method == "auto" and allow_auto: was_auto = True - method = ['shrunk', 'diagonal_fixed', 'empirical', 'factor_analysis'] + method = ["shrunk", "diagonal_fixed", "empirical", "factor_analysis"] if not isinstance(method, (list, tuple)): method = [method] if not all(k in accepted_methods for k in method): raise ValueError( - 'Invalid {name} ({method}). Accepted values (individually or ' + "Invalid {name} ({method}). Accepted values (individually or " 'in a list) are any of "{accepted_methods}" or None.'.format( - name=name, method=method, accepted_methods=accepted_methods)) - if not (isinstance(rank, str) and rank == 'full'): + name=name, method=method, accepted_methods=accepted_methods + ) + ) + if not (isinstance(rank, str) and rank == "full"): if was_auto: - method.pop(method.index('factor_analysis')) + method.pop(method.index("factor_analysis")) for method_ in method: - if method_ in ('pca', 'factor_analysis'): - raise ValueError('%s can so far only be used with rank="full",' - ' got rank=%r' % (method_, rank)) + if method_ in ("pca", "factor_analysis"): + raise ValueError( + '%s can so far only be used with rank="full",' + " got rank=%r" % (method_, rank) + ) if not keep_sample_mean: - if len(method) != 1 or 'empirical' not in method: - raise ValueError('`keep_sample_mean=False` is only supported' - 'with %s="empirical"' % (name,)) + if len(method) != 1 or "empirical" not in method: + raise ValueError( + "`keep_sample_mean=False` is only supported" + 'with %s="empirical"' % (name,) + ) for p, v in _method_params.items(): - if v.get('assume_centered', None) is False: - raise ValueError('`assume_centered` must be True' - ' if `keep_sample_mean` is False') + if v.get("assume_centered", None) is False: + raise ValueError( + "`assume_centered` must be True" " if `keep_sample_mean` is False" + ) return method, _method_params @verbose -def compute_covariance(epochs, keep_sample_mean=True, tmin=None, tmax=None, - projs=None, method='empirical', method_params=None, - cv=3, scalings=None, n_jobs=None, - return_estimators=False, on_mismatch='raise', - rank=None, verbose=None): +def compute_covariance( + epochs, + keep_sample_mean=True, + tmin=None, + tmax=None, + projs=None, + method="empirical", + method_params=None, + cv=3, + scalings=None, + n_jobs=None, + return_estimators=False, + on_mismatch="raise", + rank=None, + verbose=None, +): """Estimate noise covariance matrix from epochs. The noise covariance is typically estimated on pre-stimulus periods @@ -859,7 +1051,8 @@ def compute_covariance(epochs, keep_sample_mean=True, tmin=None, tmax=None, # scale to natural unit for best stability with MEG/EEG scalings = _check_scalings_user(scalings) method, _method_params = _check_method_params( - method, method_params, keep_sample_mean, rank=rank) + method, method_params, keep_sample_mean, rank=rank + ) del method_params # for multi condition support epochs is required to refer to a list of @@ -878,43 +1071,49 @@ def _unpack_epochs(epochs): epochs = sum([_unpack_epochs(epoch) for epoch in epochs], []) # check for baseline correction - if any(epochs_t.baseline is None and epochs_t.info['highpass'] < 0.5 and - keep_sample_mean for epochs_t in epochs): - warn('Epochs are not baseline corrected, covariance ' - 'matrix may be inaccurate') - - orig = epochs[0].info['dev_head_t'] - _check_on_missing(on_mismatch, 'on_mismatch') + if any( + epochs_t.baseline is None + and epochs_t.info["highpass"] < 0.5 + and keep_sample_mean + for epochs_t in epochs + ): + warn( + "Epochs are not baseline corrected, covariance " "matrix may be inaccurate" + ) + + orig = epochs[0].info["dev_head_t"] + _check_on_missing(on_mismatch, "on_mismatch") for ei, epoch in enumerate(epochs): epoch.info._check_consistency() - if (orig is None) != (epoch.info['dev_head_t'] is None) or \ - (orig is not None and not - np.allclose(orig['trans'], - epoch.info['dev_head_t']['trans'])): - msg = ('MEG<->Head transform mismatch between epochs[0]:\n%s\n\n' - 'and epochs[%s]:\n%s' - % (orig, ei, epoch.info['dev_head_t'])) - _on_missing(on_mismatch, msg, 'on_mismatch') - - bads = epochs[0].info['bads'] + if (orig is None) != (epoch.info["dev_head_t"] is None) or ( + orig is not None + and not np.allclose(orig["trans"], epoch.info["dev_head_t"]["trans"]) + ): + msg = ( + "MEG<->Head transform mismatch between epochs[0]:\n%s\n\n" + "and epochs[%s]:\n%s" % (orig, ei, epoch.info["dev_head_t"]) + ) + _on_missing(on_mismatch, msg, "on_mismatch") + + bads = epochs[0].info["bads"] if projs is None: - projs = epochs[0].info['projs'] + projs = epochs[0].info["projs"] # make sure Epochs are compatible for epochs_t in epochs[1:]: if epochs_t.proj != epochs[0].proj: - raise ValueError('Epochs must agree on the use of projections') - for proj_a, proj_b in zip(epochs_t.info['projs'], projs): + raise ValueError("Epochs must agree on the use of projections") + for proj_a, proj_b in zip(epochs_t.info["projs"], projs): if not _proj_equal(proj_a, proj_b): - raise ValueError('Epochs must have same projectors') + raise ValueError("Epochs must have same projectors") projs = _check_projs(projs) ch_names = epochs[0].ch_names # make sure Epochs are compatible for epochs_t in epochs[1:]: - if epochs_t.info['bads'] != bads: - raise ValueError('Epochs must have same bad channels') + if epochs_t.info["bads"] != bads: + raise ValueError("Epochs must have same bad channels") if epochs_t.ch_names != ch_names: - raise ValueError('Epochs must have same channel names') + raise ValueError("Epochs must have same channel names") picks_list = _picks_by_type(epochs[0].info) picks_meeg = np.concatenate([b for _, b in picks_list]) picks_meeg = np.sort(picks_meeg) @@ -929,7 +1128,6 @@ def _unpack_epochs(epochs): n_epochs = np.zeros(n_epoch_types, dtype=np.int64) for ii, epochs_t in enumerate(epochs): - tslice = _get_tslice(epochs_t, tmin, tmax) for e in epochs_t: e = e[picks_meeg, tslice] @@ -940,8 +1138,10 @@ def _unpack_epochs(epochs): n_samples_epoch = n_samples // n_epochs norm_const = np.sum(n_samples_epoch * (n_epochs - 1)) - data_mean = [1.0 / n_epoch * np.dot(mean, mean.T) for n_epoch, mean - in zip(n_epochs, data_mean)] + data_mean = [ + 1.0 / n_epoch * np.dot(mean, mean.T) + for n_epoch, mean in zip(n_epochs, data_mean) + ] info = pick_info(info, picks_meeg) tslice = _get_tslice(epochs[0], tmin, tmax) @@ -960,14 +1160,22 @@ def _unpack_epochs(epochs): epochs = epochs.T # sklearn | C-order cov_data = _compute_covariance_auto( - epochs, method=method, method_params=_method_params, info=info, - cv=cv, n_jobs=n_jobs, stop_early=True, picks_list=picks_list, - scalings=scalings, rank=rank) + epochs, + method=method, + method_params=_method_params, + info=info, + cv=cv, + n_jobs=n_jobs, + stop_early=True, + picks_list=picks_list, + scalings=scalings, + rank=rank, + ) if keep_sample_mean is False: - cov = cov_data['empirical']['data'] + cov = cov_data["empirical"]["data"] # undo scaling - cov *= (n_samples_tot - 1) + cov *= n_samples_tot - 1 # ... apply pre-computed class-wise normalization for mean_cov in data_mean: cov -= mean_cov @@ -975,28 +1183,29 @@ def _unpack_epochs(epochs): covs = list() for this_method, data in cov_data.items(): - cov = Covariance(data.pop('data'), ch_names, info['bads'], projs, - nfree=n_samples_tot - 1) + cov = Covariance( + data.pop("data"), ch_names, info["bads"], projs, nfree=n_samples_tot - 1 + ) # add extra info cov.update(method=this_method, **data) covs.append(cov) - logger.info('Number of samples used : %d' % n_samples_tot) - covs.sort(key=lambda c: c['loglik'], reverse=True) + logger.info("Number of samples used : %d" % n_samples_tot) + covs.sort(key=lambda c: c["loglik"], reverse=True) if len(covs) > 1: - msg = ['log-likelihood on unseen data (descending order):'] + msg = ["log-likelihood on unseen data (descending order):"] for c in covs: - msg.append('%s: %0.3f' % (c['method'], c['loglik'])) - logger.info('\n '.join(msg)) + msg.append("%s: %0.3f" % (c["method"], c["loglik"])) + logger.info("\n ".join(msg)) if return_estimators: out = covs else: out = covs[0] - logger.info('selecting best estimator: {}'.format(out['method'])) + logger.info("selecting best estimator: {}".format(out["method"])) else: out = covs[0] - logger.info('[done]') + logger.info("[done]") return out @@ -1004,11 +1213,12 @@ def _unpack_epochs(epochs): def _check_scalings_user(scalings): if isinstance(scalings, dict): for k, v in scalings.items(): - _check_option('the keys in `scalings`', k, ['mag', 'grad', 'eeg']) + _check_option("the keys in `scalings`", k, ["mag", "grad", "eeg"]) elif scalings is not None and not isinstance(scalings, np.ndarray): - raise TypeError('scalings must be a dict, ndarray, or None, got %s' - % type(scalings)) - scalings = _handle_default('scalings', scalings) + raise TypeError( + "scalings must be a dict, ndarray, or None, got %s" % type(scalings) + ) + scalings = _handle_default("scalings", scalings) return scalings @@ -1021,33 +1231,49 @@ def _eigvec_subspace(eig, eigvec, mask): return eig, eigvec -def _compute_covariance_auto(data, method, info, method_params, cv, - scalings, n_jobs, stop_early, picks_list, rank): +def _compute_covariance_auto( + data, + method, + info, + method_params, + cv, + scalings, + n_jobs, + stop_early, + picks_list, + rank, +): """Compute covariance auto mode.""" # rescale to improve numerical stability orig_rank = rank rank = compute_rank( RawArray(data.T, info, copy=None, verbose=_verbose_safe_false()), - rank, scalings, info) + rank, + scalings, + info, + ) with _scaled_array(data.T, picks_list, scalings): C = np.dot(data.T, data) - _, eigvec, mask = _smart_eigh(C, info, rank, proj_subspace=True, - do_compute_rank=False) + _, eigvec, mask = _smart_eigh( + C, info, rank, proj_subspace=True, do_compute_rank=False + ) eigvec = eigvec[mask] data = np.dot(data, eigvec.T) used = np.where(mask)[0] - sub_picks_list = [(key, np.searchsorted(used, picks)) - for key, picks in picks_list] + sub_picks_list = [ + (key, np.searchsorted(used, picks)) for key, picks in picks_list + ] sub_info = pick_info(info, used) if len(used) != len(mask) else info - logger.info('Reducing data rank from %s -> %s' - % (len(mask), eigvec.shape[0])) + logger.info("Reducing data rank from %s -> %s" % (len(mask), eigvec.shape[0])) estimator_cov_info = list() - msg = 'Estimating covariance using %s' + msg = "Estimating covariance using %s" - ok_sklearn = check_version('sklearn') - if not ok_sklearn and (len(method) != 1 or method[0] != 'empirical'): - raise ValueError('scikit-learn is not installed, `method` must be ' - '`empirical`, got %s' % (method,)) + ok_sklearn = check_version("sklearn") + if not ok_sklearn and (len(method) != 1 or method[0] != "empirical"): + raise ValueError( + "scikit-learn is not installed, `method` must be " + "`empirical`, got %s" % (method,) + ) for method_ in method: data_ = data.copy() @@ -1056,20 +1282,21 @@ def _compute_covariance_auto(data, method, info, method_params, cv, mp = method_params[method_] _info = {} - if method_ == 'empirical': + if method_ == "empirical": est = EmpiricalCovariance(**mp) est.fit(data_) estimator_cov_info.append((est, est.covariance_, _info)) del est - elif method_ == 'diagonal_fixed': + elif method_ == "diagonal_fixed": est = _RegCovariance(info=sub_info, **mp) est.fit(data_) estimator_cov_info.append((est, est.covariance_, _info)) del est - elif method_ == 'ledoit_wolf': + elif method_ == "ledoit_wolf": from sklearn.covariance import LedoitWolf + shrinkages = [] lw = LedoitWolf(**mp) @@ -1081,8 +1308,9 @@ def _compute_covariance_auto(data, method, info, method_params, cv, estimator_cov_info.append((sc, sc.covariance_, _info)) del lw, sc - elif method_ == 'oas': + elif method_ == "oas": from sklearn.covariance import OAS + shrinkages = [] oas = OAS(**mp) @@ -1094,58 +1322,65 @@ def _compute_covariance_auto(data, method, info, method_params, cv, estimator_cov_info.append((sc, sc.covariance_, _info)) del oas, sc - elif method_ == 'shrinkage': + elif method_ == "shrinkage": sc = _ShrunkCovariance(**mp) sc.fit(data_) estimator_cov_info.append((sc, sc.covariance_, _info)) del sc - elif method_ == 'shrunk': + elif method_ == "shrunk": from sklearn.model_selection import GridSearchCV from sklearn.covariance import ShrunkCovariance - shrinkage = mp.pop('shrinkage') - tuned_parameters = [{'shrinkage': shrinkage}] + + shrinkage = mp.pop("shrinkage") + tuned_parameters = [{"shrinkage": shrinkage}] shrinkages = [] - gs = GridSearchCV(ShrunkCovariance(**mp), - tuned_parameters, cv=cv) + gs = GridSearchCV(ShrunkCovariance(**mp), tuned_parameters, cv=cv) for ch_type, picks in sub_picks_list: gs.fit(data_[:, picks]) - shrinkages.append((ch_type, gs.best_estimator_.shrinkage, - picks)) + shrinkages.append((ch_type, gs.best_estimator_.shrinkage, picks)) shrinkages = [c[0] for c in zip(shrinkages)] sc = _ShrunkCovariance(shrinkage=shrinkages, **mp) sc.fit(data_) estimator_cov_info.append((sc, sc.covariance_, _info)) del shrinkage, sc - elif method_ == 'pca': - assert orig_rank == 'full' + elif method_ == "pca": + assert orig_rank == "full" pca, _info = _auto_low_rank_model( - data_, method_, n_jobs=n_jobs, method_params=mp, cv=cv, - stop_early=stop_early) + data_, + method_, + n_jobs=n_jobs, + method_params=mp, + cv=cv, + stop_early=stop_early, + ) pca.fit(data_) estimator_cov_info.append((pca, pca.get_covariance(), _info)) del pca - elif method_ == 'factor_analysis': - assert orig_rank == 'full' + elif method_ == "factor_analysis": + assert orig_rank == "full" fa, _info = _auto_low_rank_model( - data_, method_, n_jobs=n_jobs, method_params=mp, cv=cv, - stop_early=stop_early) + data_, + method_, + n_jobs=n_jobs, + method_params=mp, + cv=cv, + stop_early=stop_early, + ) fa.fit(data_) estimator_cov_info.append((fa, fa.get_covariance(), _info)) del fa else: - raise ValueError('Oh no! Your estimator does not have' - ' a .fit method') - logger.info('Done.') + raise ValueError("Oh no! Your estimator does not have" " a .fit method") + logger.info("Done.") if len(method) > 1: - logger.info('Using cross-validation to select the best estimator.') + logger.info("Using cross-validation to select the best estimator.") out = dict() - for ei, (estimator, cov, runtime_info) in \ - enumerate(estimator_cov_info): + for ei, (estimator, cov, runtime_info) in enumerate(estimator_cov_info): if len(method) > 1: loglik = _cross_val(data, estimator, cv, n_jobs) else: @@ -1169,8 +1404,8 @@ def _gaussian_loglik_scorer(est, X, y=None): # compute empirical covariance of the test set precision = est.get_precision() n_samples, n_features = X.shape - log_like = -.5 * (X * (np.dot(X, precision))).sum(axis=1) - log_like -= .5 * (n_features * log(2. * np.pi) - _logdet(precision)) + log_like = -0.5 * (X * (np.dot(X, precision))).sum(axis=1) + log_like -= 0.5 * (n_features * log(2.0 * np.pi) - _logdet(precision)) out = np.mean(log_like) return out @@ -1178,22 +1413,28 @@ def _gaussian_loglik_scorer(est, X, y=None): def _cross_val(data, est, cv, n_jobs): """Compute cross validation.""" from sklearn.model_selection import cross_val_score - return np.mean(cross_val_score(est, data, cv=cv, n_jobs=n_jobs, - scoring=_gaussian_loglik_scorer)) + return np.mean( + cross_val_score( + est, data, cv=cv, n_jobs=n_jobs, scoring=_gaussian_loglik_scorer + ) + ) -def _auto_low_rank_model(data, mode, n_jobs, method_params, cv, - stop_early=True, verbose=None): + +def _auto_low_rank_model( + data, mode, n_jobs, method_params, cv, stop_early=True, verbose=None +): """Compute latent variable models.""" method_params = deepcopy(method_params) - iter_n_components = method_params.pop('iter_n_components') + iter_n_components = method_params.pop("iter_n_components") if iter_n_components is None: iter_n_components = np.arange(5, data.shape[1], 5) from sklearn.decomposition import PCA, FactorAnalysis - if mode == 'factor_analysis': + + if mode == "factor_analysis": est = FactorAnalysis else: - assert mode == 'pca' + assert mode == "pca" est = PCA est = est(**method_params) est.n_components = 1 @@ -1203,8 +1444,10 @@ def _auto_low_rank_model(data, mode, n_jobs, method_params, cv, # make sure we don't empty the thing if it's a generator max_n = max(list(deepcopy(iter_n_components))) if max_n > data.shape[1]: - warn('You are trying to estimate %i components on matrix ' - 'with %i features.' % (max_n, data.shape[1])) + warn( + "You are trying to estimate %i components on matrix " + "with %i features." % (max_n, data.shape[1]) + ) for ii, n in enumerate(iter_n_components): est.n_components = n @@ -1213,30 +1456,34 @@ def _auto_low_rank_model(data, mode, n_jobs, method_params, cv, except ValueError: score = np.inf if np.isinf(score) or score > 0: - logger.info('... infinite values encountered. stopping estimation') + logger.info("... infinite values encountered. stopping estimation") break - logger.info('... rank: %i - loglik: %0.3f' % (n, score)) + logger.info("... rank: %i - loglik: %0.3f" % (n, score)) if score != -np.inf: scores[ii] = score - if (ii >= 3 and np.all(np.diff(scores[ii - 3:ii]) < 0) and stop_early): + if ii >= 3 and np.all(np.diff(scores[ii - 3 : ii]) < 0) and stop_early: # early stop search when loglik has been going down 3 times - logger.info('early stopping parameter search.') + logger.info("early stopping parameter search.") break # happens if rank is too low right form the beginning if np.isnan(scores).all(): - raise RuntimeError('Oh no! Could not estimate covariance because all ' - 'scores were NaN. Please contact the MNE-Python ' - 'developers.') + raise RuntimeError( + "Oh no! Could not estimate covariance because all " + "scores were NaN. Please contact the MNE-Python " + "developers." + ) i_score = np.nanargmax(scores) best = est.n_components = iter_n_components[i_score] - logger.info('... best model at rank = %i' % best) - runtime_info = {'ranks': np.array(iter_n_components), - 'scores': scores, - 'best': best, - 'cv': cv} + logger.info("... best model at rank = %i" % best) + runtime_info = { + "ranks": np.array(iter_n_components), + "scores": scores, + "best": best, + "cv": cv, + } return est, runtime_info @@ -1247,11 +1494,25 @@ def _auto_low_rank_model(data, mode, n_jobs, method_params, cv, class _RegCovariance(BaseEstimator): """Aux class.""" - def __init__(self, info, grad=0.1, mag=0.1, eeg=0.1, seeg=0.1, - ecog=0.1, hbo=0.1, hbr=0.1, fnirs_cw_amplitude=0.1, - fnirs_fd_ac_amplitude=0.1, fnirs_fd_phase=0.1, fnirs_od=0.1, - csd=0.1, dbs=0.1, store_precision=False, - assume_centered=False): + def __init__( + self, + info, + grad=0.1, + mag=0.1, + eeg=0.1, + seeg=0.1, + ecog=0.1, + hbo=0.1, + hbr=0.1, + fnirs_cw_amplitude=0.1, + fnirs_fd_ac_amplitude=0.1, + fnirs_fd_phase=0.1, + fnirs_od=0.1, + csd=0.1, + dbs=0.1, + store_precision=False, + assume_centered=False, + ): self.info = info # For sklearn compat, these cannot (easily?) be combined into # a single dictionary @@ -1274,20 +1535,33 @@ def __init__(self, info, grad=0.1, mag=0.1, eeg=0.1, seeg=0.1, def fit(self, X): """Fit covariance model with classical diagonal regularization.""" self.estimator_ = EmpiricalCovariance( - store_precision=self.store_precision, - assume_centered=self.assume_centered) + store_precision=self.store_precision, assume_centered=self.assume_centered + ) self.covariance_ = self.estimator_.fit(X).covariance_ self.covariance_ = 0.5 * (self.covariance_ + self.covariance_.T) cov_ = Covariance( - data=self.covariance_, names=self.info['ch_names'], - bads=self.info['bads'], projs=self.info['projs'], - nfree=len(self.covariance_)) + data=self.covariance_, + names=self.info["ch_names"], + bads=self.info["bads"], + projs=self.info["projs"], + nfree=len(self.covariance_), + ) cov_ = regularize( - cov_, self.info, proj=False, exclude='bads', - grad=self.grad, mag=self.mag, eeg=self.eeg, - ecog=self.ecog, seeg=self.seeg, dbs=self.dbs, - hbo=self.hbo, hbr=self.hbr, rank='full') + cov_, + self.info, + proj=False, + exclude="bads", + grad=self.grad, + mag=self.mag, + eeg=self.eeg, + ecog=self.ecog, + seeg=self.seeg, + dbs=self.dbs, + hbo=self.hbo, + hbr=self.hbr, + rank="full", + ) self.estimator_.covariance_ = self.covariance_ = cov_.data return self @@ -1303,9 +1577,7 @@ def get_precision(self): class _ShrunkCovariance(BaseEstimator): """Aux class.""" - def __init__(self, store_precision, assume_centered, - shrinkage=0.1): - + def __init__(self, store_precision, assume_centered, shrinkage=0.1): self.store_precision = store_precision self.assume_centered = assume_centered self.shrinkage = shrinkage @@ -1313,14 +1585,15 @@ def __init__(self, store_precision, assume_centered, def fit(self, X): """Fit covariance model with oracle shrinkage regularization.""" from sklearn.covariance import shrunk_covariance + self.estimator_ = EmpiricalCovariance( - store_precision=self.store_precision, - assume_centered=self.assume_centered) + store_precision=self.store_precision, assume_centered=self.assume_centered + ) cov = self.estimator_.fit(X).covariance_ if not isinstance(self.shrinkage, (list, tuple)): - shrinkage = [('all', self.shrinkage, np.arange(len(cov)))] + shrinkage = [("all", self.shrinkage, np.arange(len(cov)))] else: shrinkage = self.shrinkage @@ -1328,7 +1601,7 @@ def fit(self, X): for a, b in itt.combinations(shrinkage, 2): picks_i, picks_j = a[2], b[2] ch_ = a[0], b[0] - if 'eeg' in ch_: + if "eeg" in ch_: zero_cross_cov[np.ix_(picks_i, picks_j)] = True zero_cross_cov[np.ix_(picks_j, picks_i)] = True @@ -1337,14 +1610,13 @@ def fit(self, X): # Apply shrinkage to blocks for ch_type, c, picks in shrinkage: sub_cov = cov[np.ix_(picks, picks)] - cov[np.ix_(picks, picks)] = shrunk_covariance(sub_cov, - shrinkage=c) + cov[np.ix_(picks, picks)] = shrunk_covariance(sub_cov, shrinkage=c) # Apply shrinkage to cross-cov for a, b in itt.combinations(shrinkage, 2): shrinkage_i, shrinkage_j = a[1], b[1] picks_i, picks_j = a[2], b[2] - c_ij = np.sqrt((1. - shrinkage_i) * (1. - shrinkage_j)) + c_ij = np.sqrt((1.0 - shrinkage_i) * (1.0 - shrinkage_j)) cov[np.ix_(picks_i, picks_j)] *= c_ij cov[np.ix_(picks_j, picks_i)] *= c_ij @@ -1358,10 +1630,11 @@ def fit(self, X): def score(self, X_test, y=None): """Delegate to modified EmpiricalCovariance instance.""" # compute empirical covariance of the test set - test_cov = empirical_covariance(X_test - self.estimator_.location_, - assume_centered=True) + test_cov = empirical_covariance( + X_test - self.estimator_.location_, assume_centered=True + ) if np.any(self.zero_cross_cov_): - test_cov[self.zero_cross_cov_] = 0. + test_cov[self.zero_cross_cov_] = 0.0 res = log_likelihood(test_cov, self.estimator_.get_precision()) return res @@ -1373,6 +1646,7 @@ def get_precision(self): ############################################################################### # Writing + @verbose def write_cov(fname, cov, *, overwrite=False, verbose=None): """Write a noise covariance matrix. @@ -1399,6 +1673,7 @@ def write_cov(fname, cov, *, overwrite=False, verbose=None): ############################################################################### # Prepare for inverse modeling + def _unpack_epochs(epochs): """Aux Function.""" if len(epochs.event_id) > 1: @@ -1418,8 +1693,10 @@ def _get_ch_whitener(A, pca, ch_type, rank): eig[:-rank] = 0.0 mask[:-rank] = False - logger.info(' Setting small %s eigenvalues to zero (%s)' - % (ch_type, 'using PCA' if pca else 'without PCA')) + logger.info( + " Setting small %s eigenvalues to zero (%s)" + % (ch_type, "using PCA" if pca else "without PCA") + ) if pca: # No PCA case. # This line will reduce the actual number of variables in data # and leadfield to the true rank. @@ -1428,8 +1705,15 @@ def _get_ch_whitener(A, pca, ch_type, rank): @verbose -def prepare_noise_cov(noise_cov, info, ch_names=None, rank=None, - scalings=None, on_rank_mismatch='ignore', verbose=None): +def prepare_noise_cov( + noise_cov, + info, + ch_names=None, + rank=None, + scalings=None, + on_rank_mismatch="ignore", + verbose=None, +): """Prepare noise covariance matrix. Parameters @@ -1461,7 +1745,7 @@ def prepare_noise_cov(noise_cov, info, ch_names=None, rank=None, # reorder C and info to match ch_names order noise_cov_idx = list() missing = list() - ch_names = info['ch_names'] if ch_names is None else ch_names + ch_names = info["ch_names"] if ch_names is None else ch_names for c in ch_names: # this could be try/except ValueError, but it is not the preferred way if c in noise_cov.ch_names: @@ -1469,51 +1753,71 @@ def prepare_noise_cov(noise_cov, info, ch_names=None, rank=None, else: missing.append(c) if len(missing): - raise RuntimeError('Not all channels present in noise covariance:\n%s' - % missing) + raise RuntimeError( + "Not all channels present in noise covariance:\n%s" % missing + ) C = noise_cov._get_square()[np.ix_(noise_cov_idx, noise_cov_idx)] - info = pick_info( - info, pick_channels(info['ch_names'], ch_names, ordered=False)) - projs = info['projs'] + noise_cov['projs'] + info = pick_info(info, pick_channels(info["ch_names"], ch_names, ordered=False)) + projs = info["projs"] + noise_cov["projs"] noise_cov = Covariance( - data=C, names=ch_names, bads=list(noise_cov['bads']), - projs=deepcopy(noise_cov['projs']), nfree=noise_cov['nfree'], - method=noise_cov.get('method', None), - loglik=noise_cov.get('loglik', None)) - - eig, eigvec, _ = _smart_eigh(noise_cov, info, rank, scalings, projs, - ch_names, on_rank_mismatch=on_rank_mismatch) + data=C, + names=ch_names, + bads=list(noise_cov["bads"]), + projs=deepcopy(noise_cov["projs"]), + nfree=noise_cov["nfree"], + method=noise_cov.get("method", None), + loglik=noise_cov.get("loglik", None), + ) + + eig, eigvec, _ = _smart_eigh( + noise_cov, + info, + rank, + scalings, + projs, + ch_names, + on_rank_mismatch=on_rank_mismatch, + ) noise_cov.update(eig=eig, eigvec=eigvec) return noise_cov @verbose -def _smart_eigh(C, info, rank, scalings=None, projs=None, - ch_names=None, proj_subspace=False, do_compute_rank=True, - on_rank_mismatch='ignore', verbose=None): +def _smart_eigh( + C, + info, + rank, + scalings=None, + projs=None, + ch_names=None, + proj_subspace=False, + do_compute_rank=True, + on_rank_mismatch="ignore", + verbose=None, +): """Compute eigh of C taking into account rank and ch_type scalings.""" - scalings = _handle_default('scalings_cov_rank', scalings) - projs = info['projs'] if projs is None else projs - ch_names = info['ch_names'] if ch_names is None else ch_names - if info['ch_names'] != ch_names: - info = pick_info(info, [info['ch_names'].index(c) for c in ch_names]) - assert info['ch_names'] == ch_names + scalings = _handle_default("scalings_cov_rank", scalings) + projs = info["projs"] if projs is None else projs + ch_names = info["ch_names"] if ch_names is None else ch_names + if info["ch_names"] != ch_names: + info = pick_info(info, [info["ch_names"].index(c) for c in ch_names]) + assert info["ch_names"] == ch_names n_chan = len(ch_names) # Create the projection operator proj, ncomp, _ = make_projector(projs, ch_names) if isinstance(C, Covariance): - C = C['data'] + C = C["data"] if ncomp > 0: - logger.info(' Created an SSP operator (subspace dimension = %d)' - % ncomp) + logger.info(" Created an SSP operator (subspace dimension = %d)" % ncomp) C = np.dot(proj, np.dot(C, proj.T)) noise_cov = Covariance(C, ch_names, [], projs, 0) if do_compute_rank: # if necessary rank = compute_rank( - noise_cov, rank, scalings, info, on_rank_mismatch=on_rank_mismatch) + noise_cov, rank, scalings, info, on_rank_mismatch=on_rank_mismatch + ) assert C.ndim == 2 and C.shape[0] == C.shape[1] # time saving short-circuit @@ -1524,14 +1828,15 @@ def _smart_eigh(C, info, rank, scalings=None, projs=None, eig = np.zeros(n_chan, dtype) eigvec = np.zeros((n_chan, n_chan), dtype) mask = np.zeros(n_chan, bool) - for ch_type, picks in _picks_by_type(info, meg_combined=True, - ref_meg=False, exclude=[]): + for ch_type, picks in _picks_by_type( + info, meg_combined=True, ref_meg=False, exclude=[] + ): if len(picks) == 0: continue this_C = C[np.ix_(picks, picks)] - if ch_type not in rank and ch_type in ('mag', 'grad'): - this_rank = rank['meg'] # if there is only one or the other + if ch_type not in rank and ch_type in ("mag", "grad"): + this_rank = rank["meg"] # if there is only one or the other else: this_rank = rank[ch_type] @@ -1541,21 +1846,43 @@ def _smart_eigh(C, info, rank, scalings=None, projs=None, e, ev = _eigvec_subspace(e, ev, m) eig[picks], eigvec[np.ix_(picks, picks)], mask[picks] = e, ev, m # XXX : also handle ref for sEEG and ECoG - if ch_type == 'eeg' and _needs_eeg_average_ref_proj(info) and not \ - _has_eeg_average_ref_proj(info, projs=projs): - warn('No average EEG reference present in info["projs"], ' - 'covariance may be adversely affected. Consider recomputing ' - 'covariance using with an average eeg reference projector ' - 'added.') + if ( + ch_type == "eeg" + and _needs_eeg_average_ref_proj(info) + and not _has_eeg_average_ref_proj(info, projs=projs) + ): + warn( + 'No average EEG reference present in info["projs"], ' + "covariance may be adversely affected. Consider recomputing " + "covariance using with an average eeg reference projector " + "added." + ) return eig, eigvec, mask @verbose -def regularize(cov, info, mag=0.1, grad=0.1, eeg=0.1, exclude='bads', - proj=True, seeg=0.1, ecog=0.1, hbo=0.1, hbr=0.1, - fnirs_cw_amplitude=0.1, fnirs_fd_ac_amplitude=0.1, - fnirs_fd_phase=0.1, fnirs_od=0.1, csd=0.1, dbs=0.1, - rank=None, scalings=None, verbose=None): +def regularize( + cov, + info, + mag=0.1, + grad=0.1, + eeg=0.1, + exclude="bads", + proj=True, + seeg=0.1, + ecog=0.1, + hbo=0.1, + hbr=0.1, + fnirs_cw_amplitude=0.1, + fnirs_fd_ac_amplitude=0.1, + fnirs_fd_phase=0.1, + fnirs_od=0.1, + csd=0.1, + dbs=0.1, + rank=None, + scalings=None, + verbose=None, +): """Regularize noise covariance matrix. This method works by adding a constant to the diagonal for each @@ -1629,37 +1956,54 @@ def regularize(cov, info, mag=0.1, grad=0.1, eeg=0.1, exclude='bads', mne.compute_covariance """ # noqa: E501 from scipy import linalg + cov = cov.copy() info._check_consistency() - scalings = _handle_default('scalings_cov_rank', scalings) - regs = dict(eeg=eeg, seeg=seeg, dbs=dbs, ecog=ecog, hbo=hbo, hbr=hbr, - fnirs_cw_amplitude=fnirs_cw_amplitude, - fnirs_fd_ac_amplitude=fnirs_fd_ac_amplitude, - fnirs_fd_phase=fnirs_fd_phase, fnirs_od=fnirs_od, csd=csd) + scalings = _handle_default("scalings_cov_rank", scalings) + regs = dict( + eeg=eeg, + seeg=seeg, + dbs=dbs, + ecog=ecog, + hbo=hbo, + hbr=hbr, + fnirs_cw_amplitude=fnirs_cw_amplitude, + fnirs_fd_ac_amplitude=fnirs_fd_ac_amplitude, + fnirs_fd_phase=fnirs_fd_phase, + fnirs_od=fnirs_od, + csd=csd, + ) if exclude is None: raise ValueError('exclude must be a list of strings or "bads"') - if exclude == 'bads': - exclude = info['bads'] + cov['bads'] + if exclude == "bads": + exclude = info["bads"] + cov["bads"] picks_dict = {ch_type: [] for ch_type in _DATA_CH_TYPES_SPLIT} - meg_combined = 'auto' if rank != 'full' else False - picks_dict.update(dict(_picks_by_type( - info, meg_combined=meg_combined, exclude=exclude, ref_meg=False))) - if len(picks_dict.get('meg', [])) > 0 and rank != 'full': # combined + meg_combined = "auto" if rank != "full" else False + picks_dict.update( + dict( + _picks_by_type( + info, meg_combined=meg_combined, exclude=exclude, ref_meg=False + ) + ) + ) + if len(picks_dict.get("meg", [])) > 0 and rank != "full": # combined if mag != grad: - raise ValueError('On data where magnetometers and gradiometers ' - 'are dependent (e.g., SSSed data), mag (%s) must ' - 'equal grad (%s)' % (mag, grad)) - logger.info('Regularizing MEG channels jointly') - regs['meg'] = mag + raise ValueError( + "On data where magnetometers and gradiometers " + "are dependent (e.g., SSSed data), mag (%s) must " + "equal grad (%s)" % (mag, grad) + ) + logger.info("Regularizing MEG channels jointly") + regs["meg"] = mag else: regs.update(mag=mag, grad=grad) - if rank != 'full': + if rank != "full": rank = compute_rank(cov, rank, scalings, info) - info_ch_names = info['ch_names'] + info_ch_names = info["ch_names"] ch_names_by_type = dict() for ch_type, picks_type in picks_dict.items(): ch_names_by_type[ch_type] = [info_ch_names[i] for i in picks_type] @@ -1667,7 +2011,8 @@ def regularize(cov, info, mag=0.1, grad=0.1, eeg=0.1, exclude='bads', # This actually removes bad channels from the cov, which is not backward # compatible, so let's leave all channels in cov_good = pick_channels_cov( - cov, include=info_ch_names, exclude=exclude, ordered=False) + cov, include=info_ch_names, exclude=exclude, ordered=False + ) ch_names = cov_good.ch_names # Now get the indices for each channel type in the cov @@ -1678,14 +2023,14 @@ def regularize(cov, info, mag=0.1, grad=0.1, eeg=0.1, exclude='bads', idx_cov[ch_type].append(i) break else: - raise Exception('channel %s is unknown type' % ch) + raise Exception("channel %s is unknown type" % ch) - C = cov_good['data'] + C = cov_good["data"] assert len(C) == sum(map(len, idx_cov.values())) if proj: - projs = info['projs'] + cov_good['projs'] + projs = info["projs"] + cov_good["projs"] projs = activate_proj(projs) for ch_type in idx_cov: @@ -1702,16 +2047,18 @@ def regularize(cov, info, mag=0.1, grad=0.1, eeg=0.1, exclude='bads', this_C = C[np.ix_(idx, idx)] U = np.eye(this_C.shape[0]) this_ch_names = [ch_names[k] for k in idx] - if rank == 'full': + if rank == "full": if proj: P, ncomp, _ = make_projector(projs, this_ch_names) if ncomp > 0: # This adjustment ends up being redundant if rank is None: U = linalg.svd(P)[0][:, :-ncomp] - logger.info(' Created an SSP operator for %s ' - '(dimension = %d)' % (desc, ncomp)) + logger.info( + " Created an SSP operator for %s " + "(dimension = %d)" % (desc, ncomp) + ) else: - this_picks = pick_channels(info['ch_names'], this_ch_names) + this_picks = pick_channels(info["ch_names"], this_ch_names) this_info = pick_info(info, this_picks) # Here we could use proj_subspace=True, but this should not matter # since this is already in a loop over channel types @@ -1720,20 +2067,18 @@ def regularize(cov, info, mag=0.1, grad=0.1, eeg=0.1, exclude='bads', this_C = np.dot(U.T, np.dot(this_C, U)) sigma = np.mean(np.diag(this_C)) - this_C.flat[::len(this_C) + 1] += reg * sigma # modify diag inplace + this_C.flat[:: len(this_C) + 1] += reg * sigma # modify diag inplace this_C = np.dot(U, np.dot(this_C, U.T)) C[np.ix_(idx, idx)] = this_C # Put data back in correct locations - idx = pick_channels( - cov.ch_names, info_ch_names, exclude=exclude, ordered=False) - cov['data'][np.ix_(idx, idx)] = C + idx = pick_channels(cov.ch_names, info_ch_names, exclude=exclude, ordered=False) + cov["data"][np.ix_(idx, idx)] = C return cov -def _regularized_covariance(data, reg=None, method_params=None, info=None, - rank=None): +def _regularized_covariance(data, reg=None, method_params=None, info=None, rank=None): """Compute a regularized covariance from data using sklearn. This is a convenience wrapper for mne.decoding functions, which @@ -1744,36 +2089,55 @@ def _regularized_covariance(data, reg=None, method_params=None, info=None, cov : ndarray, shape (n_channels, n_channels) The covariance matrix. """ - _validate_type(reg, (str, 'numeric', None)) + _validate_type(reg, (str, "numeric", None)) if reg is None: - reg = 'empirical' + reg = "empirical" elif not isinstance(reg, str): reg = float(reg) if method_params is not None: - raise ValueError('If reg is a float, method_params must be None ' - '(got %s)' % (type(method_params),)) - method_params = dict(shrinkage=dict( - shrinkage=reg, assume_centered=True, store_precision=False)) - reg = 'shrinkage' + raise ValueError( + "If reg is a float, method_params must be None " + "(got %s)" % (type(method_params),) + ) + method_params = dict( + shrinkage=dict(shrinkage=reg, assume_centered=True, store_precision=False) + ) + reg = "shrinkage" method, method_params = _check_method_params( - reg, method_params, name='reg', allow_auto=False, rank=rank) + reg, method_params, name="reg", allow_auto=False, rank=rank + ) # use mag instead of eeg here to avoid the cov EEG projection warning - info = create_info(data.shape[-2], 1000., 'mag') if info is None else info + info = create_info(data.shape[-2], 1000.0, "mag") if info is None else info picks_list = _picks_by_type(info) - scalings = _handle_default('scalings_cov_rank', None) + scalings = _handle_default("scalings_cov_rank", None) cov = _compute_covariance_auto( - data.T, method=method, method_params=method_params, - info=info, cv=None, n_jobs=None, stop_early=True, - picks_list=picks_list, scalings=scalings, - rank=rank)[reg]['data'] + data.T, + method=method, + method_params=method_params, + info=info, + cv=None, + n_jobs=None, + stop_early=True, + picks_list=picks_list, + scalings=scalings, + rank=rank, + )[reg]["data"] return cov @verbose -def compute_whitener(noise_cov, info=None, picks=None, rank=None, - scalings=None, return_rank=False, pca=False, - return_colorer=False, on_rank_mismatch='warn', - verbose=None): +def compute_whitener( + noise_cov, + info=None, + picks=None, + rank=None, + scalings=None, + return_rank=False, + pca=False, + return_colorer=False, + on_rank_mismatch="warn", + verbose=None, +): """Compute whitening matrix. Parameters @@ -1824,53 +2188,56 @@ def compute_whitener(noise_cov, info=None, picks=None, rank=None, colorer : ndarray, shape (n_channels, n_channels) or (n_channels, n_nonzero) The coloring matrix. """ # noqa: E501 - _validate_type(pca, (str, bool), 'space') - _valid_pcas = (True, 'white', False) + _validate_type(pca, (str, bool), "space") + _valid_pcas = (True, "white", False) if pca not in _valid_pcas: - raise ValueError('space must be one of %s, got %s' - % (_valid_pcas, pca)) + raise ValueError("space must be one of %s, got %s" % (_valid_pcas, pca)) if info is None: - if 'eig' not in noise_cov: - raise ValueError('info can only be None if the noise cov has ' - 'already been prepared with prepare_noise_cov') - ch_names = deepcopy(noise_cov['names']) + if "eig" not in noise_cov: + raise ValueError( + "info can only be None if the noise cov has " + "already been prepared with prepare_noise_cov" + ) + ch_names = deepcopy(noise_cov["names"]) else: picks = _picks_to_idx(info, picks, with_ref_meg=False) - ch_names = [info['ch_names'][k] for k in picks] + ch_names = [info["ch_names"][k] for k in picks] del picks noise_cov = prepare_noise_cov( - noise_cov, info, ch_names, rank, scalings, - on_rank_mismatch=on_rank_mismatch) + noise_cov, info, ch_names, rank, scalings, on_rank_mismatch=on_rank_mismatch + ) n_chan = len(ch_names) - assert n_chan == len(noise_cov['eig']) + assert n_chan == len(noise_cov["eig"]) # Omit the zeroes due to projection - eig = noise_cov['eig'].copy() - nzero = (eig > 0) - eig[~nzero] = 0. # get rid of numerical noise (negative) ones + eig = noise_cov["eig"].copy() + nzero = eig > 0 + eig[~nzero] = 0.0 # get rid of numerical noise (negative) ones - if noise_cov['eigvec'].dtype.kind == 'c': + if noise_cov["eigvec"].dtype.kind == "c": dtype = np.complex128 else: dtype = np.float64 W = np.zeros((n_chan, 1), dtype) W[nzero, 0] = 1.0 / np.sqrt(eig[nzero]) # Rows of eigvec are the eigenvectors - W = W * noise_cov['eigvec'] # C ** -0.5 - C = np.sqrt(eig) * noise_cov['eigvec'].conj().T # C ** 0.5 + W = W * noise_cov["eigvec"] # C ** -0.5 + C = np.sqrt(eig) * noise_cov["eigvec"].conj().T # C ** 0.5 n_nzero = nzero.sum() - logger.info(' Created the whitener using a noise covariance matrix ' - 'with rank %d (%d small eigenvalues omitted)' - % (n_nzero, noise_cov['dim'] - n_nzero)) + logger.info( + " Created the whitener using a noise covariance matrix " + "with rank %d (%d small eigenvalues omitted)" + % (n_nzero, noise_cov["dim"] - n_nzero) + ) # Do the requested projection if pca is True: W = W[nzero] C = C[:, nzero] elif pca is False: - W = np.dot(noise_cov['eigvec'].conj().T, W) - C = np.dot(C, noise_cov['eigvec']) + W = np.dot(noise_cov["eigvec"].conj().T, W) + C = np.dot(C, noise_cov["eigvec"]) # Triage return out = W, ch_names @@ -1882,8 +2249,9 @@ def compute_whitener(noise_cov, info=None, picks=None, rank=None, @verbose -def whiten_evoked(evoked, noise_cov, picks=None, diag=None, rank=None, - scalings=None, verbose=None): +def whiten_evoked( + evoked, noise_cov, picks=None, diag=None, rank=None, scalings=None, verbose=None +): """Whiten evoked data using given noise covariance. Parameters @@ -1919,8 +2287,9 @@ def whiten_evoked(evoked, noise_cov, picks=None, diag=None, rank=None, if diag: noise_cov = noise_cov.as_diag() - W, _ = compute_whitener(noise_cov, evoked.info, picks=picks, - rank=rank, scalings=scalings) + W, _ = compute_whitener( + noise_cov, evoked.info, picks=picks, rank=rank, scalings=scalings + ) evoked.data[picks] = np.sqrt(evoked.nave) * np.dot(W, evoked.data[picks]) return evoked @@ -1931,9 +2300,10 @@ def _read_cov(fid, node, cov_kind, limited=False, verbose=None): """Read a noise covariance matrix.""" # Find all covariance matrices from scipy import sparse + covs = dir_tree_find(node, FIFF.FIFFB_MNE_COV) if len(covs) == 0: - raise ValueError('No covariance matrices found') + raise ValueError("No covariance matrices found") # Is any of the covariance matrices a noise covariance for p in range(len(covs)): @@ -1945,7 +2315,7 @@ def _read_cov(fid, node, cov_kind, limited=False, verbose=None): # Find all the necessary data tag = find_tag(fid, this, FIFF.FIFF_MNE_COV_DIM) if tag is None: - raise ValueError('Covariance matrix dimension not found') + raise ValueError("Covariance matrix dimension not found") dim = int(tag.data.item()) tag = find_tag(fid, this, FIFF.FIFF_MNE_COV_NFREE) @@ -1970,22 +2340,25 @@ def _read_cov(fid, node, cov_kind, limited=False, verbose=None): if tag is None: names = [] else: - names = _safe_name_list(tag.data, 'read', 'names') + names = _safe_name_list(tag.data, "read", "names") if len(names) != dim: - raise ValueError('Number of names does not match ' - 'covariance matrix dimension') + raise ValueError( + "Number of names does not match " "covariance matrix dimension" + ) tag = find_tag(fid, this, FIFF.FIFF_MNE_COV) if tag is None: tag = find_tag(fid, this, FIFF.FIFF_MNE_COV_DIAG) if tag is None: - raise ValueError('No covariance matrix data found') + raise ValueError("No covariance matrix data found") else: # Diagonal is stored data = tag.data diag = True - logger.info(' %d x %d diagonal covariance (kind = ' - '%d) found.' % (dim, dim, cov_kind)) + logger.info( + " %d x %d diagonal covariance (kind = " + "%d) found." % (dim, dim, cov_kind) + ) else: if not sparse.issparse(tag.data): @@ -1994,15 +2367,19 @@ def _read_cov(fid, node, cov_kind, limited=False, verbose=None): data = np.zeros((dim, dim)) data[np.tril(np.ones((dim, dim))) > 0] = vals data = data + data.T - data.flat[::dim + 1] /= 2.0 + data.flat[:: dim + 1] /= 2.0 diag = False - logger.info(' %d x %d full covariance (kind = %d) ' - 'found.' % (dim, dim, cov_kind)) + logger.info( + " %d x %d full covariance (kind = %d) " + "found." % (dim, dim, cov_kind) + ) else: diag = False data = tag.data - logger.info(' %d x %d sparse covariance (kind = %d)' - ' found.' % (dim, dim, cov_kind)) + logger.info( + " %d x %d sparse covariance (kind = %d)" + " found." % (dim, dim, cov_kind) + ) # Read the possibly precomputed decomposition tag1 = find_tag(fid, this, FIFF.FIFF_MNE_COV_EIGENVALUES) @@ -2023,20 +2400,28 @@ def _read_cov(fid, node, cov_kind, limited=False, verbose=None): # Put it together assert dim == len(data) assert data.ndim == (1 if diag else 2) - cov = dict(kind=cov_kind, diag=diag, dim=dim, names=names, - data=data, projs=projs, bads=bads, nfree=nfree, eig=eig, - eigvec=eigvec) + cov = dict( + kind=cov_kind, + diag=diag, + dim=dim, + names=names, + data=data, + projs=projs, + bads=bads, + nfree=nfree, + eig=eig, + eigvec=eigvec, + ) if score is not None: - cov['loglik'] = score + cov["loglik"] = score if method is not None: - cov['method'] = method + cov["method"] = method if limited: - del cov['kind'], cov['dim'], cov['diag'] + del cov["kind"], cov["dim"], cov["diag"] return cov - logger.info(' Did not find the desired covariance matrix (kind = %d)' - % cov_kind) + logger.info(" Did not find the desired covariance matrix (kind = %d)" % cov_kind) return None @@ -2046,55 +2431,55 @@ def _write_cov(fid, cov): start_block(fid, FIFF.FIFFB_MNE_COV) # Dimensions etc. - write_int(fid, FIFF.FIFF_MNE_COV_KIND, cov['kind']) - write_int(fid, FIFF.FIFF_MNE_COV_DIM, cov['dim']) - if cov['nfree'] > 0: - write_int(fid, FIFF.FIFF_MNE_COV_NFREE, cov['nfree']) + write_int(fid, FIFF.FIFF_MNE_COV_KIND, cov["kind"]) + write_int(fid, FIFF.FIFF_MNE_COV_DIM, cov["dim"]) + if cov["nfree"] > 0: + write_int(fid, FIFF.FIFF_MNE_COV_NFREE, cov["nfree"]) # Channel names - if cov['names'] is not None and len(cov['names']) > 0: + if cov["names"] is not None and len(cov["names"]) > 0: write_name_list_sanitized( - fid, FIFF.FIFF_MNE_ROW_NAMES, cov['names'], 'cov["names"]') + fid, FIFF.FIFF_MNE_ROW_NAMES, cov["names"], 'cov["names"]' + ) # Data - if cov['diag']: - write_double(fid, FIFF.FIFF_MNE_COV_DIAG, cov['data']) + if cov["diag"]: + write_double(fid, FIFF.FIFF_MNE_COV_DIAG, cov["data"]) else: # Store only lower part of covariance matrix - dim = cov['dim'] + dim = cov["dim"] mask = np.tril(np.ones((dim, dim), dtype=bool)) > 0 - vals = cov['data'][mask].ravel() + vals = cov["data"][mask].ravel() write_double(fid, FIFF.FIFF_MNE_COV, vals) # Eigenvalues and vectors if present - if cov['eig'] is not None and cov['eigvec'] is not None: - write_float_matrix(fid, FIFF.FIFF_MNE_COV_EIGENVECTORS, cov['eigvec']) - write_double(fid, FIFF.FIFF_MNE_COV_EIGENVALUES, cov['eig']) + if cov["eig"] is not None and cov["eigvec"] is not None: + write_float_matrix(fid, FIFF.FIFF_MNE_COV_EIGENVECTORS, cov["eigvec"]) + write_double(fid, FIFF.FIFF_MNE_COV_EIGENVALUES, cov["eig"]) # Projection operator - if cov['projs'] is not None and len(cov['projs']) > 0: - _write_proj(fid, cov['projs']) + if cov["projs"] is not None and len(cov["projs"]) > 0: + _write_proj(fid, cov["projs"]) # Bad channels - _write_bad_channels(fid, cov['bads'], None) + _write_bad_channels(fid, cov["bads"], None) # estimator method - if 'method' in cov: - write_string(fid, FIFF.FIFF_MNE_COV_METHOD, cov['method']) + if "method" in cov: + write_string(fid, FIFF.FIFF_MNE_COV_METHOD, cov["method"]) # negative log-likelihood score - if 'loglik' in cov: - write_double( - fid, FIFF.FIFF_MNE_COV_SCORE, np.array(cov['loglik'])) + if "loglik" in cov: + write_double(fid, FIFF.FIFF_MNE_COV_SCORE, np.array(cov["loglik"])) # Done! end_block(fid, FIFF.FIFFB_MNE_COV) @verbose -def _ensure_cov(cov, name='cov', *, verbose=None): - _validate_type(cov, ('path-like', Covariance), name) - logger.info('Noise covariance : %s' % (cov,)) +def _ensure_cov(cov, name="cov", *, verbose=None): + _validate_type(cov, ("path-like", Covariance), name) + logger.info("Noise covariance : %s" % (cov,)) if not isinstance(cov, Covariance): cov = read_cov(cov, verbose=_verbose_safe_false()) return cov diff --git a/mne/cuda.py b/mne/cuda.py index 15a2be2bab7..2b2dab64836 100644 --- a/mne/cuda.py +++ b/mne/cuda.py @@ -4,14 +4,22 @@ import numpy as np -from .utils import (sizeof_fmt, logger, get_config, warn, _explain_exception, - verbose, fill_doc, _check_option) +from .utils import ( + sizeof_fmt, + logger, + get_config, + warn, + _explain_exception, + verbose, + fill_doc, + _check_option, +) _cuda_capable = False -def get_cuda_memory(kind='available'): +def get_cuda_memory(kind="available"): """Get the amount of free memory for CUDA operations. Parameters @@ -25,10 +33,11 @@ def get_cuda_memory(kind='available'): The amount of available or total memory as a human-readable string. """ if not _cuda_capable: - warn('CUDA not enabled, returning zero for memory') + warn("CUDA not enabled, returning zero for memory") mem = 0 else: import cupy + mem = cupy.cuda.runtime.memGetInfo()[dict(available=0, total=1)[kind]] return sizeof_fmt(mem) @@ -55,29 +64,30 @@ def init_cuda(ignore_config=False, verbose=None): global _cuda_capable if _cuda_capable: return - if not ignore_config and (get_config('MNE_USE_CUDA', 'false').lower() != - 'true'): - logger.info('CUDA not enabled in config, skipping initialization') + if not ignore_config and (get_config("MNE_USE_CUDA", "false").lower() != "true"): + logger.info("CUDA not enabled in config, skipping initialization") return # Triage possible errors for informative messaging _cuda_capable = False try: import cupy # noqa except ImportError: - warn('module cupy not found, CUDA not enabled') + warn("module cupy not found, CUDA not enabled") return - device_id = int(get_config('MNE_CUDA_DEVICE', '0')) + device_id = int(get_config("MNE_CUDA_DEVICE", "0")) try: # Initialize CUDA _set_cuda_device(device_id, verbose) except Exception: - warn('so CUDA device could be initialized, likely a hardware error, ' - 'CUDA not enabled%s' % _explain_exception()) + warn( + "so CUDA device could be initialized, likely a hardware error, " + "CUDA not enabled%s" % _explain_exception() + ) return _cuda_capable = True # Figure out limit for CUDA FFT calculations - logger.info('Enabling CUDA with %s available memory' % get_cuda_memory()) + logger.info("Enabling CUDA with %s available memory" % get_cuda_memory()) @verbose @@ -92,28 +102,31 @@ def set_cuda_device(device_id, verbose=None): """ if _cuda_capable: _set_cuda_device(device_id, verbose) - elif get_config('MNE_USE_CUDA', 'false').lower() == 'true': + elif get_config("MNE_USE_CUDA", "false").lower() == "true": init_cuda() _set_cuda_device(device_id, verbose) else: - warn('Could not set CUDA device because CUDA is not enabled; either ' - 'run mne.cuda.init_cuda() first, or set the MNE_USE_CUDA config ' - 'variable to "true".') + warn( + "Could not set CUDA device because CUDA is not enabled; either " + "run mne.cuda.init_cuda() first, or set the MNE_USE_CUDA config " + 'variable to "true".' + ) @verbose def _set_cuda_device(device_id, verbose=None): """Set the CUDA device.""" import cupy + cupy.cuda.Device(device_id).use() - logger.info('Now using CUDA device {}'.format(device_id)) + logger.info("Now using CUDA device {}".format(device_id)) ############################################################################### # Repeated FFT multiplication -def _setup_cuda_fft_multiply_repeated(n_jobs, h, n_fft, - kind='FFT FIR filtering'): + +def _setup_cuda_fft_multiply_repeated(n_jobs, h, n_fft, kind="FFT FIR filtering"): """Set up repeated CUDA FFT multiplication with a given filter. Parameters @@ -154,28 +167,31 @@ def _setup_cuda_fft_multiply_repeated(n_jobs, h, n_fft, This function is designed to be used with fft_multiply_repeated(). """ from scipy.fft import rfft, irfft - cuda_dict = dict(n_fft=n_fft, rfft=rfft, irfft=irfft, - h_fft=rfft(h, n=n_fft)) + + cuda_dict = dict(n_fft=n_fft, rfft=rfft, irfft=irfft, h_fft=rfft(h, n=n_fft)) if isinstance(n_jobs, str): - _check_option('n_jobs', n_jobs, ('cuda',)) + _check_option("n_jobs", n_jobs, ("cuda",)) n_jobs = 1 init_cuda() if _cuda_capable: import cupy + try: # do the IFFT normalization now so we don't have to later - h_fft = cupy.array(cuda_dict['h_fft']) - logger.info('Using CUDA for %s' % kind) + h_fft = cupy.array(cuda_dict["h_fft"]) + logger.info("Using CUDA for %s" % kind) except Exception as exp: - logger.info('CUDA not used, could not instantiate memory ' - '(arrays may be too large: "%s"), falling back to ' - 'n_jobs=None' % str(exp)) - cuda_dict.update(h_fft=h_fft, - rfft=_cuda_upload_rfft, - irfft=_cuda_irfft_get) + logger.info( + "CUDA not used, could not instantiate memory " + '(arrays may be too large: "%s"), falling back to ' + "n_jobs=None" % str(exp) + ) + cuda_dict.update(h_fft=h_fft, rfft=_cuda_upload_rfft, irfft=_cuda_irfft_get) else: - logger.info('CUDA not used, CUDA could not be initialized, ' - 'falling back to n_jobs=None') + logger.info( + "CUDA not used, CUDA could not be initialized, " + "falling back to n_jobs=None" + ) return n_jobs, cuda_dict @@ -199,15 +215,16 @@ def _fft_multiply_repeated(x, cuda_dict): Filtered version of x. """ # do the fourier-domain operations - x_fft = cuda_dict['rfft'](x, cuda_dict['n_fft']) - x_fft *= cuda_dict['h_fft'] - x = cuda_dict['irfft'](x_fft, cuda_dict['n_fft']) + x_fft = cuda_dict["rfft"](x, cuda_dict["n_fft"]) + x_fft *= cuda_dict["h_fft"] + x = cuda_dict["irfft"](x_fft, cuda_dict["n_fft"]) return x ############################################################################### # FFT Resampling + def _setup_cuda_fft_resample(n_jobs, W, new_len): """Set up CUDA FFT resampling. @@ -248,52 +265,59 @@ def _setup_cuda_fft_resample(n_jobs, W, new_len): This function is designed to be used with fft_resample(). """ from scipy.fft import rfft, irfft + cuda_dict = dict(use_cuda=False, rfft=rfft, irfft=irfft) rfft_len_x = len(W) // 2 + 1 # fold the window onto inself (should be symmetric) and truncate W = W.copy() - W[1:rfft_len_x] = (W[1:rfft_len_x] + W[::-1][:rfft_len_x - 1]) / 2. + W[1:rfft_len_x] = (W[1:rfft_len_x] + W[::-1][: rfft_len_x - 1]) / 2.0 W = W[:rfft_len_x] if isinstance(n_jobs, str): - _check_option('n_jobs', n_jobs, ('cuda',)) + _check_option("n_jobs", n_jobs, ("cuda",)) n_jobs = 1 init_cuda() if _cuda_capable: try: import cupy + # do the IFFT normalization now so we don't have to later W = cupy.array(W) - logger.info('Using CUDA for FFT resampling') + logger.info("Using CUDA for FFT resampling") except Exception: - logger.info('CUDA not used, could not instantiate memory ' - '(arrays may be too large), falling back to ' - 'n_jobs=None') + logger.info( + "CUDA not used, could not instantiate memory " + "(arrays may be too large), falling back to " + "n_jobs=None" + ) else: - cuda_dict.update(use_cuda=True, - rfft=_cuda_upload_rfft, - irfft=_cuda_irfft_get) + cuda_dict.update( + use_cuda=True, rfft=_cuda_upload_rfft, irfft=_cuda_irfft_get + ) else: - logger.info('CUDA not used, CUDA could not be initialized, ' - 'falling back to n_jobs=None') - cuda_dict['W'] = W + logger.info( + "CUDA not used, CUDA could not be initialized, " + "falling back to n_jobs=None" + ) + cuda_dict["W"] = W return n_jobs, cuda_dict def _cuda_upload_rfft(x, n, axis=-1): """Upload and compute rfft.""" import cupy + return cupy.fft.rfft(cupy.array(x), n=n, axis=axis) def _cuda_irfft_get(x, n, axis=-1): """Compute irfft and get.""" import cupy + return cupy.fft.irfft(x, n=n, axis=axis).get() @fill_doc -def _fft_resample(x, new_len, npads, to_removes, cuda_dict=None, - pad='reflect_limited'): +def _fft_resample(x, new_len, npads, to_removes, cuda_dict=None, pad="reflect_limited"): """Do FFT resampling with a filter function (possibly using CUDA). Parameters @@ -327,16 +351,16 @@ def _fft_resample(x, new_len, npads, to_removes, cuda_dict=None, old_len = len(x) shorter = new_len < old_len use_len = new_len if shorter else old_len - x_fft = cuda_dict['rfft'](x, None) + x_fft = cuda_dict["rfft"](x, None) if use_len % 2 == 0: nyq = use_len // 2 - x_fft[nyq:nyq + 1] *= 2 if shorter else 0.5 - x_fft *= cuda_dict['W'] - y = cuda_dict['irfft'](x_fft, new_len) + x_fft[nyq : nyq + 1] *= 2 if shorter else 0.5 + x_fft *= cuda_dict["W"] + y = cuda_dict["irfft"](x_fft, new_len) # now let's trim it back to the correct size (if there was padding) if (to_removes > 0).any(): - y = y[to_removes[0]:y.shape[0] - to_removes[1]] + y = y[to_removes[0] : y.shape[0] - to_removes[1]] return y @@ -344,20 +368,28 @@ def _fft_resample(x, new_len, npads, to_removes, cuda_dict=None, ############################################################################### # Misc + # this has to go in mne.cuda instead of mne.filter to avoid import errors -def _smart_pad(x, n_pad, pad='reflect_limited'): +def _smart_pad(x, n_pad, pad="reflect_limited"): """Pad vector x.""" n_pad = np.asarray(n_pad) assert n_pad.shape == (2,) if (n_pad == 0).all(): return x elif (n_pad < 0).any(): - raise RuntimeError('n_pad must be non-negative') - if pad == 'reflect_limited': + raise RuntimeError("n_pad must be non-negative") + if pad == "reflect_limited": # need to pad with zeros if len(x) <= npad l_z_pad = np.zeros(max(n_pad[0] - len(x) + 1, 0), dtype=x.dtype) r_z_pad = np.zeros(max(n_pad[1] - len(x) + 1, 0), dtype=x.dtype) - return np.concatenate([l_z_pad, 2 * x[0] - x[n_pad[0]:0:-1], x, - 2 * x[-1] - x[-2:-n_pad[1] - 2:-1], r_z_pad]) + return np.concatenate( + [ + l_z_pad, + 2 * x[0] - x[n_pad[0] : 0 : -1], + x, + 2 * x[-1] - x[-2 : -n_pad[1] - 2 : -1], + r_z_pad, + ] + ) else: return np.pad(x, (tuple(n_pad),), pad) diff --git a/mne/datasets/__init__.py b/mne/datasets/__init__.py index ec24f450fd0..1549fa21f8f 100644 --- a/mne/datasets/__init__.py +++ b/mne/datasets/__init__.py @@ -29,19 +29,47 @@ from . import eyelink from . import ucl_opm_auditory from ._fetch import fetch_dataset -from .utils import (_download_all_example_data, fetch_hcp_mmp_parcellation, - fetch_aparc_sub_parcellation, has_dataset) +from .utils import ( + _download_all_example_data, + fetch_hcp_mmp_parcellation, + fetch_aparc_sub_parcellation, + has_dataset, +) from ._fsaverage.base import fetch_fsaverage from ._infant.base import fetch_infant_template from ._phantom.base import fetch_phantom __all__ = [ - '_download_all_example_data', '_fake', 'brainstorm', 'eegbci', - 'fetch_aparc_sub_parcellation', 'fetch_fsaverage', 'fetch_infant_template', - 'fetch_hcp_mmp_parcellation', 'fieldtrip_cmc', 'hf_sef', 'kiloword', - 'misc', 'mtrf', 'multimodal', 'opm', 'phantom_4dbti', 'sample', - 'sleep_physionet', 'somato', 'spm_face', 'ssvep', 'testing', - 'visual_92_categories', 'limo', 'erp_core', 'epilepsy_ecog', - 'fetch_dataset', 'fetch_phantom', 'has_dataset', 'refmeg_noise', - 'fnirs_motor', 'eyelink' + "_download_all_example_data", + "_fake", + "brainstorm", + "eegbci", + "fetch_aparc_sub_parcellation", + "fetch_fsaverage", + "fetch_infant_template", + "fetch_hcp_mmp_parcellation", + "fieldtrip_cmc", + "hf_sef", + "kiloword", + "misc", + "mtrf", + "multimodal", + "opm", + "phantom_4dbti", + "sample", + "sleep_physionet", + "somato", + "spm_face", + "ssvep", + "testing", + "visual_92_categories", + "limo", + "erp_core", + "epilepsy_ecog", + "fetch_dataset", + "fetch_phantom", + "has_dataset", + "refmeg_noise", + "fnirs_motor", + "eyelink", ] diff --git a/mne/datasets/_fake/_fake.py b/mne/datasets/_fake/_fake.py index 61ef7678862..475b7aeb640 100644 --- a/mne/datasets/_fake/_fake.py +++ b/mne/datasets/_fake/_fake.py @@ -4,25 +4,28 @@ # License: BSD Style. from ...utils import verbose -from ..utils import (_data_path_doc, _download_mne_dataset, - _get_version, _version_doc) +from ..utils import _data_path_doc, _download_mne_dataset, _get_version, _version_doc @verbose -def data_path(path=None, force_update=False, update_path=False, - download=True, *, verbose=None): # noqa: D103 +def data_path( + path=None, force_update=False, update_path=False, download=True, *, verbose=None +): # noqa: D103 return _download_mne_dataset( - name='fake', processor='untar', path=path, - force_update=force_update, update_path=update_path, - download=download) + name="fake", + processor="untar", + path=path, + force_update=force_update, + update_path=update_path, + download=download, + ) -data_path.__doc__ = _data_path_doc.format(name='fake', - conf='MNE_DATASETS_FAKE_PATH') +data_path.__doc__ = _data_path_doc.format(name="fake", conf="MNE_DATASETS_FAKE_PATH") def get_version(): # noqa: D103 - return _get_version('fake') + return _get_version("fake") -get_version.__doc__ = _version_doc.format(name='fake') +get_version.__doc__ = _version_doc.format(name="fake") diff --git a/mne/datasets/_fetch.py b/mne/datasets/_fetch.py index 578c1cf82ed..37802f817b8 100644 --- a/mne/datasets/_fetch.py +++ b/mne/datasets/_fetch.py @@ -17,8 +17,13 @@ TESTING_VERSIONED, MISC_VERSIONED, ) -from .utils import (_dataset_version, _do_path_update, _get_path, - _log_time_size, _downloader_params) +from .utils import ( + _dataset_version, + _do_path_update, + _get_path, + _log_time_size, + _downloader_params, +) from ..fixes import _compare_version @@ -131,6 +136,7 @@ def fetch_dataset( pass a list of dicts. """ # noqa E501 import pooch + t0 = time.time() if auth is not None: @@ -153,7 +159,7 @@ def fetch_dataset( names = [params["dataset_name"] for params in dataset_params] name = names[0] dataset_dict = dataset_params[0] - config_key = dataset_dict.get('config_key', None) + config_key = dataset_dict.get("config_key", None) folder_name = dataset_dict["folder_name"] # get download path for specific dataset @@ -175,8 +181,9 @@ def fetch_dataset( # get the version of the dataset and then check if the version is outdated data_version = _dataset_version(final_path, name) - outdated = (want_version is not None and - _compare_version(want_version, '>', data_version)) + outdated = want_version is not None and _compare_version( + want_version, ">", data_version + ) if outdated: logger.info( @@ -188,16 +195,13 @@ def fetch_dataset( # return empty string if outdated dataset and we don't want to download if (not force_update) and outdated and not download: logger.info( - 'Dataset out of date but force_update=False and download=False, ' - 'returning empty data_path') + "Dataset out of date but force_update=False and download=False, " + "returning empty data_path" + ) return (empty, data_version) if return_version else empty # reasons to bail early (hf_sef has separate code for this): - if ( - (not force_update) - and (not outdated) - and (not name.startswith("hf_sef_")) - ): + if (not force_update) and (not outdated) and (not name.startswith("hf_sef_")): # ...if target folder exists (otherwise pooch downloads every # time because we don't save the archive files after unpacking, so # pooch can't check its checksum) @@ -215,8 +219,7 @@ def fetch_dataset( else: # If they don't have stdin, just accept the license # https://github.com/mne-tools/mne-python/issues/8513#issuecomment-726823724 # noqa: E501 - answer = _safe_input( - "%sAgree (y/[n])? " % _bst_license_text, use="y") + answer = _safe_input("%sAgree (y/[n])? " % _bst_license_text, use="y") if answer.lower() != "y": raise RuntimeError( "You must agree to the license to use this " "dataset" @@ -262,10 +265,11 @@ def fetch_dataset( ) except ValueError as err: err = str(err) - if 'hash of downloaded file' in str(err): + if "hash of downloaded file" in str(err): raise ValueError( - f'{err} Consider using force_update=True to force ' - 'the dataset to be downloaded again.') from None + f"{err} Consider using force_update=True to force " + "the dataset to be downloaded again." + ) from None else: raise fname = use_path / archive_name @@ -291,7 +295,7 @@ def fetch_dataset( data_version = _dataset_version(path, name) # 0.7 < 0.7.git should be False, therefore strip if check_version and ( - _compare_version(data_version, '<', mne_version.strip(".git")) + _compare_version(data_version, "<", mne_version.strip(".git")) ): warn( "The {name} dataset (version {current}) is older than " diff --git a/mne/datasets/_fsaverage/base.py b/mne/datasets/_fsaverage/base.py index d4a8f3d82c0..daa01dc64c2 100644 --- a/mne/datasets/_fsaverage/base.py +++ b/mne/datasets/_fsaverage/base.py @@ -65,19 +65,19 @@ def fetch_fsaverage(subjects_dir=None, *, verbose=None): # subjects_dir = _set_montage_coreg_path(subjects_dir) subjects_dir = op.abspath(op.expanduser(subjects_dir)) - fs_dir = op.join(subjects_dir, 'fsaverage') + fs_dir = op.join(subjects_dir, "fsaverage") os.makedirs(fs_dir, exist_ok=True) _manifest_check_download( - manifest_path=op.join(FSAVERAGE_MANIFEST_PATH, 'root.txt'), + manifest_path=op.join(FSAVERAGE_MANIFEST_PATH, "root.txt"), destination=op.join(subjects_dir), - url='https://osf.io/3bxqt/download?version=2', - hash_='5133fe92b7b8f03ae19219d5f46e4177', + url="https://osf.io/3bxqt/download?version=2", + hash_="5133fe92b7b8f03ae19219d5f46e4177", ) _manifest_check_download( - manifest_path=op.join(FSAVERAGE_MANIFEST_PATH, 'bem.txt'), - destination=op.join(subjects_dir, 'fsaverage'), - url='https://osf.io/7ve8g/download?version=4', - hash_='b31509cdcf7908af6a83dc5ee8f49fb1', + manifest_path=op.join(FSAVERAGE_MANIFEST_PATH, "bem.txt"), + destination=op.join(subjects_dir, "fsaverage"), + url="https://osf.io/7ve8g/download?version=4", + hash_="b31509cdcf7908af6a83dc5ee8f49fb1", ) return fs_dir @@ -85,8 +85,8 @@ def fetch_fsaverage(subjects_dir=None, *, verbose=None): def _get_create_subjects_dir(subjects_dir): subjects_dir = get_subjects_dir(subjects_dir, raise_error=False) if subjects_dir is None: - subjects_dir = _get_path(None, 'MNE_DATA', 'montage coregistration') - subjects_dir = op.join(subjects_dir, 'MNE-fsaverage-data') + subjects_dir = _get_path(None, "MNE_DATA", "montage coregistration") + subjects_dir = op.join(subjects_dir, "MNE-fsaverage-data") os.makedirs(subjects_dir, exist_ok=True) else: subjects_dir = str(subjects_dir) @@ -128,5 +128,5 @@ def _set_montage_coreg_path(subjects_dir=None): subjects_dir = _get_create_subjects_dir(subjects_dir) old_subjects_dir = get_subjects_dir(None, raise_error=False) if old_subjects_dir is None: - set_config('SUBJECTS_DIR', subjects_dir) + set_config("SUBJECTS_DIR", subjects_dir) return subjects_dir diff --git a/mne/datasets/_infant/base.py b/mne/datasets/_infant/base.py index c327c4835e0..196faa7bfc2 100644 --- a/mne/datasets/_infant/base.py +++ b/mne/datasets/_infant/base.py @@ -7,9 +7,9 @@ from ..utils import _manifest_check_download from ...utils import verbose, get_subjects_dir, _check_option, _validate_type -_AGES = '2wk 1mo 2mo 3mo 4.5mo 6mo 7.5mo 9mo 10.5mo 12mo 15mo 18mo 2yr' +_AGES = "2wk 1mo 2mo 3mo 4.5mo 6mo 7.5mo 9mo 10.5mo 12mo 15mo 18mo 2yr" # https://github.com/christian-oreilly/infant_template_paper/releases -_ORIGINAL_URL = 'https://github.com/christian-oreilly/infant_template_paper/releases/download/v0.1-alpha/{subject}.zip' # noqa: E501 +_ORIGINAL_URL = "https://github.com/christian-oreilly/infant_template_paper/releases/download/v0.1-alpha/{subject}.zip" # noqa: E501 # Formatted the same way as md5sum *.zip on Ubuntu: _ORIGINAL_HASHES = """ 851737d5f8f246883f2aef9819c6ec29 ANTS10-5Months3T.zip @@ -71,23 +71,24 @@ def fetch_infant_template(age, subjects_dir=None, *, verbose=None): # ... names = sorted(name for name in zip.namelist() if not zipfile.Path(zip, name).is_dir()) # noqa: E501 # ... with open(f'{name}.txt', 'w') as fid: # ... fid.write('\n'.join(names)) - _validate_type(age, str, 'age') - _check_option('age', age, _AGES.split()) + _validate_type(age, str, "age") + _check_option("age", age, _AGES.split()) subjects_dir = get_subjects_dir(subjects_dir, raise_error=True) - unit = dict(wk='Weeks', mo='Months', yr='Years')[age[-2:]] - first = age[:-2].split('.')[0] - dash = '-5' if '.5' in age else '-0' - subject = f'ANTS{first}{dash}{unit}3T' + unit = dict(wk="Weeks", mo="Months", yr="Years")[age[-2:]] + first = age[:-2].split(".")[0] + dash = "-5" if ".5" in age else "-0" + subject = f"ANTS{first}{dash}{unit}3T" # Actually get and create the files subj_dir = subjects_dir / subject os.makedirs(subj_dir, exist_ok=True) # .zip -> hash mapping - orig_hashes = dict(line.strip().split()[::-1] - for line in _ORIGINAL_HASHES.strip().splitlines()) + orig_hashes = dict( + line.strip().split()[::-1] for line in _ORIGINAL_HASHES.strip().splitlines() + ) _manifest_check_download( - manifest_path=op.join(_MANIFEST_PATH, f'{subject}.txt'), + manifest_path=op.join(_MANIFEST_PATH, f"{subject}.txt"), destination=subj_dir, url=_ORIGINAL_URL.format(subject=subject), - hash_=orig_hashes[f'{subject}.zip'], + hash_=orig_hashes[f"{subject}.zip"], ) return subject diff --git a/mne/datasets/_phantom/base.py b/mne/datasets/_phantom/base.py index 8785e3018ec..3d8af0e68ac 100644 --- a/mne/datasets/_phantom/base.py +++ b/mne/datasets/_phantom/base.py @@ -43,19 +43,21 @@ def fetch_phantom(kind, subjects_dir=None, *, verbose=None): .. versionadded:: 0.24 """ phantoms = dict( - otaniemi=dict(url='https://osf.io/j5czy/download?version=1', - hash='42d17db5b1db3e30327ffb4cf2649de8'), + otaniemi=dict( + url="https://osf.io/j5czy/download?version=1", + hash="42d17db5b1db3e30327ffb4cf2649de8", + ), ) - _validate_type(kind, str, 'kind') - _check_option('kind', kind, list(phantoms)) + _validate_type(kind, str, "kind") + _check_option("kind", kind, list(phantoms)) subjects_dir = get_subjects_dir(subjects_dir, raise_error=True) - subject = f'phantom_{kind}' + subject = f"phantom_{kind}" subject_dir = subjects_dir / subject os.makedirs(subject_dir, exist_ok=True) _manifest_check_download( - manifest_path=op.join(PHANTOM_MANIFEST_PATH, f'{subject}.txt'), + manifest_path=op.join(PHANTOM_MANIFEST_PATH, f"{subject}.txt"), destination=subjects_dir, - url=phantoms[kind]['url'], - hash_=phantoms[kind]['hash'], + url=phantoms[kind]["url"], + hash_=phantoms[kind]["hash"], ) return subject_dir diff --git a/mne/datasets/brainstorm/__init__.py b/mne/datasets/brainstorm/__init__.py index 8dcf9b79811..e97790f52c6 100644 --- a/mne/datasets/brainstorm/__init__.py +++ b/mne/datasets/brainstorm/__init__.py @@ -1,4 +1,3 @@ """Brainstorm datasets.""" -from . import (bst_raw, bst_resting, bst_auditory, bst_phantom_ctf, - bst_phantom_elekta) +from . import bst_raw, bst_resting, bst_auditory, bst_phantom_ctf, bst_phantom_elekta diff --git a/mne/datasets/brainstorm/bst_auditory.py b/mne/datasets/brainstorm/bst_auditory.py index 41c2f078671..a45dc72b5cf 100644 --- a/mne/datasets/brainstorm/bst_auditory.py +++ b/mne/datasets/brainstorm/bst_auditory.py @@ -2,8 +2,12 @@ # # License: BSD-3-Clause from ...utils import verbose -from ..utils import (_get_version, _version_doc, - _data_path_doc_accept, _download_mne_dataset) +from ..utils import ( + _get_version, + _version_doc, + _data_path_doc_accept, + _download_mne_dataset, +) _description = """ URL: http://neuroimage.usc.edu/brainstorm/DatasetAuditory @@ -22,26 +26,40 @@ @verbose -def data_path(path=None, force_update=False, update_path=True, - download=True, accept=False, *, verbose=None): # noqa: D103 +def data_path( + path=None, + force_update=False, + update_path=True, + download=True, + accept=False, + *, + verbose=None +): # noqa: D103 return _download_mne_dataset( - name='bst_auditory', processor='nested_untar', path=path, - force_update=force_update, update_path=update_path, - download=download, accept=accept) + name="bst_auditory", + processor="nested_untar", + path=path, + force_update=force_update, + update_path=update_path, + download=download, + accept=accept, + ) _data_path_doc = _data_path_doc_accept.format( - name='brainstorm', conf='MNE_DATASETS_BRAINSTORM_DATA_PATH') -_data_path_doc = _data_path_doc.replace('brainstorm dataset', - 'brainstorm (bst_auditory) dataset') + name="brainstorm", conf="MNE_DATASETS_BRAINSTORM_DATA_PATH" +) +_data_path_doc = _data_path_doc.replace( + "brainstorm dataset", "brainstorm (bst_auditory) dataset" +) data_path.__doc__ = _data_path_doc def get_version(): # noqa: D103 - return _get_version('bst_auditory') + return _get_version("bst_auditory") -get_version.__doc__ = _version_doc.format(name='brainstorm') +get_version.__doc__ = _version_doc.format(name="brainstorm") def description(): diff --git a/mne/datasets/brainstorm/bst_phantom_ctf.py b/mne/datasets/brainstorm/bst_phantom_ctf.py index 87300a82971..147626d33b6 100644 --- a/mne/datasets/brainstorm/bst_phantom_ctf.py +++ b/mne/datasets/brainstorm/bst_phantom_ctf.py @@ -2,8 +2,12 @@ # # License: BSD-3-Clause from ...utils import verbose -from ..utils import (_get_version, _version_doc, - _data_path_doc_accept, _download_mne_dataset) +from ..utils import ( + _get_version, + _version_doc, + _data_path_doc_accept, + _download_mne_dataset, +) _description = """ URL: http://neuroimage.usc.edu/brainstorm/Tutorials/PhantomCtf @@ -11,26 +15,40 @@ @verbose -def data_path(path=None, force_update=False, update_path=True, - download=True, accept=False, *, verbose=None): # noqa: D103 +def data_path( + path=None, + force_update=False, + update_path=True, + download=True, + accept=False, + *, + verbose=None +): # noqa: D103 return _download_mne_dataset( - name='bst_phantom_ctf', processor='nested_untar', path=path, - force_update=force_update, update_path=update_path, - download=download, accept=accept) + name="bst_phantom_ctf", + processor="nested_untar", + path=path, + force_update=force_update, + update_path=update_path, + download=download, + accept=accept, + ) _data_path_doc = _data_path_doc_accept.format( - name='brainstorm', conf='MNE_DATASETS_BRAINSTORM_DATA_PATH') -_data_path_doc = _data_path_doc.replace('brainstorm dataset', - 'brainstorm (bst_phantom_ctf) dataset') + name="brainstorm", conf="MNE_DATASETS_BRAINSTORM_DATA_PATH" +) +_data_path_doc = _data_path_doc.replace( + "brainstorm dataset", "brainstorm (bst_phantom_ctf) dataset" +) data_path.__doc__ = _data_path_doc def get_version(): # noqa: D103 - return _get_version('bst_phantom_ctf') + return _get_version("bst_phantom_ctf") -get_version.__doc__ = _version_doc.format(name='brainstorm') +get_version.__doc__ = _version_doc.format(name="brainstorm") def description(): diff --git a/mne/datasets/brainstorm/bst_phantom_elekta.py b/mne/datasets/brainstorm/bst_phantom_elekta.py index abfa5a68aca..8e5b5a8a69c 100644 --- a/mne/datasets/brainstorm/bst_phantom_elekta.py +++ b/mne/datasets/brainstorm/bst_phantom_elekta.py @@ -2,8 +2,12 @@ # # License: BSD-3-Clause from ...utils import verbose -from ..utils import (_get_version, _version_doc, - _data_path_doc_accept, _download_mne_dataset) +from ..utils import ( + _get_version, + _version_doc, + _data_path_doc_accept, + _download_mne_dataset, +) _description = """ URL: http://neuroimage.usc.edu/brainstorm/Tutorials/PhantomElekta @@ -11,27 +15,40 @@ @verbose -def data_path(path=None, force_update=False, update_path=True, - download=True, accept=False, *, verbose=None): # noqa: D103 +def data_path( + path=None, + force_update=False, + update_path=True, + download=True, + accept=False, + *, + verbose=None +): # noqa: D103 return _download_mne_dataset( - name='bst_phantom_elekta', processor='nested_untar', path=path, - force_update=force_update, update_path=update_path, - download=download, accept=accept) + name="bst_phantom_elekta", + processor="nested_untar", + path=path, + force_update=force_update, + update_path=update_path, + download=download, + accept=accept, + ) _data_path_doc = _data_path_doc_accept.format( - name='brainstorm', conf='MNE_DATASETS_BRAINSTORM_DATA_PATH') -_data_path_doc = _data_path_doc.replace('brainstorm dataset', - 'brainstorm (bst_phantom_elekta) ' - 'dataset') + name="brainstorm", conf="MNE_DATASETS_BRAINSTORM_DATA_PATH" +) +_data_path_doc = _data_path_doc.replace( + "brainstorm dataset", "brainstorm (bst_phantom_elekta) " "dataset" +) data_path.__doc__ = _data_path_doc def get_version(): # noqa: D103 - return _get_version('bst_phantom_elekta') + return _get_version("bst_phantom_elekta") -get_version.__doc__ = _version_doc.format(name='brainstorm') +get_version.__doc__ = _version_doc.format(name="brainstorm") def description(): diff --git a/mne/datasets/brainstorm/bst_raw.py b/mne/datasets/brainstorm/bst_raw.py index 0616ca176d5..f8d92e0b26c 100644 --- a/mne/datasets/brainstorm/bst_raw.py +++ b/mne/datasets/brainstorm/bst_raw.py @@ -4,11 +4,16 @@ from functools import partial from ...utils import verbose, get_config -from ..utils import (has_dataset, _get_version, _version_doc, - _data_path_doc_accept, _download_mne_dataset) +from ..utils import ( + has_dataset, + _get_version, + _version_doc, + _data_path_doc_accept, + _download_mne_dataset, +) -has_brainstorm_data = partial(has_dataset, name='bst_raw') +has_brainstorm_data = partial(has_dataset, name="bst_raw") _description = """ URL: http://neuroimage.usc.edu/brainstorm/DatasetMedianNerveCtf @@ -26,26 +31,40 @@ @verbose -def data_path(path=None, force_update=False, update_path=True, - download=True, accept=False, *, verbose=None): # noqa: D103 +def data_path( + path=None, + force_update=False, + update_path=True, + download=True, + accept=False, + *, + verbose=None +): # noqa: D103 return _download_mne_dataset( - name='bst_raw', processor='nested_untar', path=path, - force_update=force_update, update_path=update_path, - download=download, accept=accept) + name="bst_raw", + processor="nested_untar", + path=path, + force_update=force_update, + update_path=update_path, + download=download, + accept=accept, + ) _data_path_doc = _data_path_doc_accept.format( - name='brainstorm', conf='MNE_DATASETS_BRAINSTORM_DATA_PATH') -_data_path_doc = _data_path_doc.replace('brainstorm dataset', - 'brainstorm (bst_raw) dataset') + name="brainstorm", conf="MNE_DATASETS_BRAINSTORM_DATA_PATH" +) +_data_path_doc = _data_path_doc.replace( + "brainstorm dataset", "brainstorm (bst_raw) dataset" +) data_path.__doc__ = _data_path_doc def get_version(): # noqa: D103 - return _get_version('bst_raw') + return _get_version("bst_raw") -get_version.__doc__ = _version_doc.format(name='brainstorm') +get_version.__doc__ = _version_doc.format(name="brainstorm") def description(): # noqa: D103 @@ -55,8 +74,7 @@ def description(): # noqa: D103 def _skip_bstraw_data(): - skip_testing = (get_config('MNE_SKIP_TESTING_DATASET_TESTS', 'false') == - 'true') + skip_testing = get_config("MNE_SKIP_TESTING_DATASET_TESTS", "false") == "true" skip = skip_testing or not has_brainstorm_data() return skip @@ -64,5 +82,7 @@ def _skip_bstraw_data(): def requires_bstraw_data(func): """Skip testing data test.""" import pytest - return pytest.mark.skipif(_skip_bstraw_data(), - reason='Requires brainstorm dataset')(func) + + return pytest.mark.skipif( + _skip_bstraw_data(), reason="Requires brainstorm dataset" + )(func) diff --git a/mne/datasets/brainstorm/bst_resting.py b/mne/datasets/brainstorm/bst_resting.py index e0eb226e863..9e2f8f7e73b 100644 --- a/mne/datasets/brainstorm/bst_resting.py +++ b/mne/datasets/brainstorm/bst_resting.py @@ -2,8 +2,12 @@ # # License: BSD-3-Clause from ...utils import verbose -from ..utils import (_get_version, _version_doc, - _data_path_doc_accept, _download_mne_dataset) +from ..utils import ( + _get_version, + _version_doc, + _data_path_doc_accept, + _download_mne_dataset, +) _description = """ URL: http://neuroimage.usc.edu/brainstorm/DatasetResting @@ -14,26 +18,40 @@ @verbose -def data_path(path=None, force_update=False, update_path=True, - download=True, accept=False, *, verbose=None): # noqa: D103 +def data_path( + path=None, + force_update=False, + update_path=True, + download=True, + accept=False, + *, + verbose=None +): # noqa: D103 return _download_mne_dataset( - name='bst_resting', processor='nested_untar', path=path, - force_update=force_update, update_path=update_path, - download=download, accept=accept) + name="bst_resting", + processor="nested_untar", + path=path, + force_update=force_update, + update_path=update_path, + download=download, + accept=accept, + ) _data_path_doc = _data_path_doc_accept.format( - name='brainstorm', conf='MNE_DATASETS_BRAINSTORM_DATA_PATH') -_data_path_doc = _data_path_doc.replace('brainstorm dataset', - 'brainstorm (bst_resting) dataset') + name="brainstorm", conf="MNE_DATASETS_BRAINSTORM_DATA_PATH" +) +_data_path_doc = _data_path_doc.replace( + "brainstorm dataset", "brainstorm (bst_resting) dataset" +) data_path.__doc__ = _data_path_doc def get_version(): # noqa: D103 - return _get_version('bst_resting') + return _get_version("bst_resting") -get_version.__doc__ = _version_doc.format(name='brainstorm') +get_version.__doc__ = _version_doc.format(name="brainstorm") def description(): diff --git a/mne/datasets/config.py b/mne/datasets/config.py index ec45dbbf91b..7869f97a78e 100644 --- a/mne/datasets/config.py +++ b/mne/datasets/config.py @@ -87,7 +87,7 @@ # respective repos, and make a new release of the dataset on GitHub. Then # update the checksum in the MNE_DATASETS dict below, and change version # here: ↓↓↓↓↓ ↓↓↓ -RELEASES = dict(testing='0.146', misc='0.26') +RELEASES = dict(testing="0.146", misc="0.26") TESTING_VERSIONED = f'mne-testing-data-{RELEASES["testing"]}' MISC_VERSIONED = f'mne-misc-data-{RELEASES["misc"]}' @@ -109,240 +109,245 @@ # of the downloaded dataset (ex: "MNE_DATASETS_EEGBCI_PATH"). # Testing and misc are at the top as they're updated most often -MNE_DATASETS['testing'] = dict( - archive_name=f'{TESTING_VERSIONED}.tar.gz', - hash='md5:a2e86fe404f4321408b22f38711d11b7', - url=('https://codeload.github.com/mne-tools/mne-testing-data/' - f'tar.gz/{RELEASES["testing"]}'), +MNE_DATASETS["testing"] = dict( + archive_name=f"{TESTING_VERSIONED}.tar.gz", + hash="md5:a2e86fe404f4321408b22f38711d11b7", + url=( + "https://codeload.github.com/mne-tools/mne-testing-data/" + f'tar.gz/{RELEASES["testing"]}' + ), # In case we ever have to resort to osf.io again... # archive_name='mne-testing-data.tar.gz', # hash='md5:c805a5fed8ca46f723e7eec828d90824', # url='https://osf.io/dqfgy/download?version=1', # 0.136 - folder_name='MNE-testing-data', - config_key='MNE_DATASETS_TESTING_PATH', + folder_name="MNE-testing-data", + config_key="MNE_DATASETS_TESTING_PATH", ) -MNE_DATASETS['misc'] = dict( - archive_name=f'{MISC_VERSIONED}.tar.gz', # 'mne-misc-data', - hash='md5:868b484fadd73b1d1a3535b7194a0d03', - url=('https://codeload.github.com/mne-tools/mne-misc-data/tar.gz/' - f'{RELEASES["misc"]}'), - folder_name='MNE-misc-data', - config_key='MNE_DATASETS_MISC_PATH' +MNE_DATASETS["misc"] = dict( + archive_name=f"{MISC_VERSIONED}.tar.gz", # 'mne-misc-data', + hash="md5:868b484fadd73b1d1a3535b7194a0d03", + url=( + "https://codeload.github.com/mne-tools/mne-misc-data/tar.gz/" + f'{RELEASES["misc"]}' + ), + folder_name="MNE-misc-data", + config_key="MNE_DATASETS_MISC_PATH", ) -MNE_DATASETS['fnirs_motor'] = dict( - archive_name='MNE-fNIRS-motor-data.tgz', - hash='md5:c4935d19ddab35422a69f3326a01fef8', - url='https://osf.io/dj3eh/download?version=1', - folder_name='MNE-fNIRS-motor-data', - config_key='MNE_DATASETS_FNIRS_MOTOR_PATH', +MNE_DATASETS["fnirs_motor"] = dict( + archive_name="MNE-fNIRS-motor-data.tgz", + hash="md5:c4935d19ddab35422a69f3326a01fef8", + url="https://osf.io/dj3eh/download?version=1", + folder_name="MNE-fNIRS-motor-data", + config_key="MNE_DATASETS_FNIRS_MOTOR_PATH", ) -MNE_DATASETS['ucl_opm_auditory'] = dict( - archive_name='auditory_OPM_stationary.zip', - hash='md5:9ed0d8d554894542b56f8e7c4c0041fe', - url='https://osf.io/download/mwrt3/?version=1', - folder_name='auditory_OPM_stationary', - config_key='MNE_DATASETS_UCL_OPM_AUDITORY_PATH', +MNE_DATASETS["ucl_opm_auditory"] = dict( + archive_name="auditory_OPM_stationary.zip", + hash="md5:9ed0d8d554894542b56f8e7c4c0041fe", + url="https://osf.io/download/mwrt3/?version=1", + folder_name="auditory_OPM_stationary", + config_key="MNE_DATASETS_UCL_OPM_AUDITORY_PATH", ) -MNE_DATASETS['kiloword'] = dict( - archive_name='MNE-kiloword-data.tar.gz', - hash='md5:3a124170795abbd2e48aae8727e719a8', - url='https://osf.io/qkvf9/download?version=1', - folder_name='MNE-kiloword-data', - config_key='MNE_DATASETS_KILOWORD_PATH', +MNE_DATASETS["kiloword"] = dict( + archive_name="MNE-kiloword-data.tar.gz", + hash="md5:3a124170795abbd2e48aae8727e719a8", + url="https://osf.io/qkvf9/download?version=1", + folder_name="MNE-kiloword-data", + config_key="MNE_DATASETS_KILOWORD_PATH", ) -MNE_DATASETS['multimodal'] = dict( - archive_name='MNE-multimodal-data.tar.gz', - hash='md5:26ec847ae9ab80f58f204d09e2c08367', - url='https://ndownloader.figshare.com/files/5999598', - folder_name='MNE-multimodal-data', - config_key='MNE_DATASETS_MULTIMODAL_PATH', +MNE_DATASETS["multimodal"] = dict( + archive_name="MNE-multimodal-data.tar.gz", + hash="md5:26ec847ae9ab80f58f204d09e2c08367", + url="https://ndownloader.figshare.com/files/5999598", + folder_name="MNE-multimodal-data", + config_key="MNE_DATASETS_MULTIMODAL_PATH", ) -MNE_DATASETS['opm'] = dict( - archive_name='MNE-OPM-data.tar.gz', - hash='md5:370ad1dcfd5c47e029e692c85358a374', - url='https://osf.io/p6ae7/download?version=2', - folder_name='MNE-OPM-data', - config_key='MNE_DATASETS_OPM_PATH', +MNE_DATASETS["opm"] = dict( + archive_name="MNE-OPM-data.tar.gz", + hash="md5:370ad1dcfd5c47e029e692c85358a374", + url="https://osf.io/p6ae7/download?version=2", + folder_name="MNE-OPM-data", + config_key="MNE_DATASETS_OPM_PATH", ) -MNE_DATASETS['phantom_4dbti'] = dict( - archive_name='MNE-phantom-4DBTi.zip', - hash='md5:938a601440f3ffa780d20a17bae039ff', - url='https://osf.io/v2brw/download?version=2', - folder_name='MNE-phantom-4DBTi', - config_key='MNE_DATASETS_PHANTOM_4DBTI_PATH', +MNE_DATASETS["phantom_4dbti"] = dict( + archive_name="MNE-phantom-4DBTi.zip", + hash="md5:938a601440f3ffa780d20a17bae039ff", + url="https://osf.io/v2brw/download?version=2", + folder_name="MNE-phantom-4DBTi", + config_key="MNE_DATASETS_PHANTOM_4DBTI_PATH", ) -MNE_DATASETS['sample'] = dict( - archive_name='MNE-sample-data-processed.tar.gz', - hash='md5:e8f30c4516abdc12a0c08e6bae57409c', - url='https://osf.io/86qa2/download?version=6', - folder_name='MNE-sample-data', - config_key='MNE_DATASETS_SAMPLE_PATH', +MNE_DATASETS["sample"] = dict( + archive_name="MNE-sample-data-processed.tar.gz", + hash="md5:e8f30c4516abdc12a0c08e6bae57409c", + url="https://osf.io/86qa2/download?version=6", + folder_name="MNE-sample-data", + config_key="MNE_DATASETS_SAMPLE_PATH", ) -MNE_DATASETS['somato'] = dict( - archive_name='MNE-somato-data.tar.gz', - hash='md5:32fd2f6c8c7eb0784a1de6435273c48b', - url='https://osf.io/tp4sg/download?version=7', - folder_name='MNE-somato-data', - config_key='MNE_DATASETS_SOMATO_PATH' +MNE_DATASETS["somato"] = dict( + archive_name="MNE-somato-data.tar.gz", + hash="md5:32fd2f6c8c7eb0784a1de6435273c48b", + url="https://osf.io/tp4sg/download?version=7", + folder_name="MNE-somato-data", + config_key="MNE_DATASETS_SOMATO_PATH", ) -MNE_DATASETS['spm'] = dict( - archive_name='MNE-spm-face.tar.gz', - hash='md5:9f43f67150e3b694b523a21eb929ea75', - url='https://osf.io/je4s8/download?version=2', - folder_name='MNE-spm-face', - config_key='MNE_DATASETS_SPM_FACE_PATH', +MNE_DATASETS["spm"] = dict( + archive_name="MNE-spm-face.tar.gz", + hash="md5:9f43f67150e3b694b523a21eb929ea75", + url="https://osf.io/je4s8/download?version=2", + folder_name="MNE-spm-face", + config_key="MNE_DATASETS_SPM_FACE_PATH", ) # Visual 92 categories has the dataset split into 2 files. # We define a dictionary holding the items with the same # value across both files: folder name and configuration key. -MNE_DATASETS['visual_92_categories'] = dict( - folder_name='MNE-visual_92_categories-data', - config_key='MNE_DATASETS_VISUAL_92_CATEGORIES_PATH', +MNE_DATASETS["visual_92_categories"] = dict( + folder_name="MNE-visual_92_categories-data", + config_key="MNE_DATASETS_VISUAL_92_CATEGORIES_PATH", ) -MNE_DATASETS['visual_92_categories_1'] = dict( - archive_name='MNE-visual_92_categories-data-part1.tar.gz', - hash='md5:74f50bbeb65740903eadc229c9fa759f', - url='https://osf.io/8ejrs/download?version=1', - folder_name='MNE-visual_92_categories-data', - config_key='MNE_DATASETS_VISUAL_92_CATEGORIES_PATH', +MNE_DATASETS["visual_92_categories_1"] = dict( + archive_name="MNE-visual_92_categories-data-part1.tar.gz", + hash="md5:74f50bbeb65740903eadc229c9fa759f", + url="https://osf.io/8ejrs/download?version=1", + folder_name="MNE-visual_92_categories-data", + config_key="MNE_DATASETS_VISUAL_92_CATEGORIES_PATH", ) -MNE_DATASETS['visual_92_categories_2'] = dict( - archive_name='MNE-visual_92_categories-data-part2.tar.gz', - hash='md5:203410a98afc9df9ae8ba9f933370e20', - url='https://osf.io/t4yjp/download?version=1', - folder_name='MNE-visual_92_categories-data', - config_key='MNE_DATASETS_VISUAL_92_CATEGORIES_PATH', +MNE_DATASETS["visual_92_categories_2"] = dict( + archive_name="MNE-visual_92_categories-data-part2.tar.gz", + hash="md5:203410a98afc9df9ae8ba9f933370e20", + url="https://osf.io/t4yjp/download?version=1", + folder_name="MNE-visual_92_categories-data", + config_key="MNE_DATASETS_VISUAL_92_CATEGORIES_PATH", ) -MNE_DATASETS['mtrf'] = dict( - archive_name='mTRF_1.5.zip', - hash='md5:273a390ebbc48da2c3184b01a82e4636', - url='https://osf.io/h85s2/download?version=1', - folder_name='mTRF_1.5', - config_key='MNE_DATASETS_MTRF_PATH' +MNE_DATASETS["mtrf"] = dict( + archive_name="mTRF_1.5.zip", + hash="md5:273a390ebbc48da2c3184b01a82e4636", + url="https://osf.io/h85s2/download?version=1", + folder_name="mTRF_1.5", + config_key="MNE_DATASETS_MTRF_PATH", ) -MNE_DATASETS['refmeg_noise'] = dict( - archive_name='sample_reference_MEG_noise-raw.zip', - hash='md5:779fecd890d98b73a4832e717d7c7c45', - url='https://osf.io/drt6v/download?version=1', - folder_name='MNE-refmeg-noise-data', - config_key='MNE_DATASETS_REFMEG_NOISE_PATH' +MNE_DATASETS["refmeg_noise"] = dict( + archive_name="sample_reference_MEG_noise-raw.zip", + hash="md5:779fecd890d98b73a4832e717d7c7c45", + url="https://osf.io/drt6v/download?version=1", + folder_name="MNE-refmeg-noise-data", + config_key="MNE_DATASETS_REFMEG_NOISE_PATH", ) -MNE_DATASETS['ssvep'] = dict( - archive_name='ssvep_example_data.zip', - hash='md5:af866bbc0f921114ac9d683494fe87d6', - url='https://osf.io/z8h6k/download?version=5', - folder_name='ssvep-example-data', - config_key='MNE_DATASETS_SSVEP_PATH' +MNE_DATASETS["ssvep"] = dict( + archive_name="ssvep_example_data.zip", + hash="md5:af866bbc0f921114ac9d683494fe87d6", + url="https://osf.io/z8h6k/download?version=5", + folder_name="ssvep-example-data", + config_key="MNE_DATASETS_SSVEP_PATH", ) -MNE_DATASETS['erp_core'] = dict( - archive_name='MNE-ERP-CORE-data.tar.gz', - hash='md5:5866c0d6213bd7ac97f254c776f6c4b1', - url='https://osf.io/rzgba/download?version=1', - folder_name='MNE-ERP-CORE-data', - config_key='MNE_DATASETS_ERP_CORE_PATH', +MNE_DATASETS["erp_core"] = dict( + archive_name="MNE-ERP-CORE-data.tar.gz", + hash="md5:5866c0d6213bd7ac97f254c776f6c4b1", + url="https://osf.io/rzgba/download?version=1", + folder_name="MNE-ERP-CORE-data", + config_key="MNE_DATASETS_ERP_CORE_PATH", ) -MNE_DATASETS['epilepsy_ecog'] = dict( - archive_name='MNE-epilepsy-ecog-data.tar.gz', - hash='md5:ffb139174afa0f71ec98adbbb1729dea', - url='https://osf.io/z4epq/download?version=1', - folder_name='MNE-epilepsy-ecog-data', - config_key='MNE_DATASETS_EPILEPSY_ECOG_PATH', +MNE_DATASETS["epilepsy_ecog"] = dict( + archive_name="MNE-epilepsy-ecog-data.tar.gz", + hash="md5:ffb139174afa0f71ec98adbbb1729dea", + url="https://osf.io/z4epq/download?version=1", + folder_name="MNE-epilepsy-ecog-data", + config_key="MNE_DATASETS_EPILEPSY_ECOG_PATH", ) # Fieldtrip CMC dataset -MNE_DATASETS['fieldtrip_cmc'] = dict( - archive_name='SubjectCMC.zip', - hash='md5:6f9fd6520f9a66e20994423808d2528c', - url='https://osf.io/j9b6s/download?version=1', - folder_name='MNE-fieldtrip_cmc-data', - config_key='MNE_DATASETS_FIELDTRIP_CMC_PATH' +MNE_DATASETS["fieldtrip_cmc"] = dict( + archive_name="SubjectCMC.zip", + hash="md5:6f9fd6520f9a66e20994423808d2528c", + url="https://osf.io/j9b6s/download?version=1", + folder_name="MNE-fieldtrip_cmc-data", + config_key="MNE_DATASETS_FIELDTRIP_CMC_PATH", ) # brainstorm datasets: -MNE_DATASETS['bst_auditory'] = dict( - archive_name='bst_auditory.tar.gz', - hash='md5:fa371a889a5688258896bfa29dd1700b', - url='https://osf.io/5t9n8/download?version=1', - folder_name='MNE-brainstorm-data', - config_key='MNE_DATASETS_BRAINSTORM_PATH', +MNE_DATASETS["bst_auditory"] = dict( + archive_name="bst_auditory.tar.gz", + hash="md5:fa371a889a5688258896bfa29dd1700b", + url="https://osf.io/5t9n8/download?version=1", + folder_name="MNE-brainstorm-data", + config_key="MNE_DATASETS_BRAINSTORM_PATH", ) -MNE_DATASETS['bst_phantom_ctf'] = dict( - archive_name='bst_phantom_ctf.tar.gz', - hash='md5:80819cb7f5b92d1a5289db3fb6acb33c', - url='https://osf.io/sxr8y/download?version=1', - folder_name='MNE-brainstorm-data', - config_key='MNE_DATASETS_BRAINSTORM_PATH', +MNE_DATASETS["bst_phantom_ctf"] = dict( + archive_name="bst_phantom_ctf.tar.gz", + hash="md5:80819cb7f5b92d1a5289db3fb6acb33c", + url="https://osf.io/sxr8y/download?version=1", + folder_name="MNE-brainstorm-data", + config_key="MNE_DATASETS_BRAINSTORM_PATH", ) -MNE_DATASETS['bst_phantom_elekta'] = dict( - archive_name='bst_phantom_elekta.tar.gz', - hash='md5:1badccbe17998d18cc373526e86a7aaf', - url='https://osf.io/dpcku/download?version=1', - folder_name='MNE-brainstorm-data', - config_key='MNE_DATASETS_BRAINSTORM_PATH', +MNE_DATASETS["bst_phantom_elekta"] = dict( + archive_name="bst_phantom_elekta.tar.gz", + hash="md5:1badccbe17998d18cc373526e86a7aaf", + url="https://osf.io/dpcku/download?version=1", + folder_name="MNE-brainstorm-data", + config_key="MNE_DATASETS_BRAINSTORM_PATH", ) -MNE_DATASETS['bst_raw'] = dict( - archive_name='bst_raw.tar.gz', - hash='md5:fa2efaaec3f3d462b319bc24898f440c', - url='https://osf.io/9675n/download?version=2', - folder_name='MNE-brainstorm-data', - config_key='MNE_DATASETS_BRAINSTORM_PATH', +MNE_DATASETS["bst_raw"] = dict( + archive_name="bst_raw.tar.gz", + hash="md5:fa2efaaec3f3d462b319bc24898f440c", + url="https://osf.io/9675n/download?version=2", + folder_name="MNE-brainstorm-data", + config_key="MNE_DATASETS_BRAINSTORM_PATH", ) -MNE_DATASETS['bst_resting'] = dict( - archive_name='bst_resting.tar.gz', - hash='md5:70fc7bf9c3b97c4f2eab6260ee4a0430', - url='https://osf.io/m7bd3/download?version=3', - folder_name='MNE-brainstorm-data', - config_key='MNE_DATASETS_BRAINSTORM_PATH', +MNE_DATASETS["bst_resting"] = dict( + archive_name="bst_resting.tar.gz", + hash="md5:70fc7bf9c3b97c4f2eab6260ee4a0430", + url="https://osf.io/m7bd3/download?version=3", + folder_name="MNE-brainstorm-data", + config_key="MNE_DATASETS_BRAINSTORM_PATH", ) # HF-SEF -MNE_DATASETS['hf_sef_raw'] = dict( - archive_name='hf_sef_raw.tar.gz', - hash='md5:33934351e558542bafa9b262ac071168', - url='https://zenodo.org/record/889296/files/hf_sef_raw.tar.gz', - folder_name='hf_sef', - config_key='MNE_DATASETS_HF_SEF_PATH', +MNE_DATASETS["hf_sef_raw"] = dict( + archive_name="hf_sef_raw.tar.gz", + hash="md5:33934351e558542bafa9b262ac071168", + url="https://zenodo.org/record/889296/files/hf_sef_raw.tar.gz", + folder_name="hf_sef", + config_key="MNE_DATASETS_HF_SEF_PATH", ) -MNE_DATASETS['hf_sef_evoked'] = dict( - archive_name='hf_sef_evoked.tar.gz', - hash='md5:13d34cb5db584e00868677d8fb0aab2b', +MNE_DATASETS["hf_sef_evoked"] = dict( + archive_name="hf_sef_evoked.tar.gz", + hash="md5:13d34cb5db584e00868677d8fb0aab2b", # Zenodo can be slow, so we use the OSF mirror # url=('https://zenodo.org/record/3523071/files/' # 'hf_sef_evoked.tar.gz'), - url='https://osf.io/25f8d/download?version=2', - folder_name='hf_sef', - config_key='MNE_DATASETS_HF_SEF_PATH', + url="https://osf.io/25f8d/download?version=2", + folder_name="hf_sef", + config_key="MNE_DATASETS_HF_SEF_PATH", ) # "fake" dataset (for testing) -MNE_DATASETS['fake'] = dict( - archive_name='foo.tgz', - hash='md5:3194e9f7b46039bb050a74f3e1ae9908', - url=('https://github.com/mne-tools/mne-testing-data/raw/master/' - 'datasets/foo.tgz'), - folder_name='foo', - config_key='MNE_DATASETS_FAKE_PATH' +MNE_DATASETS["fake"] = dict( + archive_name="foo.tgz", + hash="md5:3194e9f7b46039bb050a74f3e1ae9908", + url=( + "https://github.com/mne-tools/mne-testing-data/raw/master/" "datasets/foo.tgz" + ), + folder_name="foo", + config_key="MNE_DATASETS_FAKE_PATH", ) # eyelink dataset -MNE_DATASETS['eyelink'] = dict( - archive_name='eyelink_example_data.zip', - hash='md5:081950c05f35267458d9c751e178f161', - url=('https://osf.io/r5ndq/download?version=1'), - folder_name='eyelink-example-data', - config_key='MNE_DATASETS_EYELINK_PATH' +MNE_DATASETS["eyelink"] = dict( + archive_name="eyelink_example_data.zip", + hash="md5:081950c05f35267458d9c751e178f161", + url=("https://osf.io/r5ndq/download?version=1"), + folder_name="eyelink-example-data", + config_key="MNE_DATASETS_EYELINK_PATH", ) diff --git a/mne/datasets/eegbci/eegbci.py b/mne/datasets/eegbci/eegbci.py index fd2b0a71e24..4d5b3f9b7d6 100644 --- a/mne/datasets/eegbci/eegbci.py +++ b/mne/datasets/eegbci/eegbci.py @@ -10,8 +10,7 @@ import time from ...utils import _url_to_local_path, verbose, logger -from ..utils import (_do_path_update, _get_path, _log_time_size, - _downloader_params) +from ..utils import _do_path_update, _get_path, _log_time_size, _downloader_params # TODO: remove try/except when our min version is py 3.9 try: @@ -20,12 +19,11 @@ from importlib_resources import files -EEGMI_URL = 'https://physionet.org/files/eegmmidb/1.0.0/' +EEGMI_URL = "https://physionet.org/files/eegmmidb/1.0.0/" @verbose -def data_path(url, path=None, force_update=False, update_path=None, *, - verbose=None): +def data_path(url, path=None, force_update=False, update_path=None, *, verbose=None): """Get path to local copy of EEGMMI dataset URL. This is a low-level function useful for getting a local copy of a @@ -73,10 +71,10 @@ def data_path(url, path=None, force_update=False, update_path=None, *, """ # noqa: E501 import pooch - key = 'MNE_DATASETS_EEGBCI_PATH' - name = 'EEGBCI' + key = "MNE_DATASETS_EEGBCI_PATH" + name = "EEGBCI" path = _get_path(path, key, name) - fname = 'MNE-eegbci-data' + fname = "MNE-eegbci-data" destination = _url_to_local_path(url, op.join(path, fname)) destinations = [destination] @@ -101,8 +99,15 @@ def data_path(url, path=None, force_update=False, update_path=None, *, @verbose -def load_data(subject, runs, path=None, force_update=False, update_path=None, - base_url=EEGMI_URL, verbose=None): # noqa: D301 +def load_data( + subject, + runs, + path=None, + force_update=False, + update_path=None, + base_url=EEGMI_URL, + verbose=None, +): # noqa: D301 """Get paths to local copies of EEGBCI dataset files. This will fetch data for the EEGBCI dataset :footcite:`SchalkEtAl2004`, which is also @@ -165,43 +170,46 @@ def load_data(subject, runs, path=None, force_update=False, update_path=None, .. footbibliography:: """ # noqa: E501 import pooch + t0 = time.time() - if not hasattr(runs, '__iter__'): + if not hasattr(runs, "__iter__"): runs = [runs] # get local storage path - config_key = 'MNE_DATASETS_EEGBCI_PATH' - folder = 'MNE-eegbci-data' - name = 'EEGBCI' + config_key = "MNE_DATASETS_EEGBCI_PATH" + folder = "MNE-eegbci-data" + name = "EEGBCI" path = _get_path(path, config_key, name) # extract path parts - pattern = r'(?:https?://.*)(files)/(eegmmidb)/(\d+\.\d+\.\d+)/?' + pattern = r"(?:https?://.*)(files)/(eegmmidb)/(\d+\.\d+\.\d+)/?" match = re.compile(pattern).match(base_url) if match is None: - raise ValueError('base_url does not match the expected EEGMI folder ' - 'structure. Please notify MNE-Python developers.') + raise ValueError( + "base_url does not match the expected EEGMI folder " + "structure. Please notify MNE-Python developers." + ) base_path = op.join(path, folder, *match.groups()) # create the download manager fetcher = pooch.create( path=base_path, base_url=base_url, - version=None, # Data versioning is decoupled from MNE-Python version. + version=None, # Data versioning is decoupled from MNE-Python version. registry=None, # Registry is loaded from file, below. - retry_if_failed=2 # 2 retries = 3 total attempts + retry_if_failed=2, # 2 retries = 3 total attempts ) # load the checksum registry - registry = files('mne').joinpath('data', 'eegbci_checksums.txt') + registry = files("mne").joinpath("data", "eegbci_checksums.txt") fetcher.load_registry(registry) # fetch the file(s) data_paths = [] sz = 0 for run in runs: - file_part = f'S{subject:03d}/S{subject:03d}R{run:02d}.edf' + file_part = f"S{subject:03d}/S{subject:03d}R{run:02d}.edf" destination = Path(base_path, file_part) data_paths.append(destination) if destination.exists(): @@ -210,7 +218,7 @@ def load_data(subject, runs, path=None, force_update=False, update_path=None, else: continue if sz == 0: # log once - logger.info('Downloading EEGBCI data') + logger.info("Downloading EEGBCI data") fetcher.fetch(file_part) # update path in config if desired sz += destination.stat().st_size @@ -230,11 +238,11 @@ def standardize(raw): """ rename = dict() for name in raw.ch_names: - std_name = name.strip('.') + std_name = name.strip(".") std_name = std_name.upper() - if std_name.endswith('Z'): - std_name = std_name[:-1] + 'z' - if std_name.startswith('FP'): - std_name = 'Fp' + std_name[2:] + if std_name.endswith("Z"): + std_name = std_name[:-1] + "z" + if std_name.startswith("FP"): + std_name = "Fp" + std_name[2:] rename[name] = std_name raw.rename_channels(rename) diff --git a/mne/datasets/eegbci/tests/test_eegbci.py b/mne/datasets/eegbci/tests/test_eegbci.py index e60988ff36c..c59c6802ede 100644 --- a/mne/datasets/eegbci/tests/test_eegbci.py +++ b/mne/datasets/eegbci/tests/test_eegbci.py @@ -8,7 +8,6 @@ def test_eegbci_download(tmp_path, fake_retrieve): """Test Sleep Physionet URL handling.""" for subj in range(4): - fnames = eegbci.load_data( - subj + 1, runs=[3], path=tmp_path, update_path=False) + fnames = eegbci.load_data(subj + 1, runs=[3], path=tmp_path, update_path=False) assert len(fnames) == 1, subj assert fake_retrieve.call_count == 4 diff --git a/mne/datasets/epilepsy_ecog/_data.py b/mne/datasets/epilepsy_ecog/_data.py index 33535c1aff0..b6cc93b92bd 100644 --- a/mne/datasets/epilepsy_ecog/_data.py +++ b/mne/datasets/epilepsy_ecog/_data.py @@ -3,25 +3,30 @@ # License: BSD Style. from ...utils import verbose -from ..utils import (_data_path_doc, _get_version, _version_doc, - _download_mne_dataset) +from ..utils import _data_path_doc, _get_version, _version_doc, _download_mne_dataset @verbose -def data_path(path=None, force_update=False, update_path=True, - download=True, *, verbose=None): # noqa: D103 +def data_path( + path=None, force_update=False, update_path=True, download=True, *, verbose=None +): # noqa: D103 return _download_mne_dataset( - name='epilepsy_ecog', processor='untar', path=path, - force_update=force_update, update_path=update_path, - download=download) + name="epilepsy_ecog", + processor="untar", + path=path, + force_update=force_update, + update_path=update_path, + download=download, + ) data_path.__doc__ = _data_path_doc.format( - name='epilepsy_ecog', conf='MNE_DATASETS_EPILEPSY_ECOG_PATH') + name="epilepsy_ecog", conf="MNE_DATASETS_EPILEPSY_ECOG_PATH" +) def get_version(): # noqa: D103 - return _get_version('epilepsy_ecog') + return _get_version("epilepsy_ecog") -get_version.__doc__ = _version_doc.format(name='epilepsy_ecog') +get_version.__doc__ = _version_doc.format(name="epilepsy_ecog") diff --git a/mne/datasets/erp_core/erp_core.py b/mne/datasets/erp_core/erp_core.py index 76bd62ca209..8f3aa1e2663 100644 --- a/mne/datasets/erp_core/erp_core.py +++ b/mne/datasets/erp_core/erp_core.py @@ -1,23 +1,28 @@ from ...utils import verbose -from ..utils import (_data_path_doc, - _get_version, _version_doc, _download_mne_dataset) +from ..utils import _data_path_doc, _get_version, _version_doc, _download_mne_dataset @verbose -def data_path(path=None, force_update=False, update_path=True, - download=True, *, verbose=None): # noqa: D103 +def data_path( + path=None, force_update=False, update_path=True, download=True, *, verbose=None +): # noqa: D103 return _download_mne_dataset( - name='erp_core', processor='untar', path=path, - force_update=force_update, update_path=update_path, - download=download) + name="erp_core", + processor="untar", + path=path, + force_update=force_update, + update_path=update_path, + download=download, + ) -data_path.__doc__ = _data_path_doc.format(name='erp_core', - conf='MNE_DATASETS_ERP_CORE_PATH') +data_path.__doc__ = _data_path_doc.format( + name="erp_core", conf="MNE_DATASETS_ERP_CORE_PATH" +) def get_version(): # noqa: D103 - return _get_version('erp_core') + return _get_version("erp_core") -get_version.__doc__ = _version_doc.format(name='erp_core') +get_version.__doc__ = _version_doc.format(name="erp_core") diff --git a/mne/datasets/eyelink/eyelink.py b/mne/datasets/eyelink/eyelink.py index a08e338ab33..f0a349c3c16 100644 --- a/mne/datasets/eyelink/eyelink.py +++ b/mne/datasets/eyelink/eyelink.py @@ -2,25 +2,30 @@ # License: BSD Style. from ...utils import verbose -from ..utils import (_data_path_doc, _get_version, _version_doc, - _download_mne_dataset) +from ..utils import _data_path_doc, _get_version, _version_doc, _download_mne_dataset @verbose -def data_path(path=None, force_update=False, update_path=True, - download=True, *, verbose=None): # noqa: D103 +def data_path( + path=None, force_update=False, update_path=True, download=True, *, verbose=None +): # noqa: D103 return _download_mne_dataset( - name='eyelink', processor='unzip', path=path, - force_update=force_update, update_path=update_path, - download=download) + name="eyelink", + processor="unzip", + path=path, + force_update=force_update, + update_path=update_path, + download=download, + ) -data_path.__doc__ = _data_path_doc.format(name='eyelink', - conf='MNE_DATASETS_EYELINK_PATH') +data_path.__doc__ = _data_path_doc.format( + name="eyelink", conf="MNE_DATASETS_EYELINK_PATH" +) def get_version(): # noqa: D103 - return _get_version('eyelink') + return _get_version("eyelink") -get_version.__doc__ = _version_doc.format(name='eyelink') +get_version.__doc__ = _version_doc.format(name="eyelink") diff --git a/mne/datasets/fieldtrip_cmc/fieldtrip_cmc.py b/mne/datasets/fieldtrip_cmc/fieldtrip_cmc.py index d7abe1c68f0..cdce53d57a8 100644 --- a/mne/datasets/fieldtrip_cmc/fieldtrip_cmc.py +++ b/mne/datasets/fieldtrip_cmc/fieldtrip_cmc.py @@ -3,25 +3,30 @@ # # License: BSD Style. from ...utils import verbose -from ..utils import (_data_path_doc, - _get_version, _version_doc, _download_mne_dataset) +from ..utils import _data_path_doc, _get_version, _version_doc, _download_mne_dataset @verbose -def data_path(path=None, force_update=False, update_path=True, - download=True, *, verbose=None): # noqa: D103 +def data_path( + path=None, force_update=False, update_path=True, download=True, *, verbose=None +): # noqa: D103 return _download_mne_dataset( - name='fieldtrip_cmc', processor='nested_unzip', path=path, - force_update=force_update, update_path=update_path, - download=download) + name="fieldtrip_cmc", + processor="nested_unzip", + path=path, + force_update=force_update, + update_path=update_path, + download=download, + ) data_path.__doc__ = _data_path_doc.format( - name='fieldtrip_cmc', conf='MNE_DATASETS_FIELDTRIP_CMC_PATH') + name="fieldtrip_cmc", conf="MNE_DATASETS_FIELDTRIP_CMC_PATH" +) def get_version(): # noqa: D103 - return _get_version('fieldtrip_cmc') + return _get_version("fieldtrip_cmc") -get_version.__doc__ = _version_doc.format(name='fieldtrip_cmc') +get_version.__doc__ = _version_doc.format(name="fieldtrip_cmc") diff --git a/mne/datasets/fnirs_motor/fnirs_motor.py b/mne/datasets/fnirs_motor/fnirs_motor.py index ce0294f9f4e..2c49a32c891 100644 --- a/mne/datasets/fnirs_motor/fnirs_motor.py +++ b/mne/datasets/fnirs_motor/fnirs_motor.py @@ -2,25 +2,30 @@ # License: BSD Style. from ...utils import verbose -from ..utils import (_data_path_doc, _get_version, _version_doc, - _download_mne_dataset) +from ..utils import _data_path_doc, _get_version, _version_doc, _download_mne_dataset @verbose -def data_path(path=None, force_update=False, update_path=True, - download=True, *, verbose=None): # noqa: D103 +def data_path( + path=None, force_update=False, update_path=True, download=True, *, verbose=None +): # noqa: D103 return _download_mne_dataset( - name='fnirs_motor', processor='untar', path=path, - force_update=force_update, update_path=update_path, - download=download) + name="fnirs_motor", + processor="untar", + path=path, + force_update=force_update, + update_path=update_path, + download=download, + ) -data_path.__doc__ = _data_path_doc.format(name='fnirs_motor', - conf='MNE_DATASETS_FNIRS_MOTOR_PATH') +data_path.__doc__ = _data_path_doc.format( + name="fnirs_motor", conf="MNE_DATASETS_FNIRS_MOTOR_PATH" +) def get_version(): # noqa: D103 - return _get_version('fnirs_motor') + return _get_version("fnirs_motor") -get_version.__doc__ = _version_doc.format(name='fnirs_motor') +get_version.__doc__ = _version_doc.format(name="fnirs_motor") diff --git a/mne/datasets/hf_sef/hf_sef.py b/mne/datasets/hf_sef/hf_sef.py index 401c3636017..66c25ad12be 100644 --- a/mne/datasets/hf_sef/hf_sef.py +++ b/mne/datasets/hf_sef/hf_sef.py @@ -11,8 +11,9 @@ @verbose -def data_path(dataset='evoked', path=None, force_update=False, - update_path=True, *, verbose=None): +def data_path( + dataset="evoked", path=None, force_update=False, update_path=True, *, verbose=None +): """Get path to local copy of the high frequency SEF dataset. Gets a local copy of the high frequency SEF MEG dataset @@ -46,33 +47,38 @@ def data_path(dataset='evoked', path=None, force_update=False, ---------- .. footbibliography:: """ - _check_option('dataset', dataset, ('evoked', 'raw')) - if dataset == 'raw': - data_dict = MNE_DATASETS['hf_sef_raw'] - data_dict['dataset_name'] = 'hf_sef_raw' + _check_option("dataset", dataset, ("evoked", "raw")) + if dataset == "raw": + data_dict = MNE_DATASETS["hf_sef_raw"] + data_dict["dataset_name"] = "hf_sef_raw" else: - data_dict = MNE_DATASETS['hf_sef_evoked'] - data_dict['dataset_name'] = 'hf_sef_evoked' - config_key = data_dict['config_key'] - folder_name = data_dict['folder_name'] + data_dict = MNE_DATASETS["hf_sef_evoked"] + data_dict["dataset_name"] = "hf_sef_evoked" + config_key = data_dict["config_key"] + folder_name = data_dict["folder_name"] # get download path for specific dataset path = _get_path(path=path, key=config_key, name=folder_name) final_path = op.join(path, folder_name) - megdir = op.join(final_path, 'MEG', 'subject_a') - has_raw = (dataset == 'raw' and op.isdir(megdir) and - any('raw' in filename for filename in os.listdir(megdir))) - has_evoked = (dataset == 'evoked' and - op.isdir(op.join(final_path, 'subjects'))) + megdir = op.join(final_path, "MEG", "subject_a") + has_raw = ( + dataset == "raw" + and op.isdir(megdir) + and any("raw" in filename for filename in os.listdir(megdir)) + ) + has_evoked = dataset == "evoked" and op.isdir(op.join(final_path, "subjects")) # data not there, or force_update requested: if has_raw or has_evoked and not force_update: - _do_path_update(path, update_path, config_key, - folder_name) + _do_path_update(path, update_path, config_key, folder_name) return final_path # instantiate processor that unzips file - data_path = _download_mne_dataset(name=data_dict['dataset_name'], - processor='untar', path=path, - force_update=force_update, - update_path=update_path, download=True) + data_path = _download_mne_dataset( + name=data_dict["dataset_name"], + processor="untar", + path=path, + force_update=force_update, + update_path=update_path, + download=True, + ) return data_path diff --git a/mne/datasets/kiloword/kiloword.py b/mne/datasets/kiloword/kiloword.py index c011365bad3..c6f437ab36e 100644 --- a/mne/datasets/kiloword/kiloword.py +++ b/mne/datasets/kiloword/kiloword.py @@ -1,12 +1,13 @@ # License: BSD Style. from ...utils import verbose -from ..utils import (_get_version, _version_doc, _download_mne_dataset) +from ..utils import _get_version, _version_doc, _download_mne_dataset @verbose -def data_path(path=None, force_update=False, update_path=True, - download=True, *, verbose=None): +def data_path( + path=None, force_update=False, update_path=True, download=True, *, verbose=None +): """Get path to local copy of the kiloword dataset. This is the dataset from :footcite:`DufauEtAl2015`. @@ -44,14 +45,18 @@ def data_path(path=None, force_update=False, update_path=True, .. footbibliography:: """ return _download_mne_dataset( - name='kiloword', processor='untar', path=path, - force_update=force_update, update_path=update_path, - download=download) + name="kiloword", + processor="untar", + path=path, + force_update=force_update, + update_path=update_path, + download=download, + ) def get_version(): """Get dataset version.""" - return _get_version('kiloword') + return _get_version("kiloword") -get_version.__doc__ = _version_doc.format(name='kiloword') +get_version.__doc__ = _version_doc.format(name="kiloword") diff --git a/mne/datasets/limo/limo.py b/mne/datasets/limo/limo.py index e0f1d0f9fa9..47055a0bd91 100644 --- a/mne/datasets/limo/limo.py +++ b/mne/datasets/limo/limo.py @@ -12,17 +12,17 @@ from ...epochs import EpochsArray from ...io.meas_info import create_info from ...utils import _check_pandas_installed, verbose, logger -from ..utils import (_get_path, _do_path_update, _log_time_size, - _downloader_params) +from ..utils import _get_path, _do_path_update, _log_time_size, _downloader_params # root url for LIMO files -root_url = 'https://files.de-1.osf.io/v1/resources/52rea/providers/osfstorage/' +root_url = "https://files.de-1.osf.io/v1/resources/52rea/providers/osfstorage/" @verbose -def data_path(subject, path=None, force_update=False, update_path=None, *, - verbose=None): +def data_path( + subject, path=None, force_update=False, update_path=None, *, verbose=None +): """Get path to local copy of LIMO dataset URL. This is a low-level function useful for getting a local copy of the @@ -69,110 +69,183 @@ def data_path(subject, path=None, force_update=False, update_path=None, *, .. footbibliography:: """ # noqa: E501 import pooch + t0 = time.time() downloader = pooch.HTTPDownloader(**_downloader_params()) # local storage patch - config_key = 'MNE_DATASETS_LIMO_PATH' - name = 'LIMO' - subj = f'S{subject}' + config_key = "MNE_DATASETS_LIMO_PATH" + name = "LIMO" + subj = f"S{subject}" path = _get_path(path, config_key, name) - base_path = op.join(path, 'MNE-limo-data') + base_path = op.join(path, "MNE-limo-data") subject_path = op.join(base_path, subj) # the remote URLs are in the form of UUIDs: urls = dict( - S18={'Yr.mat': '5cf839833a4d9500178a6ff8', - 'LIMO.mat': '5cf83907e650a2001ad592e4'}, - S17={'Yr.mat': '5cf838e83a4d9500168aeb76', - 'LIMO.mat': '5cf83867a542b80019c87602'}, - S16={'Yr.mat': '5cf83857e650a20019d5778f', - 'LIMO.mat': '5cf837dc3a4d9500188a64fe'}, - S15={'Yr.mat': '5cf837cce650a2001ad591e8', - 'LIMO.mat': '5cf83758a542b8001ac7d11d'}, - S14={'Yr.mat': '5cf837493a4d9500198a938f', - 'LIMO.mat': '5cf836e4a542b8001bc7cc53'}, - S13={'Yr.mat': '5cf836d23a4d9500178a6df7', - 'LIMO.mat': '5cf836543a4d9500168ae7cb'}, - S12={'Yr.mat': '5cf83643d4c7d700193e5954', - 'LIMO.mat': '5cf835193a4d9500178a6c92'}, - S11={'Yr.mat': '5cf8356ea542b8001cc81517', - 'LIMO.mat': '5cf834f7d4c7d700163daab8'}, - S10={'Yr.mat': '5cf833b0e650a20019d57454', - 'LIMO.mat': '5cf83204e650a20018d59eb2'}, - S9={'Yr.mat': '5cf83201a542b8001cc811cf', - 'LIMO.mat': '5cf8316c3a4d9500168ae13b'}, - S8={'Yr.mat': '5cf8326ce650a20017d60373', - 'LIMO.mat': '5cf8316d3a4d9500198a8dc5'}, - S7={'Yr.mat': '5cf834a03a4d9500168ae59b', - 'LIMO.mat': '5cf83069e650a20017d600d7'}, - S6={'Yr.mat': '5cf830e6a542b80019c86a70', - 'LIMO.mat': '5cf83057a542b80019c869ca'}, - S5={'Yr.mat': '5cf8115be650a20018d58041', - 'LIMO.mat': '5cf80c0bd4c7d700193e213c'}, - S4={'Yr.mat': '5cf810c9a542b80019c8450a', - 'LIMO.mat': '5cf80bf83a4d9500198a6eb4'}, - S3={'Yr.mat': '5cf80c55d4c7d700163d8f52', - 'LIMO.mat': '5cf80bdea542b80019c83cab'}, - S2={'Yr.mat': '5cde827123fec40019e01300', - 'LIMO.mat': '5cde82682a50c4001677c259'}, - S1={'Yr.mat': '5d6d3071536cf5001a8b0c78', - 'LIMO.mat': '5d6d305f6f41fc001a3151d8'}, + S18={ + "Yr.mat": "5cf839833a4d9500178a6ff8", + "LIMO.mat": "5cf83907e650a2001ad592e4", + }, + S17={ + "Yr.mat": "5cf838e83a4d9500168aeb76", + "LIMO.mat": "5cf83867a542b80019c87602", + }, + S16={ + "Yr.mat": "5cf83857e650a20019d5778f", + "LIMO.mat": "5cf837dc3a4d9500188a64fe", + }, + S15={ + "Yr.mat": "5cf837cce650a2001ad591e8", + "LIMO.mat": "5cf83758a542b8001ac7d11d", + }, + S14={ + "Yr.mat": "5cf837493a4d9500198a938f", + "LIMO.mat": "5cf836e4a542b8001bc7cc53", + }, + S13={ + "Yr.mat": "5cf836d23a4d9500178a6df7", + "LIMO.mat": "5cf836543a4d9500168ae7cb", + }, + S12={ + "Yr.mat": "5cf83643d4c7d700193e5954", + "LIMO.mat": "5cf835193a4d9500178a6c92", + }, + S11={ + "Yr.mat": "5cf8356ea542b8001cc81517", + "LIMO.mat": "5cf834f7d4c7d700163daab8", + }, + S10={ + "Yr.mat": "5cf833b0e650a20019d57454", + "LIMO.mat": "5cf83204e650a20018d59eb2", + }, + S9={ + "Yr.mat": "5cf83201a542b8001cc811cf", + "LIMO.mat": "5cf8316c3a4d9500168ae13b", + }, + S8={ + "Yr.mat": "5cf8326ce650a20017d60373", + "LIMO.mat": "5cf8316d3a4d9500198a8dc5", + }, + S7={ + "Yr.mat": "5cf834a03a4d9500168ae59b", + "LIMO.mat": "5cf83069e650a20017d600d7", + }, + S6={ + "Yr.mat": "5cf830e6a542b80019c86a70", + "LIMO.mat": "5cf83057a542b80019c869ca", + }, + S5={ + "Yr.mat": "5cf8115be650a20018d58041", + "LIMO.mat": "5cf80c0bd4c7d700193e213c", + }, + S4={ + "Yr.mat": "5cf810c9a542b80019c8450a", + "LIMO.mat": "5cf80bf83a4d9500198a6eb4", + }, + S3={ + "Yr.mat": "5cf80c55d4c7d700163d8f52", + "LIMO.mat": "5cf80bdea542b80019c83cab", + }, + S2={ + "Yr.mat": "5cde827123fec40019e01300", + "LIMO.mat": "5cde82682a50c4001677c259", + }, + S1={ + "Yr.mat": "5d6d3071536cf5001a8b0c78", + "LIMO.mat": "5d6d305f6f41fc001a3151d8", + }, ) # these can't be in the registry file (mne/data/dataset_checksums.txt) # because of filename duplication hashes = dict( - S18={'Yr.mat': 'md5:87f883d442737971a80fc0a35d057e51', - 'LIMO.mat': 'md5:8b4879646f65d7876fa4adf2e40162c5'}, - S17={'Yr.mat': 'md5:7b667ec9eefd7a9996f61ae270e295ee', - 'LIMO.mat': 'md5:22eaca4e6fad54431fd61b307fc426b8'}, - S16={'Yr.mat': 'md5:c877afdb4897426421577e863a45921a', - 'LIMO.mat': 'md5:86672d7afbea1e8c39305bc3f852c8c2'}, - S15={'Yr.mat': 'md5:eea9e0140af598fefc08c886a6f05de5', - 'LIMO.mat': 'md5:aed5cb71ddbfd27c6a3ac7d3e613d07f'}, - S14={'Yr.mat': 'md5:8bd842cfd8588bd5d32e72fdbe70b66e', - 'LIMO.mat': 'md5:1e07d1f36f2eefad435a77530daf2680'}, - S13={'Yr.mat': 'md5:d7925d2af7288b8a5186dfb5dbb63d34', - 'LIMO.mat': 'md5:ba891015d2f9e447955fffa9833404ca'}, - S12={'Yr.mat': 'md5:0e1d05beaa4bf2726e0d0671b78fe41e', - 'LIMO.mat': 'md5:423fd479d71097995b6614ecb11df9ad'}, - S11={'Yr.mat': 'md5:1b0016fb9832e43b71f79c1992fcbbb1', - 'LIMO.mat': 'md5:1a281348c2a41ee899f42731d30cda70'}, - S10={'Yr.mat': 'md5:13c66f60e241b9a9cc576eaf1b55a417', - 'LIMO.mat': 'md5:3c4b41e221eb352a21bbef1a7e006f06'}, - S9={'Yr.mat': 'md5:3ae1d9c3a1d9325deea2f2dddd1ab507', - 'LIMO.mat': 'md5:5e204e2a4bcfe4f535b4b1af469b37f7'}, - S8={'Yr.mat': 'md5:7e9adbca4e03d8d7ce8ea07ccecdc8fd', - 'LIMO.mat': 'md5:88313c21d34428863590e586b2bc3408'}, - S7={'Yr.mat': 'md5:6b5290a6725ecebf1022d5d2789b186d', - 'LIMO.mat': 'md5:8c769219ebc14ce3f595063e84bfc0a9'}, - S6={'Yr.mat': 'md5:420c858a8340bf7c28910b7b0425dc5d', - 'LIMO.mat': 'md5:9cf4e1a405366d6bd0cc6d996e32fd63'}, - S5={'Yr.mat': 'md5:946436cfb474c8debae56ffb1685ecf3', - 'LIMO.mat': 'md5:241fac95d3a79d2cea081391fb7078bd'}, - S4={'Yr.mat': 'md5:c8216af78ac87b739e86e57b345cafdd', - 'LIMO.mat': 'md5:8e10ef36c2e075edc2f787581ba33459'}, - S3={'Yr.mat': 'md5:ff02e885b65b7b807146f259a30b1b5e', - 'LIMO.mat': 'md5:59b5fb3a9749003133608b5871309e2c'}, - S2={'Yr.mat': 'md5:a4329022e57fd07ceceb7d1735fd2718', - 'LIMO.mat': 'md5:98b284b567f2dd395c936366e404f2c6'}, - S1={'Yr.mat': 'md5:076c0ae78fb71d43409c1877707df30e', - 'LIMO.mat': 'md5:136c8cf89f8f111a11f531bd9fa6ae69'}, + S18={ + "Yr.mat": "md5:87f883d442737971a80fc0a35d057e51", + "LIMO.mat": "md5:8b4879646f65d7876fa4adf2e40162c5", + }, + S17={ + "Yr.mat": "md5:7b667ec9eefd7a9996f61ae270e295ee", + "LIMO.mat": "md5:22eaca4e6fad54431fd61b307fc426b8", + }, + S16={ + "Yr.mat": "md5:c877afdb4897426421577e863a45921a", + "LIMO.mat": "md5:86672d7afbea1e8c39305bc3f852c8c2", + }, + S15={ + "Yr.mat": "md5:eea9e0140af598fefc08c886a6f05de5", + "LIMO.mat": "md5:aed5cb71ddbfd27c6a3ac7d3e613d07f", + }, + S14={ + "Yr.mat": "md5:8bd842cfd8588bd5d32e72fdbe70b66e", + "LIMO.mat": "md5:1e07d1f36f2eefad435a77530daf2680", + }, + S13={ + "Yr.mat": "md5:d7925d2af7288b8a5186dfb5dbb63d34", + "LIMO.mat": "md5:ba891015d2f9e447955fffa9833404ca", + }, + S12={ + "Yr.mat": "md5:0e1d05beaa4bf2726e0d0671b78fe41e", + "LIMO.mat": "md5:423fd479d71097995b6614ecb11df9ad", + }, + S11={ + "Yr.mat": "md5:1b0016fb9832e43b71f79c1992fcbbb1", + "LIMO.mat": "md5:1a281348c2a41ee899f42731d30cda70", + }, + S10={ + "Yr.mat": "md5:13c66f60e241b9a9cc576eaf1b55a417", + "LIMO.mat": "md5:3c4b41e221eb352a21bbef1a7e006f06", + }, + S9={ + "Yr.mat": "md5:3ae1d9c3a1d9325deea2f2dddd1ab507", + "LIMO.mat": "md5:5e204e2a4bcfe4f535b4b1af469b37f7", + }, + S8={ + "Yr.mat": "md5:7e9adbca4e03d8d7ce8ea07ccecdc8fd", + "LIMO.mat": "md5:88313c21d34428863590e586b2bc3408", + }, + S7={ + "Yr.mat": "md5:6b5290a6725ecebf1022d5d2789b186d", + "LIMO.mat": "md5:8c769219ebc14ce3f595063e84bfc0a9", + }, + S6={ + "Yr.mat": "md5:420c858a8340bf7c28910b7b0425dc5d", + "LIMO.mat": "md5:9cf4e1a405366d6bd0cc6d996e32fd63", + }, + S5={ + "Yr.mat": "md5:946436cfb474c8debae56ffb1685ecf3", + "LIMO.mat": "md5:241fac95d3a79d2cea081391fb7078bd", + }, + S4={ + "Yr.mat": "md5:c8216af78ac87b739e86e57b345cafdd", + "LIMO.mat": "md5:8e10ef36c2e075edc2f787581ba33459", + }, + S3={ + "Yr.mat": "md5:ff02e885b65b7b807146f259a30b1b5e", + "LIMO.mat": "md5:59b5fb3a9749003133608b5871309e2c", + }, + S2={ + "Yr.mat": "md5:a4329022e57fd07ceceb7d1735fd2718", + "LIMO.mat": "md5:98b284b567f2dd395c936366e404f2c6", + }, + S1={ + "Yr.mat": "md5:076c0ae78fb71d43409c1877707df30e", + "LIMO.mat": "md5:136c8cf89f8f111a11f531bd9fa6ae69", + }, ) # create the download manager fetcher = pooch.create( path=subject_path, - base_url='', - version=None, # Data versioning is decoupled from MNE-Python version. + base_url="", + version=None, # Data versioning is decoupled from MNE-Python version. registry=hashes[subj], - urls={key: f'{root_url}{uuid}' for key, uuid in urls[subj].items()}, - retry_if_failed=2 # 2 retries = 3 total attempts + urls={key: f"{root_url}{uuid}" for key, uuid in urls[subj].items()}, + retry_if_failed=2, # 2 retries = 3 total attempts ) # use our logger level for pooch's logger too pooch.get_logger().setLevel(logger.getEffectiveLevel()) # fetch the data sz = 0 - for fname in ('LIMO.mat', 'Yr.mat'): + for fname in ("LIMO.mat", "Yr.mat"): destination = Path(subject_path, fname) if destination.exists(): if force_update: @@ -180,7 +253,7 @@ def data_path(subject, path=None, force_update=False, update_path=None, *, else: continue if sz == 0: # log once - logger.info('Downloading LIMO data') + logger.info("Downloading LIMO data") # fetch the remote file (if local file missing or has hash mismatch) fetcher.fetch(fname=fname, downloader=downloader) sz += destination.stat().st_size @@ -192,8 +265,7 @@ def data_path(subject, path=None, force_update=False, update_path=None, *, @verbose -def load_data(subject, path=None, force_update=False, update_path=None, - verbose=None): +def load_data(subject, path=None, force_update=False, update_path=None, verbose=None): """Fetch subjects epochs data for the LIMO data set. Parameters @@ -222,45 +294,45 @@ def load_data(subject, path=None, force_update=False, update_path=None, # subject in question if isinstance(subject, int) and 1 <= subject <= 18: - subj = 'S%i' % subject + subj = "S%i" % subject else: - raise ValueError('subject must be an int in the range from 1 to 18') + raise ValueError("subject must be an int in the range from 1 to 18") # set limo path, download and decompress files if not found limo_path = data_path(subject, path, force_update, update_path) # -- 1) import .mat files # epochs info - fname_info = op.join(limo_path, subj, 'LIMO.mat') + fname_info = op.join(limo_path, subj, "LIMO.mat") data_info = loadmat(fname_info) # number of epochs per condition - design = data_info['LIMO']['design'][0][0]['X'][0][0] - data_info = data_info['LIMO']['data'][0][0][0][0] + design = data_info["LIMO"]["design"][0][0]["X"][0][0] + data_info = data_info["LIMO"]["data"][0][0][0][0] # epochs data - fname_eeg = op.join(limo_path, subj, 'Yr.mat') + fname_eeg = op.join(limo_path, subj, "Yr.mat") data = loadmat(fname_eeg) # -- 2) get epochs information from structure # sampling rate - sfreq = data_info['sampling_rate'][0][0] + sfreq = data_info["sampling_rate"][0][0] # tmin and tmax - tmin = data_info['start'][0][0] + tmin = data_info["start"][0][0] # create events matrix sample = np.arange(len(design)) prev_id = np.zeros(len(design)) ev_id = design[:, 1] events = np.array([sample, prev_id, ev_id]).astype(int).T # event ids, such that Face B == 1 - event_id = {'Face/A': 0, 'Face/B': 1} + event_id = {"Face/A": 0, "Face/B": 1} # -- 3) extract channel labels from LIMO structure # get individual labels - labels = data_info['chanlocs']['labels'] + labels = data_info["chanlocs"]["labels"] labels = [label for label, *_ in labels[0]] # get montage - montage = make_standard_montage('biosemi128') + montage = make_standard_montage("biosemi128") # add external electrodes (e.g., eogs) - ch_names = montage.ch_names + ['EXG1', 'EXG2', 'EXG3', 'EXG4'] + ch_names = montage.ch_names + ["EXG1", "EXG2", "EXG3", "EXG4"] # match individual labels to labels in montage found_inds = [ind for ind, name in enumerate(ch_names) if name in labels] missing_chans = [name for name in ch_names if name not in labels] @@ -270,7 +342,7 @@ def load_data(subject, path=None, force_update=False, update_path=None, # data is stored as channels x time points x epochs # data['Yr'].shape # <-- see here # transpose to epochs x channels time points - data = np.transpose(data['Yr'], (2, 0, 1)) + data = np.transpose(data["Yr"], (2, 0, 1)) # initialize data in expected order temp_data = np.empty((data.shape[0], len(ch_names), data.shape[2])) # copy over the non-missing data @@ -287,15 +359,16 @@ def load_data(subject, path=None, force_update=False, update_path=None, info = create_info(ch_names, sfreq, types).set_montage(montage) # get faces and noise variables from design matrix event_list = list(events[:, 2]) - faces = ['B' if event else 'A' for event in event_list] + faces = ["B" if event else "A" for event in event_list] noise = list(design[:, 2]) # create epochs metadata - metadata = {'face': faces, 'phase-coherence': noise} + metadata = {"face": faces, "phase-coherence": noise} metadata = pd.DataFrame(metadata) # -- 6) Create custom epochs array - epochs = EpochsArray(data, info, events, tmin, event_id, metadata=metadata, - verbose=False) - epochs.info['bads'] = missing_chans # missing channels are marked as bad. + epochs = EpochsArray( + data, info, events, tmin, event_id, metadata=metadata, verbose=False + ) + epochs.info["bads"] = missing_chans # missing channels are marked as bad. return epochs diff --git a/mne/datasets/misc/_misc.py b/mne/datasets/misc/_misc.py index 85f65332ad1..443aa24787b 100644 --- a/mne/datasets/misc/_misc.py +++ b/mne/datasets/misc/_misc.py @@ -8,19 +8,25 @@ @verbose -def data_path(path=None, force_update=False, update_path=True, - download=True, *, verbose=None): # noqa: D103 +def data_path( + path=None, force_update=False, update_path=True, download=True, *, verbose=None +): # noqa: D103 return _download_mne_dataset( - name='misc', processor='untar', path=path, - force_update=force_update, update_path=update_path, - download=download) + name="misc", + processor="untar", + path=path, + force_update=force_update, + update_path=update_path, + download=download, + ) def _pytest_mark(): import pytest + return pytest.mark.skipif( - not has_dataset(name='misc'), reason='Requires misc dataset') + not has_dataset(name="misc"), reason="Requires misc dataset" + ) -data_path.__doc__ = _data_path_doc.format(name='misc', - conf='MNE_DATASETS_MISC_PATH') +data_path.__doc__ = _data_path_doc.format(name="misc", conf="MNE_DATASETS_MISC_PATH") diff --git a/mne/datasets/mtrf/mtrf.py b/mne/datasets/mtrf/mtrf.py index bfc5cd0ba58..1ce4f741a4f 100644 --- a/mne/datasets/mtrf/mtrf.py +++ b/mne/datasets/mtrf/mtrf.py @@ -3,24 +3,27 @@ # License: BSD Style. from ...utils import verbose -from ..utils import (_data_path_doc, - _get_version, _version_doc, _download_mne_dataset) +from ..utils import _data_path_doc, _get_version, _version_doc, _download_mne_dataset -data_name = 'mtrf' +data_name = "mtrf" @verbose -def data_path(path=None, force_update=False, update_path=True, - download=True, *, verbose=None): # noqa: D103 +def data_path( + path=None, force_update=False, update_path=True, download=True, *, verbose=None +): # noqa: D103 return _download_mne_dataset( - name=data_name, processor='unzip', path=path, - force_update=force_update, update_path=update_path, - download=download) + name=data_name, + processor="unzip", + path=path, + force_update=force_update, + update_path=update_path, + download=download, + ) -data_path.__doc__ = _data_path_doc.format(name=data_name, - conf='MNE_DATASETS_MTRF_PATH') +data_path.__doc__ = _data_path_doc.format(name=data_name, conf="MNE_DATASETS_MTRF_PATH") def get_version(): # noqa: D103 diff --git a/mne/datasets/multimodal/multimodal.py b/mne/datasets/multimodal/multimodal.py index 4ef0fd38efb..84fbf662e5f 100644 --- a/mne/datasets/multimodal/multimodal.py +++ b/mne/datasets/multimodal/multimodal.py @@ -4,25 +4,30 @@ # License: BSD Style. from ...utils import verbose -from ..utils import (_data_path_doc, _get_version, _version_doc, - _download_mne_dataset) +from ..utils import _data_path_doc, _get_version, _version_doc, _download_mne_dataset @verbose -def data_path(path=None, force_update=False, update_path=True, - download=True, *, verbose=None): # noqa: D103 +def data_path( + path=None, force_update=False, update_path=True, download=True, *, verbose=None +): # noqa: D103 return _download_mne_dataset( - name='multimodal', processor='untar', path=path, - force_update=force_update, update_path=update_path, - download=download) + name="multimodal", + processor="untar", + path=path, + force_update=force_update, + update_path=update_path, + download=download, + ) -data_path.__doc__ = _data_path_doc.format(name='multimodal', - conf='MNE_DATASETS_MULTIMODAL_PATH') +data_path.__doc__ = _data_path_doc.format( + name="multimodal", conf="MNE_DATASETS_MULTIMODAL_PATH" +) def get_version(): # noqa: D103 - return _get_version('multimodal') + return _get_version("multimodal") -get_version.__doc__ = _version_doc.format(name='multimodal') +get_version.__doc__ = _version_doc.format(name="multimodal") diff --git a/mne/datasets/opm/opm.py b/mne/datasets/opm/opm.py index 014e91f2029..b2b24f2e3f8 100644 --- a/mne/datasets/opm/opm.py +++ b/mne/datasets/opm/opm.py @@ -4,25 +4,28 @@ # License: BSD Style. from ...utils import verbose -from ..utils import (_data_path_doc, _get_version, _version_doc, - _download_mne_dataset) +from ..utils import _data_path_doc, _get_version, _version_doc, _download_mne_dataset @verbose -def data_path(path=None, force_update=False, update_path=True, - download=True, *, verbose=None): # noqa: D103 +def data_path( + path=None, force_update=False, update_path=True, download=True, *, verbose=None +): # noqa: D103 return _download_mne_dataset( - name='opm', processor='untar', path=path, - force_update=force_update, update_path=update_path, - download=download) + name="opm", + processor="untar", + path=path, + force_update=force_update, + update_path=update_path, + download=download, + ) -data_path.__doc__ = _data_path_doc.format(name='opm', - conf='MNE_DATASETS_OPML_PATH') +data_path.__doc__ = _data_path_doc.format(name="opm", conf="MNE_DATASETS_OPML_PATH") def get_version(): # noqa: D103 - return _get_version('opm') + return _get_version("opm") -get_version.__doc__ = _version_doc.format(name='opm') +get_version.__doc__ = _version_doc.format(name="opm") diff --git a/mne/datasets/phantom_4dbti/phantom_4dbti.py b/mne/datasets/phantom_4dbti/phantom_4dbti.py index 2154dee99ce..59c42416d5a 100644 --- a/mne/datasets/phantom_4dbti/phantom_4dbti.py +++ b/mne/datasets/phantom_4dbti/phantom_4dbti.py @@ -3,25 +3,30 @@ # License: BSD Style. from ...utils import verbose -from ..utils import (_data_path_doc, _get_version, _version_doc, - _download_mne_dataset) +from ..utils import _data_path_doc, _get_version, _version_doc, _download_mne_dataset @verbose -def data_path(path=None, force_update=False, update_path=True, - download=True, *, verbose=None): # noqa: D103 +def data_path( + path=None, force_update=False, update_path=True, download=True, *, verbose=None +): # noqa: D103 return _download_mne_dataset( - name='phantom_4dbti', processor='unzip', path=path, - force_update=force_update, update_path=update_path, - download=download) + name="phantom_4dbti", + processor="unzip", + path=path, + force_update=force_update, + update_path=update_path, + download=download, + ) data_path.__doc__ = _data_path_doc.format( - name='phantom_4dbti', conf='MNE_DATASETS_PHANTOM_4DBTI_PATH') + name="phantom_4dbti", conf="MNE_DATASETS_PHANTOM_4DBTI_PATH" +) def get_version(): # noqa: D103 - return _get_version('phantom_4dbti') + return _get_version("phantom_4dbti") -get_version.__doc__ = _version_doc.format(name='phantom_4dbti') +get_version.__doc__ = _version_doc.format(name="phantom_4dbti") diff --git a/mne/datasets/refmeg_noise/refmeg_noise.py b/mne/datasets/refmeg_noise/refmeg_noise.py index 2027a31bacc..e77f3eefaf0 100644 --- a/mne/datasets/refmeg_noise/refmeg_noise.py +++ b/mne/datasets/refmeg_noise/refmeg_noise.py @@ -2,25 +2,30 @@ # License: BSD Style. from ...utils import verbose -from ..utils import (_data_path_doc, _get_version, _version_doc, - _download_mne_dataset) +from ..utils import _data_path_doc, _get_version, _version_doc, _download_mne_dataset @verbose -def data_path(path=None, force_update=False, update_path=True, - download=True, *, verbose=None): # noqa: D103 +def data_path( + path=None, force_update=False, update_path=True, download=True, *, verbose=None +): # noqa: D103 return _download_mne_dataset( - name='refmeg_noise', processor='unzip', path=path, - force_update=force_update, update_path=update_path, - download=download) + name="refmeg_noise", + processor="unzip", + path=path, + force_update=force_update, + update_path=update_path, + download=download, + ) data_path.__doc__ = _data_path_doc.format( - name='refmeg_noise', conf='MNE_DATASETS_REFMEG_NOISE_PATH') + name="refmeg_noise", conf="MNE_DATASETS_REFMEG_NOISE_PATH" +) def get_version(): # noqa: D103 - return _get_version('refmeg_noise') + return _get_version("refmeg_noise") -get_version.__doc__ = _version_doc.format(name='refmeg_noise') +get_version.__doc__ = _version_doc.format(name="refmeg_noise") diff --git a/mne/datasets/sample/sample.py b/mne/datasets/sample/sample.py index 4876b7bc7f7..f5ca6de24c4 100644 --- a/mne/datasets/sample/sample.py +++ b/mne/datasets/sample/sample.py @@ -4,25 +4,30 @@ # License: BSD Style. from ...utils import verbose -from ..utils import (_data_path_doc, _get_version, _version_doc, - _download_mne_dataset) +from ..utils import _data_path_doc, _get_version, _version_doc, _download_mne_dataset @verbose -def data_path(path=None, force_update=False, update_path=True, - download=True, *, verbose=None): # noqa: D103 +def data_path( + path=None, force_update=False, update_path=True, download=True, *, verbose=None +): # noqa: D103 return _download_mne_dataset( - name='sample', processor='untar', path=path, - force_update=force_update, update_path=update_path, - download=download) + name="sample", + processor="untar", + path=path, + force_update=force_update, + update_path=update_path, + download=download, + ) -data_path.__doc__ = _data_path_doc.format(name='sample', - conf='MNE_DATASETS_SAMPLE_PATH') +data_path.__doc__ = _data_path_doc.format( + name="sample", conf="MNE_DATASETS_SAMPLE_PATH" +) def get_version(): # noqa: D103 - return _get_version('sample') + return _get_version("sample") -get_version.__doc__ = _version_doc.format(name='sample') +get_version.__doc__ = _version_doc.format(name="sample") diff --git a/mne/datasets/sleep_physionet/_utils.py b/mne/datasets/sleep_physionet/_utils.py index 50f992e7803..bca3284d73b 100644 --- a/mne/datasets/sleep_physionet/_utils.py +++ b/mne/datasets/sleep_physionet/_utils.py @@ -8,27 +8,30 @@ import numpy as np -from ...utils import (verbose, _TempDir, _check_pandas_installed, - _on_missing) +from ...utils import verbose, _TempDir, _check_pandas_installed, _on_missing from ..utils import _get_path, _downloader_params -AGE_SLEEP_RECORDS = op.join(op.dirname(__file__), 'age_records.csv') -TEMAZEPAM_SLEEP_RECORDS = op.join(op.dirname(__file__), - 'temazepam_records.csv') +AGE_SLEEP_RECORDS = op.join(op.dirname(__file__), "age_records.csv") +TEMAZEPAM_SLEEP_RECORDS = op.join(op.dirname(__file__), "temazepam_records.csv") -TEMAZEPAM_RECORDS_URL = 'https://physionet.org/physiobank/database/sleep-edfx/ST-subjects.xls' # noqa: E501 -TEMAZEPAM_RECORDS_URL_SHA1 = 'f52fffe5c18826a2bd4c5d5cb375bb4a9008c885' +TEMAZEPAM_RECORDS_URL = ( + "https://physionet.org/physiobank/database/sleep-edfx/ST-subjects.xls" # noqa: E501 +) +TEMAZEPAM_RECORDS_URL_SHA1 = "f52fffe5c18826a2bd4c5d5cb375bb4a9008c885" -AGE_RECORDS_URL = 'https://physionet.org/physiobank/database/sleep-edfx/SC-subjects.xls' # noqa: E501 -AGE_RECORDS_URL_SHA1 = '0ba6650892c5d33a8e2b3f62ce1cc9f30438c54f' +AGE_RECORDS_URL = ( + "https://physionet.org/physiobank/database/sleep-edfx/SC-subjects.xls" # noqa: E501 +) +AGE_RECORDS_URL_SHA1 = "0ba6650892c5d33a8e2b3f62ce1cc9f30438c54f" -sha1sums_fname = op.join(op.dirname(__file__), 'SHA1SUMS') +sha1sums_fname = op.join(op.dirname(__file__), "SHA1SUMS") def _fetch_one(fname, hashsum, path, force_update, base_url): import pooch + # Fetch the file - url = base_url + '/' + fname + url = base_url + "/" + fname destination = op.join(path, fname) if op.isfile(destination) and not force_update: return destination, False @@ -42,7 +45,7 @@ def _fetch_one(fname, hashsum, path, force_update, base_url): known_hash=f"sha1:{hashsum}", path=path, downloader=downloader, - fname=fname + fname=fname, ) return destination, True @@ -75,10 +78,10 @@ def _data_path(path=None, verbose=None): ---------- .. footbibliography:: """ # noqa: E501 - key = 'PHYSIONET_SLEEP_PATH' - name = 'PHYSIONET_SLEEP' + key = "PHYSIONET_SLEEP_PATH" + name = "PHYSIONET_SLEEP" path = _get_path(path, key, name) - return op.join(path, 'physionet-sleep-data') + return op.join(path, "physionet-sleep-data") def _update_sleep_temazepam_records(fname=TEMAZEPAM_SLEEP_RECORDS): @@ -89,7 +92,7 @@ def _update_sleep_temazepam_records(fname=TEMAZEPAM_SLEEP_RECORDS): tmp = _TempDir() # Download subjects info. - subjects_fname = op.join(tmp, 'ST-subjects.xls') + subjects_fname = op.join(tmp, "ST-subjects.xls") downloader = pooch.HTTPDownloader(**_downloader_params()) pooch.retrieve( url=TEMAZEPAM_RECORDS_URL, @@ -100,44 +103,60 @@ def _update_sleep_temazepam_records(fname=TEMAZEPAM_SLEEP_RECORDS): ) # Load and Massage the checksums. - sha1_df = pd.read_csv(sha1sums_fname, sep=' ', header=None, - names=['sha', 'fname'], engine='python') - select_age_records = (sha1_df.fname.str.startswith('ST') & - sha1_df.fname.str.endswith('edf')) + sha1_df = pd.read_csv( + sha1sums_fname, sep=" ", header=None, names=["sha", "fname"], engine="python" + ) + select_age_records = sha1_df.fname.str.startswith( + "ST" + ) & sha1_df.fname.str.endswith("edf") sha1_df = sha1_df[select_age_records] - sha1_df['id'] = [name[:6] for name in sha1_df.fname] + sha1_df["id"] = [name[:6] for name in sha1_df.fname] # Load and massage the data. data = pd.read_excel(subjects_fname, header=[0, 1]) - data = data.set_index(('Subject - age - sex', 'Nr')) - data.index.name = 'subject' + data = data.set_index(("Subject - age - sex", "Nr")) + data.index.name = "subject" data.columns.names = [None, None] - data = (data.set_index([('Subject - age - sex', 'Age'), - ('Subject - age - sex', 'M1/F2')], append=True) - .stack(level=0).reset_index()) - - data = data.rename(columns={('Subject - age - sex', 'Age'): 'age', - ('Subject - age - sex', 'M1/F2'): 'sex', - 'level_3': 'drug'}) - data['id'] = ['ST7{:02d}{:1d}'.format(s, n) - for s, n in zip(data.subject, data['night nr'])] + data = ( + data.set_index( + [("Subject - age - sex", "Age"), ("Subject - age - sex", "M1/F2")], + append=True, + ) + .stack(level=0) + .reset_index() + ) - data = pd.merge(sha1_df, data, how='outer', on='id') - data['record type'] = (data.fname.str.split('-', expand=True)[1] - .str.split('.', expand=True)[0] - .astype('category')) + data = data.rename( + columns={ + ("Subject - age - sex", "Age"): "age", + ("Subject - age - sex", "M1/F2"): "sex", + "level_3": "drug", + } + ) + data["id"] = [ + "ST7{:02d}{:1d}".format(s, n) for s, n in zip(data.subject, data["night nr"]) + ] + + data = pd.merge(sha1_df, data, how="outer", on="id") + data["record type"] = ( + data.fname.str.split("-", expand=True)[1] + .str.split(".", expand=True)[0] + .astype("category") + ) - data = data.set_index(['id', 'subject', 'age', 'sex', 'drug', - 'lights off', 'night nr', 'record type']).unstack() - data.columns = [l1 + '_' + l2 for l1, l2 in data.columns] - data = data.reset_index().drop(columns=['id']) + data = data.set_index( + ["id", "subject", "age", "sex", "drug", "lights off", "night nr", "record type"] + ).unstack() + data.columns = [l1 + "_" + l2 for l1, l2 in data.columns] + data = data.reset_index().drop(columns=["id"]) - data['sex'] = (data.sex.astype('category') - .cat.rename_categories({1: 'male', 2: 'female'})) + data["sex"] = data.sex.astype("category").cat.rename_categories( + {1: "male", 2: "female"} + ) - data['drug'] = data['drug'].str.split(expand=True)[0] - data['subject_orig'] = data['subject'] - data['subject'] = data.index // 2 # to make sure index is from 0 to 21 + data["drug"] = data["drug"].str.split(expand=True)[0] + data["subject_orig"] = data["subject"] + data["subject"] = data.index // 2 # to make sure index is from 0 to 21 # Save the data. data.to_csv(fname, index=False) @@ -146,11 +165,12 @@ def _update_sleep_temazepam_records(fname=TEMAZEPAM_SLEEP_RECORDS): def _update_sleep_age_records(fname=AGE_SLEEP_RECORDS): """Help function to download Physionet's age dataset records.""" import pooch + pd = _check_pandas_installed() tmp = _TempDir() # Download subjects info. - subjects_fname = op.join(tmp, 'SC-subjects.xls') + subjects_fname = op.join(tmp, "SC-subjects.xls") downloader = pooch.HTTPDownloader(**_downloader_params()) pooch.retrieve( url=AGE_RECORDS_URL, @@ -161,38 +181,46 @@ def _update_sleep_age_records(fname=AGE_SLEEP_RECORDS): ) # Load and Massage the checksums. - sha1_df = pd.read_csv(sha1sums_fname, sep=' ', header=None, - names=['sha', 'fname'], engine='python') - select_age_records = (sha1_df.fname.str.startswith('SC') & - sha1_df.fname.str.endswith('edf')) + sha1_df = pd.read_csv( + sha1sums_fname, sep=" ", header=None, names=["sha", "fname"], engine="python" + ) + select_age_records = sha1_df.fname.str.startswith( + "SC" + ) & sha1_df.fname.str.endswith("edf") sha1_df = sha1_df[select_age_records] - sha1_df['id'] = [name[:6] for name in sha1_df.fname] + sha1_df["id"] = [name[:6] for name in sha1_df.fname] # Load and massage the data. data = pd.read_excel(subjects_fname) - data = data.rename(index=str, columns={'sex (F=1)': 'sex', - 'LightsOff': 'lights off'}) - data['sex'] = (data.sex.astype('category') - .cat.rename_categories({1: 'female', 2: 'male'})) + data = data.rename( + index=str, columns={"sex (F=1)": "sex", "LightsOff": "lights off"} + ) + data["sex"] = data.sex.astype("category").cat.rename_categories( + {1: "female", 2: "male"} + ) - data['id'] = ['SC4{:02d}{:1d}'.format(s, n) - for s, n in zip(data.subject, data.night)] + data["id"] = [ + "SC4{:02d}{:1d}".format(s, n) for s, n in zip(data.subject, data.night) + ] - data = data.set_index('id').join(sha1_df.set_index('id')).dropna() + data = data.set_index("id").join(sha1_df.set_index("id")).dropna() - data['record type'] = (data.fname.str.split('-', expand=True)[1] - .str.split('.', expand=True)[0] - .astype('category')) + data["record type"] = ( + data.fname.str.split("-", expand=True)[1] + .str.split(".", expand=True)[0] + .astype("category") + ) - data = data.reset_index().drop(columns=['id']) - data = data[['subject', 'night', 'record type', 'age', 'sex', 'lights off', - 'sha', 'fname']] + data = data.reset_index().drop(columns=["id"]) + data = data[ + ["subject", "night", "record type", "age", "sex", "lights off", "sha", "fname"] + ] # Save the data. data.to_csv(fname, index=False) -def _check_subjects(subjects, n_subjects, missing=None, on_missing='raise'): +def _check_subjects(subjects, n_subjects, missing=None, on_missing="raise"): """Check whether subjects are available. Parameters @@ -214,8 +242,10 @@ def _check_subjects(subjects, n_subjects, missing=None, on_missing='raise'): valid_subjects = np.setdiff1d(valid_subjects, missing) unknown_subjects = np.setdiff1d(subjects, valid_subjects) if unknown_subjects.size > 0: - subjects_list = ', '.join([str(s) for s in unknown_subjects]) - msg = (f'This dataset contains subjects 0 to {n_subjects - 1} with ' - f'missing subjects {missing}. Unknown subjects: ' - f'{subjects_list}.') + subjects_list = ", ".join([str(s) for s in unknown_subjects]) + msg = ( + f"This dataset contains subjects 0 to {n_subjects - 1} with " + f"missing subjects {missing}. Unknown subjects: " + f"{subjects_list}." + ) _on_missing(on_missing, msg) diff --git a/mne/datasets/sleep_physionet/age.py b/mne/datasets/sleep_physionet/age.py index 106d39d4e32..0a7fb174d1c 100644 --- a/mne/datasets/sleep_physionet/age.py +++ b/mne/datasets/sleep_physionet/age.py @@ -15,12 +15,22 @@ data_path = _data_path # expose _data_path(..) as data_path(..) -BASE_URL = 'https://physionet.org/physiobank/database/sleep-edfx/sleep-cassette/' # noqa: E501 +BASE_URL = ( + "https://physionet.org/physiobank/database/sleep-edfx/sleep-cassette/" # noqa: E501 +) @verbose -def fetch_data(subjects, recording=(1, 2), path=None, force_update=False, - base_url=BASE_URL, on_missing='raise', *, verbose=None): # noqa: D301, E501 +def fetch_data( + subjects, + recording=(1, 2), + path=None, + force_update=False, + base_url=BASE_URL, + on_missing="raise", + *, + verbose=None +): # noqa: D301, E501 """Get paths to local copies of PhysioNet Polysomnography dataset files. This will fetch data from the publicly available subjects from PhysioNet's @@ -84,46 +94,53 @@ def fetch_data(subjects, recording=(1, 2), path=None, force_update=False, .. footbibliography:: """ # noqa: E501 t0 = time.time() - records = np.loadtxt(AGE_SLEEP_RECORDS, - skiprows=1, - delimiter=',', - usecols=(0, 1, 2, 6, 7), - dtype={'names': ('subject', 'record', 'type', 'sha', - 'fname'), - 'formats': (' 0: - os.makedirs(op.join(destination, 'foo')) - assert op.isdir(op.join(destination, 'foo')) + os.makedirs(op.join(destination, "foo")) + assert op.isdir(op.join(destination, "foo")) for fname in _zip_fnames: assert not op.isfile(op.join(destination, fname)) for fname in _zip_fnames[:n_have]: - with open(op.join(destination, fname), 'w'): + with open(op.join(destination, fname), "w"): pass with catch_logging() as log: with use_log_level(True): # we mock the pooch.retrieve so these are not used - url = hash_ = '' + url = hash_ = "" _manifest_check_download(manifest_path, destination, url, hash_) log = log.getvalue() n_missing = 3 - n_have - assert ('%d file%s missing from' % (n_missing, _pl(n_missing))) in log - for want in ('Extracting missing', 'Successfully '): + assert ("%d file%s missing from" % (n_missing, _pl(n_missing))) in log + for want in ("Extracting missing", "Successfully "): if n_missing > 0: assert want in log else: @@ -236,10 +264,9 @@ def test_manifest_check_download(tmp_path, n_have, monkeypatch): assert op.isfile(op.join(destination, fname)) -def _fake_mcd(manifest_path, destination, url, hash_, name=None, - fake_files=False): +def _fake_mcd(manifest_path, destination, url, hash_, name=None, fake_files=False): if name is None: - name = url.split('/')[-1].split('.')[0] + name = url.split("/")[-1].split(".")[0] assert name in url assert name in str(destination) assert name in manifest_path @@ -252,16 +279,16 @@ def _fake_mcd(manifest_path, destination, url, hash_, name=None, continue fname = op.join(destination, path) os.makedirs(op.dirname(fname), exist_ok=True) - with open(fname, 'wb'): + with open(fname, "wb"): pass def test_infant(tmp_path, monkeypatch): """Test fetch_infant_template.""" - monkeypatch.setattr(infant_base, '_manifest_check_download', _fake_mcd) - fetch_infant_template('12mo', subjects_dir=tmp_path) - with pytest.raises(ValueError, match='Invalid value for'): - fetch_infant_template('0mo', subjects_dir=tmp_path) + monkeypatch.setattr(infant_base, "_manifest_check_download", _fake_mcd) + fetch_infant_template("12mo", subjects_dir=tmp_path) + with pytest.raises(ValueError, match="Invalid value for"): + fetch_infant_template("0mo", subjects_dir=tmp_path) def test_phantom(tmp_path, monkeypatch): @@ -270,21 +297,25 @@ def test_phantom(tmp_path, monkeypatch): # an actual download here. But it doesn't seem worth it given that # CircleCI will at least test the VectorView one, and this file should # not change often. - monkeypatch.setattr(phantom_base, '_manifest_check_download', - partial(_fake_mcd, name='phantom_otaniemi', - fake_files=True)) - fetch_phantom('otaniemi', subjects_dir=tmp_path) - assert op.isfile(tmp_path / 'phantom_otaniemi' / 'mri' / 'T1.mgz') + monkeypatch.setattr( + phantom_base, + "_manifest_check_download", + partial(_fake_mcd, name="phantom_otaniemi", fake_files=True), + ) + fetch_phantom("otaniemi", subjects_dir=tmp_path) + assert op.isfile(tmp_path / "phantom_otaniemi" / "mri" / "T1.mgz") def test_fetch_uncompressed_file(tmp_path): """Test downloading an uncompressed file with our fetch function.""" dataset_dict = dict( - dataset_name='license', - url=('https://raw.githubusercontent.com/mne-tools/mne-python/main/' - 'LICENSE.txt'), - archive_name='LICENSE.foo', - folder_name=op.join(tmp_path, 'foo'), - hash=None) + dataset_name="license", + url=( + "https://raw.githubusercontent.com/mne-tools/mne-python/main/" "LICENSE.txt" + ), + archive_name="LICENSE.foo", + folder_name=op.join(tmp_path, "foo"), + hash=None, + ) fetch_dataset(dataset_dict, path=None, force_update=True) - assert (tmp_path / 'foo' / 'LICENSE.foo').is_file() + assert (tmp_path / "foo" / "LICENSE.foo").is_file() diff --git a/mne/datasets/ucl_opm_auditory/ucl_opm_auditory.py b/mne/datasets/ucl_opm_auditory/ucl_opm_auditory.py index e43443d1480..09853e640de 100644 --- a/mne/datasets/ucl_opm_auditory/ucl_opm_auditory.py +++ b/mne/datasets/ucl_opm_auditory/ucl_opm_auditory.py @@ -2,26 +2,30 @@ # License: BSD Style. from ...utils import verbose -from ..utils import (_data_path_doc, _get_version, _version_doc, - _download_mne_dataset) +from ..utils import _data_path_doc, _get_version, _version_doc, _download_mne_dataset -_NAME = 'ucl_opm_auditory' -_PROCESSOR = 'unzip' +_NAME = "ucl_opm_auditory" +_PROCESSOR = "unzip" @verbose -def data_path(path=None, force_update=False, update_path=True, - download=True, *, verbose=None): # noqa: D103 +def data_path( + path=None, force_update=False, update_path=True, download=True, *, verbose=None +): # noqa: D103 return _download_mne_dataset( - name=_NAME, processor=_PROCESSOR, path=path, - force_update=force_update, update_path=update_path, - download=download) + name=_NAME, + processor=_PROCESSOR, + path=path, + force_update=force_update, + update_path=update_path, + download=download, + ) data_path.__doc__ = _data_path_doc.format( name=_NAME, - conf=f'MNE_DATASETS_{_NAME.upper()}_PATH', + conf=f"MNE_DATASETS_{_NAME.upper()}_PATH", ) diff --git a/mne/datasets/utils.py b/mne/datasets/utils.py index 1fba832abb0..32ff152cd5e 100644 --- a/mne/datasets/utils.py +++ b/mne/datasets/utils.py @@ -24,8 +24,16 @@ from .config import _hcp_mmp_license_text, MNE_DATASETS from ..label import read_labels_from_annot, Label, write_labels_to_annot -from ..utils import (get_config, set_config, logger, _validate_type, - verbose, get_subjects_dir, _pl, _safe_input) +from ..utils import ( + get_config, + set_config, + logger, + _validate_type, + verbose, + get_subjects_dir, + _pl, + _safe_input, +) from ..utils.docs import docdict, _docformat @@ -58,10 +66,10 @@ path : instance of Path Path to {name} dataset directory. """ -_data_path_doc_accept = _data_path_doc.split('%(verbose)s') -_data_path_doc_accept[-1] = '%(verbose)s' + _data_path_doc_accept[-1] -_data_path_doc_accept.insert(1, ' %(accept)s') -_data_path_doc_accept = ''.join(_data_path_doc_accept) +_data_path_doc_accept = _data_path_doc.split("%(verbose)s") +_data_path_doc_accept[-1] = "%(verbose)s" + _data_path_doc_accept[-1] +_data_path_doc_accept.insert(1, " %(accept)s") +_data_path_doc_accept = "".join(_data_path_doc_accept) _data_path_doc = _docformat(_data_path_doc, docdict) _data_path_doc_accept = _docformat(_data_path_doc_accept, docdict) @@ -77,69 +85,73 @@ def _dataset_version(path, name): """Get the version of the dataset.""" - ver_fname = op.join(path, 'version.txt') + ver_fname = op.join(path, "version.txt") if op.exists(ver_fname): - with open(ver_fname, 'r') as fid: + with open(ver_fname, "r") as fid: version = fid.readline().strip() # version is on first line else: - logger.debug(f'Version file missing: {ver_fname}') + logger.debug(f"Version file missing: {ver_fname}") # Sample dataset versioning was introduced after 0.3 # SPM dataset was introduced with 0.7 - versions = dict(sample='0.7', spm='0.3') - version = versions.get(name, '0.0') + versions = dict(sample="0.7", spm="0.3") + version = versions.get(name, "0.0") return version def _get_path(path, key, name): """Get a dataset path.""" # 1. Input - _validate_type(path, ('path-like', None), path) + _validate_type(path, ("path-like", None), path) if path is not None: return path # 2. get_config(key) — unless key is None or "" (special get_config values) # 3. get_config('MNE_DATA') - path = get_config(key or 'MNE_DATA', get_config('MNE_DATA')) + path = get_config(key or "MNE_DATA", get_config("MNE_DATA")) if path is not None: path = Path(path).expanduser() if not path.exists(): - msg = (f"Download location {path} as specified by MNE_DATA does " - f"not exist. Either create this directory manually and try " - f"again, or set MNE_DATA to an existing directory.") + msg = ( + f"Download location {path} as specified by MNE_DATA does " + f"not exist. Either create this directory manually and try " + f"again, or set MNE_DATA to an existing directory." + ) raise FileNotFoundError(msg) return path # 4. ~/mne_data (but use a fake home during testing so we don't # unnecessarily create ~/mne_data) - logger.info('Using default location ~/mne_data for %s...' % name) - path = op.join(os.getenv('_MNE_FAKE_HOME_DIR', - op.expanduser("~")), 'mne_data') + logger.info("Using default location ~/mne_data for %s..." % name) + path = op.join(os.getenv("_MNE_FAKE_HOME_DIR", op.expanduser("~")), "mne_data") if not op.exists(path): - logger.info('Creating ~/mne_data') + logger.info("Creating ~/mne_data") try: os.mkdir(path) except OSError: - raise OSError("User does not have write permissions " - "at '%s', try giving the path as an " - "argument to data_path() where user has " - "write permissions, for ex:data_path" - "('/home/xyz/me2/')" % (path)) + raise OSError( + "User does not have write permissions " + "at '%s', try giving the path as an " + "argument to data_path() where user has " + "write permissions, for ex:data_path" + "('/home/xyz/me2/')" % (path) + ) return Path(path) def _do_path_update(path, update_path, key, name): """Update path.""" path = op.abspath(path) - identical = get_config(key, '', use_env=False) == path + identical = get_config(key, "", use_env=False) == path if not identical: if update_path is None: update_path = True - if '--update-dataset-path' in sys.argv: - answer = 'y' + if "--update-dataset-path" in sys.argv: + answer = "y" else: - msg = ('Do you want to set the path:\n %s\nas the default ' - '%s dataset path in the mne-python config [y]/n? ' - % (path, name)) - answer = _safe_input(msg, alt='pass update_path=True') - if answer.lower() == 'n': + msg = ( + "Do you want to set the path:\n %s\nas the default " + "%s dataset path in the mne-python config [y]/n? " % (path, name) + ) + answer = _safe_input(msg, alt="pass update_path=True") + if answer.lower() == "n": update_path = False if update_path: @@ -149,14 +161,15 @@ def _do_path_update(path, update_path, key, name): # This is meant to be semi-public: let packages like mne-bids use it to make # sure they don't accidentally set download=True in their tests, too -_MODULES_TO_ENSURE_DOWNLOAD_IS_FALSE_IN_TESTS = ('mne',) +_MODULES_TO_ENSURE_DOWNLOAD_IS_FALSE_IN_TESTS = ("mne",) def _check_in_testing_and_raise(name, download): """Check if we're in an MNE test and raise an error if download!=False.""" root_dirs = [ importlib.import_module(ns) - for ns in _MODULES_TO_ENSURE_DOWNLOAD_IS_FALSE_IN_TESTS] + for ns in _MODULES_TO_ENSURE_DOWNLOAD_IS_FALSE_IN_TESTS + ] root_dirs = [str(Path(ns.__file__).parent) for ns in root_dirs] check = False func = None @@ -164,7 +177,7 @@ def _check_in_testing_and_raise(name, download): try: # First, traverse out of the data_path() call while frame: - if frame.f_code.co_name in ('data_path', 'load_data'): + if frame.f_code.co_name in ("data_path", "load_data"): func = frame.f_code.co_name frame = frame.f_back.f_back # out of verbose decorator break @@ -177,10 +190,12 @@ def _check_in_testing_and_raise(name, download): # in mne namespace, and # (can't use is_relative_to here until 3.9) if any(str(fname).startswith(rd) for rd in root_dirs) and ( - # in tests/*.py - fname.parent.stem == 'tests' or - # or in a conftest.py - fname.stem == 'conftest.py'): + # in tests/*.py + fname.parent.stem == "tests" + or + # or in a conftest.py + fname.stem == "conftest.py" + ): check = True break frame = frame.f_back @@ -188,12 +203,14 @@ def _check_in_testing_and_raise(name, download): del frame if check and download is not False: raise RuntimeError( - f'Do not download dataset {repr(name)} in tests, pass ' - f'{func}(download=False) to prevent accidental downloads') + f"Do not download dataset {repr(name)} in tests, pass " + f"{func}(download=False) to prevent accidental downloads" + ) -def _download_mne_dataset(name, processor, path, force_update, - update_path, download, accept=False): +def _download_mne_dataset( + name, processor, path, force_update, update_path, download, accept=False +): """Aux function for downloading internal MNE datasets.""" import pooch from mne.datasets._fetch import fetch_dataset @@ -202,33 +219,38 @@ def _download_mne_dataset(name, processor, path, force_update, # import pooch library for handling the dataset downloading dataset_params = MNE_DATASETS[name] - dataset_params['dataset_name'] = name - config_key = MNE_DATASETS[name]['config_key'] - folder_name = MNE_DATASETS[name]['folder_name'] + dataset_params["dataset_name"] = name + config_key = MNE_DATASETS[name]["config_key"] + folder_name = MNE_DATASETS[name]["folder_name"] # get download path for specific dataset path = _get_path(path=path, key=config_key, name=name) # instantiate processor that unzips file - if processor == 'nested_untar': + if processor == "nested_untar": processor_ = pooch.Untar(extract_dir=op.join(path, folder_name)) - elif processor == 'nested_unzip': + elif processor == "nested_unzip": processor_ = pooch.Unzip(extract_dir=op.join(path, folder_name)) else: processor_ = processor # handle case of multiple sub-datasets with different urls - if name == 'visual_92_categories': + if name == "visual_92_categories": dataset_params = [] - for name in ['visual_92_categories_1', 'visual_92_categories_2']: + for name in ["visual_92_categories_1", "visual_92_categories_2"]: this_dataset = MNE_DATASETS[name] - this_dataset['dataset_name'] = name + this_dataset["dataset_name"] = name dataset_params.append(this_dataset) - return fetch_dataset(dataset_params=dataset_params, processor=processor_, - path=path, force_update=force_update, - update_path=update_path, download=download, - accept=accept) + return fetch_dataset( + dataset_params=dataset_params, + processor=processor_, + path=path, + force_update=force_update, + update_path=update_path, + download=download, + accept=accept, + ) def _get_version(name): @@ -238,14 +260,13 @@ def _get_version(name): if not has_dataset(name): return None dataset_params = MNE_DATASETS[name] - dataset_params['dataset_name'] = name - config_key = MNE_DATASETS[name]['config_key'] + dataset_params["dataset_name"] = name + config_key = MNE_DATASETS[name]["config_key"] # get download path for specific dataset path = _get_path(path=None, key=config_key, name=name) - return fetch_dataset(dataset_params, path=path, - return_version=True)[1] + return fetch_dataset(dataset_params, path=path, return_version=True)[1] def has_dataset(name): @@ -268,24 +289,23 @@ def has_dataset(name): from mne.datasets._fetch import fetch_dataset if isinstance(name, dict): - dataset_name = name['dataset_name'] + dataset_name = name["dataset_name"] dataset_params = name else: - dataset_name = 'spm' if name == 'spm_face' else name + dataset_name = "spm" if name == "spm_face" else name dataset_params = MNE_DATASETS[dataset_name] - dataset_params['dataset_name'] = dataset_name + dataset_params["dataset_name"] = dataset_name - config_key = dataset_params['config_key'] + config_key = dataset_params["config_key"] # get download path for specific dataset path = _get_path(path=None, key=config_key, name=dataset_name) - dp = fetch_dataset(dataset_params, path=path, download=False, - check_version=False) - if dataset_name.startswith('bst_'): + dp = fetch_dataset(dataset_params, path=path, download=False, check_version=False) + if dataset_name.startswith("bst_"): check = dataset_name else: - check = MNE_DATASETS[dataset_name]['folder_name'] + check = MNE_DATASETS[dataset_name]["folder_name"] return str(dp).endswith(check) @@ -302,51 +322,57 @@ def _download_all_example_data(verbose=True): # verbose=True by default so we get nice status messages. # Consider adding datasets from here to CircleCI for PR-auto-build paths = dict() - for kind in ('sample testing misc spm_face somato hf_sef multimodal ' - 'fnirs_motor opm mtrf fieldtrip_cmc kiloword phantom_4dbti ' - 'refmeg_noise ssvep epilepsy_ecog ucl_opm_auditory eyelink ' - 'erp_core brainstorm.bst_raw brainstorm.bst_auditory ' - 'brainstorm.bst_resting brainstorm.bst_phantom_ctf ' - 'brainstorm.bst_phantom_elekta' - ).split(): - mod = importlib.import_module(f'mne.datasets.{kind}') - data_path_func = getattr(mod, 'data_path') + for kind in ( + "sample testing misc spm_face somato hf_sef multimodal " + "fnirs_motor opm mtrf fieldtrip_cmc kiloword phantom_4dbti " + "refmeg_noise ssvep epilepsy_ecog ucl_opm_auditory eyelink " + "erp_core brainstorm.bst_raw brainstorm.bst_auditory " + "brainstorm.bst_resting brainstorm.bst_phantom_ctf " + "brainstorm.bst_phantom_elekta" + ).split(): + mod = importlib.import_module(f"mne.datasets.{kind}") + data_path_func = getattr(mod, "data_path") kwargs = dict() - if 'accept' in inspect.getfullargspec(data_path_func).args: - kwargs['accept'] = True + if "accept" in inspect.getfullargspec(data_path_func).args: + kwargs["accept"] = True paths[kind] = data_path_func(**kwargs) - logger.info(f'[done {kind}]') + logger.info(f"[done {kind}]") # Now for the exceptions: from . import ( - eegbci, sleep_physionet, limo, fetch_fsaverage, fetch_infant_template, - fetch_hcp_mmp_parcellation, fetch_phantom) + eegbci, + sleep_physionet, + limo, + fetch_fsaverage, + fetch_infant_template, + fetch_hcp_mmp_parcellation, + fetch_phantom, + ) + eegbci.load_data(1, [6, 10, 14], update_path=True) for subj in range(4): eegbci.load_data(subj + 1, runs=[3], update_path=True) - logger.info('[done eegbci]') + logger.info("[done eegbci]") sleep_physionet.age.fetch_data(subjects=[0, 1], recording=[1]) - logger.info('[done sleep_physionet]') + logger.info("[done sleep_physionet]") # If the user has SUBJECTS_DIR, respect it, if not, set it to the EEG one # (probably on CircleCI, or otherwise advanced user) fetch_fsaverage(None) - logger.info('[done fsaverage]') + logger.info("[done fsaverage]") - fetch_infant_template('6mo') - logger.info('[done infant_template]') + fetch_infant_template("6mo") + logger.info("[done infant_template]") - fetch_hcp_mmp_parcellation( - subjects_dir=paths['sample'] / 'subjects', accept=True) - logger.info('[done hcp_mmp_parcellation]') + fetch_hcp_mmp_parcellation(subjects_dir=paths["sample"] / "subjects", accept=True) + logger.info("[done hcp_mmp_parcellation]") - fetch_phantom( - 'otaniemi', subjects_dir=paths['brainstorm.bst_phantom_elekta']) - logger.info('[done phantom]') + fetch_phantom("otaniemi", subjects_dir=paths["brainstorm.bst_phantom_elekta"]) + logger.info("[done phantom]") limo.load_data(subject=1, update_path=True) - logger.info('[done limo]') + logger.info("[done limo]") @verbose @@ -372,13 +398,13 @@ def fetch_aparc_sub_parcellation(subjects_dir=None, verbose=None): subjects_dir = get_subjects_dir(subjects_dir, raise_error=True) destination = subjects_dir / "fsaverage" / "label" - urls = dict(lh='https://osf.io/p92yb/download', - rh='https://osf.io/4kxny/download') - hashes = dict(lh='9e4d8d6b90242b7e4b0145353436ef77', - rh='dd6464db8e7762d969fc1d8087cd211b') + urls = dict(lh="https://osf.io/p92yb/download", rh="https://osf.io/4kxny/download") + hashes = dict( + lh="9e4d8d6b90242b7e4b0145353436ef77", rh="dd6464db8e7762d969fc1d8087cd211b" + ) downloader = pooch.HTTPDownloader(**_downloader_params()) - for hemi in ('lh', 'rh'): - fname = f'{hemi}.aparc_sub.annot' + for hemi in ("lh", "rh"): + fname = f"{hemi}.aparc_sub.annot" fpath = destination / fname if not fpath.is_file(): pooch.retrieve( @@ -391,8 +417,9 @@ def fetch_aparc_sub_parcellation(subjects_dir=None, verbose=None): @verbose -def fetch_hcp_mmp_parcellation(subjects_dir=None, combine=True, *, - accept=False, verbose=None): +def fetch_hcp_mmp_parcellation( + subjects_dir=None, combine=True, *, accept=False, verbose=None +): """Fetch the HCP-MMP parcellation. This will download and install the HCP-MMP parcellation @@ -425,20 +452,22 @@ def fetch_hcp_mmp_parcellation(subjects_dir=None, combine=True, *, subjects_dir = get_subjects_dir(subjects_dir, raise_error=True) destination = subjects_dir / "fsaverage" / "label" fnames = [destination / f"{hemi}.HCPMMP1.annot" for hemi in ("lh", "rh")] - urls = dict(lh='https://ndownloader.figshare.com/files/5528816', - rh='https://ndownloader.figshare.com/files/5528819') - hashes = dict(lh='46a102b59b2fb1bb4bd62d51bf02e975', - rh='75e96b331940227bbcb07c1c791c2463') + urls = dict( + lh="https://ndownloader.figshare.com/files/5528816", + rh="https://ndownloader.figshare.com/files/5528819", + ) + hashes = dict( + lh="46a102b59b2fb1bb4bd62d51bf02e975", rh="75e96b331940227bbcb07c1c791c2463" + ) if not all(fname.exists() for fname in fnames): - if accept or '--accept-hcpmmp-license' in sys.argv: - answer = 'y' + if accept or "--accept-hcpmmp-license" in sys.argv: + answer = "y" else: - answer = _safe_input('%s\nAgree (y/[n])? ' % _hcp_mmp_license_text) - if answer.lower() != 'y': - raise RuntimeError('You must agree to the license to use this ' - 'dataset') + answer = _safe_input("%s\nAgree (y/[n])? " % _hcp_mmp_license_text) + if answer.lower() != "y": + raise RuntimeError("You must agree to the license to use this " "dataset") downloader = pooch.HTTPDownloader(**_downloader_params()) - for hemi, fpath in zip(('lh', 'rh'), fnames): + for hemi, fpath in zip(("lh", "rh"), fnames): if not op.isfile(fpath): fname = fpath.name pooch.retrieve( @@ -450,82 +479,255 @@ def fetch_hcp_mmp_parcellation(subjects_dir=None, combine=True, *, ) if combine: - fnames = [op.join(destination, '%s.HCPMMP1_combined.annot' % hemi) - for hemi in ('lh', 'rh')] + fnames = [ + op.join(destination, "%s.HCPMMP1_combined.annot" % hemi) + for hemi in ("lh", "rh") + ] if all(op.isfile(fname) for fname in fnames): return # otherwise, let's make them - logger.info('Creating combined labels') - groups = OrderedDict([ - ('Primary Visual Cortex (V1)', - ('V1',)), - ('Early Visual Cortex', - ('V2', 'V3', 'V4')), - ('Dorsal Stream Visual Cortex', - ('V3A', 'V3B', 'V6', 'V6A', 'V7', 'IPS1')), - ('Ventral Stream Visual Cortex', - ('V8', 'VVC', 'PIT', 'FFC', 'VMV1', 'VMV2', 'VMV3')), - ('MT+ Complex and Neighboring Visual Areas', - ('V3CD', 'LO1', 'LO2', 'LO3', 'V4t', 'FST', 'MT', 'MST', 'PH')), - ('Somatosensory and Motor Cortex', - ('4', '3a', '3b', '1', '2')), - ('Paracentral Lobular and Mid Cingulate Cortex', - ('24dd', '24dv', '6mp', '6ma', 'SCEF', '5m', '5L', '5mv',)), - ('Premotor Cortex', - ('55b', '6d', '6a', 'FEF', '6v', '6r', 'PEF')), - ('Posterior Opercular Cortex', - ('43', 'FOP1', 'OP4', 'OP1', 'OP2-3', 'PFcm')), - ('Early Auditory Cortex', - ('A1', 'LBelt', 'MBelt', 'PBelt', 'RI')), - ('Auditory Association Cortex', - ('A4', 'A5', 'STSdp', 'STSda', 'STSvp', 'STSva', 'STGa', 'TA2',)), - ('Insular and Frontal Opercular Cortex', - ('52', 'PI', 'Ig', 'PoI1', 'PoI2', 'FOP2', 'FOP3', - 'MI', 'AVI', 'AAIC', 'Pir', 'FOP4', 'FOP5')), - ('Medial Temporal Cortex', - ('H', 'PreS', 'EC', 'PeEc', 'PHA1', 'PHA2', 'PHA3',)), - ('Lateral Temporal Cortex', - ('PHT', 'TE1p', 'TE1m', 'TE1a', 'TE2p', 'TE2a', - 'TGv', 'TGd', 'TF',)), - ('Temporo-Parieto-Occipital Junction', - ('TPOJ1', 'TPOJ2', 'TPOJ3', 'STV', 'PSL',)), - ('Superior Parietal Cortex', - ('LIPv', 'LIPd', 'VIP', 'AIP', 'MIP', - '7PC', '7AL', '7Am', '7PL', '7Pm',)), - ('Inferior Parietal Cortex', - ('PGp', 'PGs', 'PGi', 'PFm', 'PF', 'PFt', 'PFop', - 'IP0', 'IP1', 'IP2',)), - ('Posterior Cingulate Cortex', - ('DVT', 'ProS', 'POS1', 'POS2', 'RSC', 'v23ab', 'd23ab', - '31pv', '31pd', '31a', '23d', '23c', 'PCV', '7m',)), - ('Anterior Cingulate and Medial Prefrontal Cortex', - ('33pr', 'p24pr', 'a24pr', 'p24', 'a24', 'p32pr', 'a32pr', 'd32', - 'p32', 's32', '8BM', '9m', '10v', '10r', '25',)), - ('Orbital and Polar Frontal Cortex', - ('47s', '47m', 'a47r', '11l', '13l', - 'a10p', 'p10p', '10pp', '10d', 'OFC', 'pOFC',)), - ('Inferior Frontal Cortex', - ('44', '45', 'IFJp', 'IFJa', 'IFSp', 'IFSa', '47l', 'p47r',)), - ('DorsoLateral Prefrontal Cortex', - ('8C', '8Av', 'i6-8', 's6-8', 'SFL', '8BL', '9p', '9a', '8Ad', - 'p9-46v', 'a9-46v', '46', '9-46d',)), - ('???', - ('???',))]) + logger.info("Creating combined labels") + groups = OrderedDict( + [ + ("Primary Visual Cortex (V1)", ("V1",)), + ("Early Visual Cortex", ("V2", "V3", "V4")), + ( + "Dorsal Stream Visual Cortex", + ("V3A", "V3B", "V6", "V6A", "V7", "IPS1"), + ), + ( + "Ventral Stream Visual Cortex", + ("V8", "VVC", "PIT", "FFC", "VMV1", "VMV2", "VMV3"), + ), + ( + "MT+ Complex and Neighboring Visual Areas", + ("V3CD", "LO1", "LO2", "LO3", "V4t", "FST", "MT", "MST", "PH"), + ), + ("Somatosensory and Motor Cortex", ("4", "3a", "3b", "1", "2")), + ( + "Paracentral Lobular and Mid Cingulate Cortex", + ( + "24dd", + "24dv", + "6mp", + "6ma", + "SCEF", + "5m", + "5L", + "5mv", + ), + ), + ("Premotor Cortex", ("55b", "6d", "6a", "FEF", "6v", "6r", "PEF")), + ( + "Posterior Opercular Cortex", + ("43", "FOP1", "OP4", "OP1", "OP2-3", "PFcm"), + ), + ("Early Auditory Cortex", ("A1", "LBelt", "MBelt", "PBelt", "RI")), + ( + "Auditory Association Cortex", + ( + "A4", + "A5", + "STSdp", + "STSda", + "STSvp", + "STSva", + "STGa", + "TA2", + ), + ), + ( + "Insular and Frontal Opercular Cortex", + ( + "52", + "PI", + "Ig", + "PoI1", + "PoI2", + "FOP2", + "FOP3", + "MI", + "AVI", + "AAIC", + "Pir", + "FOP4", + "FOP5", + ), + ), + ( + "Medial Temporal Cortex", + ( + "H", + "PreS", + "EC", + "PeEc", + "PHA1", + "PHA2", + "PHA3", + ), + ), + ( + "Lateral Temporal Cortex", + ( + "PHT", + "TE1p", + "TE1m", + "TE1a", + "TE2p", + "TE2a", + "TGv", + "TGd", + "TF", + ), + ), + ( + "Temporo-Parieto-Occipital Junction", + ( + "TPOJ1", + "TPOJ2", + "TPOJ3", + "STV", + "PSL", + ), + ), + ( + "Superior Parietal Cortex", + ( + "LIPv", + "LIPd", + "VIP", + "AIP", + "MIP", + "7PC", + "7AL", + "7Am", + "7PL", + "7Pm", + ), + ), + ( + "Inferior Parietal Cortex", + ( + "PGp", + "PGs", + "PGi", + "PFm", + "PF", + "PFt", + "PFop", + "IP0", + "IP1", + "IP2", + ), + ), + ( + "Posterior Cingulate Cortex", + ( + "DVT", + "ProS", + "POS1", + "POS2", + "RSC", + "v23ab", + "d23ab", + "31pv", + "31pd", + "31a", + "23d", + "23c", + "PCV", + "7m", + ), + ), + ( + "Anterior Cingulate and Medial Prefrontal Cortex", + ( + "33pr", + "p24pr", + "a24pr", + "p24", + "a24", + "p32pr", + "a32pr", + "d32", + "p32", + "s32", + "8BM", + "9m", + "10v", + "10r", + "25", + ), + ), + ( + "Orbital and Polar Frontal Cortex", + ( + "47s", + "47m", + "a47r", + "11l", + "13l", + "a10p", + "p10p", + "10pp", + "10d", + "OFC", + "pOFC", + ), + ), + ( + "Inferior Frontal Cortex", + ( + "44", + "45", + "IFJp", + "IFJa", + "IFSp", + "IFSa", + "47l", + "p47r", + ), + ), + ( + "DorsoLateral Prefrontal Cortex", + ( + "8C", + "8Av", + "i6-8", + "s6-8", + "SFL", + "8BL", + "9p", + "9a", + "8Ad", + "p9-46v", + "a9-46v", + "46", + "9-46d", + ), + ), + ("???", ("???",)), + ] + ) assert len(groups) == 23 labels_out = list() - for hemi in ('lh', 'rh'): - labels = read_labels_from_annot('fsaverage', 'HCPMMP1', hemi=hemi, - subjects_dir=subjects_dir, - sort=False) + for hemi in ("lh", "rh"): + labels = read_labels_from_annot( + "fsaverage", "HCPMMP1", hemi=hemi, subjects_dir=subjects_dir, sort=False + ) label_names = [ - '???' if label.name.startswith('???') else - label.name.split('_')[1] for label in labels] + "???" if label.name.startswith("???") else label.name.split("_")[1] + for label in labels + ] used = np.zeros(len(labels), bool) for key, want in groups.items(): - assert '\t' not in key - these_labels = [li for li, label_name in enumerate(label_names) - if label_name in want] + assert "\t" not in key + these_labels = [ + li + for li, label_name in enumerate(label_names) + if label_name in want + ] assert not used[these_labels].any() assert len(these_labels) == len(want) used[these_labels] = True @@ -535,38 +737,47 @@ def fetch_hcp_mmp_parcellation(subjects_dir=None, combine=True, *, w = np.array([len(label.vertices) for label in these_labels]) w = w / float(w.sum()) color = np.dot(w, [label.color for label in these_labels]) - these_labels = sum(these_labels, - Label([], subject='fsaverage', hemi=hemi)) + these_labels = sum( + these_labels, Label([], subject="fsaverage", hemi=hemi) + ) these_labels.name = key these_labels.color = color labels_out.append(these_labels) assert used.all() assert len(labels_out) == 46 - for hemi, side in (('lh', 'left'), ('rh', 'right')): - table_name = './%s.fsaverage164.label.gii' % (side,) - write_labels_to_annot(labels_out, 'fsaverage', 'HCPMMP1_combined', - hemi=hemi, subjects_dir=subjects_dir, - sort=False, table_name=table_name) + for hemi, side in (("lh", "left"), ("rh", "right")): + table_name = "./%s.fsaverage164.label.gii" % (side,) + write_labels_to_annot( + labels_out, + "fsaverage", + "HCPMMP1_combined", + hemi=hemi, + subjects_dir=subjects_dir, + sort=False, + table_name=table_name, + ) def _manifest_check_download(manifest_path, destination, url, hash_): import pooch - with open(manifest_path, 'r') as fid: + with open(manifest_path, "r") as fid: names = [name.strip() for name in fid.readlines()] manifest_path = op.basename(manifest_path) need = list() for name in names: if not op.isfile(op.join(destination, name)): need.append(name) - logger.info('%d file%s missing from %s in %s' - % (len(need), _pl(need), manifest_path, destination)) + logger.info( + "%d file%s missing from %s in %s" + % (len(need), _pl(need), manifest_path, destination) + ) if len(need) > 0: downloader = pooch.HTTPDownloader(**_downloader_params()) with tempfile.TemporaryDirectory() as path: - logger.info('Downloading missing files remotely') + logger.info("Downloading missing files remotely") - fname_path = op.join(path, 'temp.zip') + fname_path = op.join(path, "temp.zip") pooch.retrieve( url=url, known_hash=f"md5:{hash_}", @@ -575,36 +786,36 @@ def _manifest_check_download(manifest_path, destination, url, hash_): fname=op.basename(fname_path), ) - logger.info('Extracting missing file%s' % (_pl(need),)) - with zipfile.ZipFile(fname_path, 'r') as ff: - members = set(f for f in ff.namelist() if not f.endswith('/')) + logger.info("Extracting missing file%s" % (_pl(need),)) + with zipfile.ZipFile(fname_path, "r") as ff: + members = set(f for f in ff.namelist() if not f.endswith("/")) missing = sorted(members.symmetric_difference(set(names))) if len(missing): - raise RuntimeError('Zip file did not have correct names:' - '\n%s' % ('\n'.join(missing))) + raise RuntimeError( + "Zip file did not have correct names:" + "\n%s" % ("\n".join(missing)) + ) for name in need: ff.extract(name, path=destination) - logger.info('Successfully extracted %d file%s' - % (len(need), _pl(need))) + logger.info("Successfully extracted %d file%s" % (len(need), _pl(need))) def _log_time_size(t0, sz): t = time.time() - t0 - fmt = '%Ss' + fmt = "%Ss" if t > 60: - fmt = f'%Mm{fmt}' + fmt = f"%Mm{fmt}" if t > 3600: - fmt = f'%Hh{fmt}' + fmt = f"%Hh{fmt}" sz = sz / 1048576 # 1024 ** 2 t = time.strftime(fmt, time.gmtime(t)) - logger.info(f'Download complete in {t} ({sz:.1f} MB)') + logger.info(f"Download complete in {t} ({sz:.1f} MB)") def _downloader_params(*, auth=None, token=None): params = dict() - params['progressbar'] = ( - logger.level <= logging.INFO and - get_config('MNE_TQDM', 'tqdm.auto') != 'off' + params["progressbar"] = ( + logger.level <= logging.INFO and get_config("MNE_TQDM", "tqdm.auto") != "off" ) if auth is not None: params["auth"] = auth diff --git a/mne/datasets/visual_92_categories/visual_92_categories.py b/mne/datasets/visual_92_categories/visual_92_categories.py index df687aafb6c..d5fb1c1c8bb 100644 --- a/mne/datasets/visual_92_categories/visual_92_categories.py +++ b/mne/datasets/visual_92_categories/visual_92_categories.py @@ -1,13 +1,13 @@ # License: BSD Style. from ...utils import verbose -from ..utils import (_download_mne_dataset, _data_path_doc, _get_version, - _version_doc) +from ..utils import _download_mne_dataset, _data_path_doc, _get_version, _version_doc @verbose -def data_path(path=None, force_update=False, update_path=True, - download=True, *, verbose=None): +def data_path( + path=None, force_update=False, update_path=True, download=True, *, verbose=None +): """ Get path to local copy of visual_92_categories dataset. @@ -43,18 +43,23 @@ def data_path(path=None, force_update=False, update_path=True, human object recognition in space and time. doi: 10.1038/NN.3635 """ return _download_mne_dataset( - name='visual_92_categories', processor='untar', path=path, - force_update=force_update, update_path=update_path, - download=download) + name="visual_92_categories", + processor="untar", + path=path, + force_update=force_update, + update_path=update_path, + download=download, + ) data_path.__doc__ = _data_path_doc.format( - name='visual_92_categories', conf='MNE_DATASETS_VISUAL_92_CATEGORIES_PATH') + name="visual_92_categories", conf="MNE_DATASETS_VISUAL_92_CATEGORIES_PATH" +) def get_version(): """Get dataset version.""" - return _get_version('visual_92_categories') + return _get_version("visual_92_categories") -get_version.__doc__ = _version_doc.format(name='visual_92_categories') +get_version.__doc__ = _version_doc.format(name="visual_92_categories") diff --git a/mne/decoding/__init__.py b/mne/decoding/__init__.py index 2b0136256b6..099e3c0dd30 100644 --- a/mne/decoding/__init__.py +++ b/mne/decoding/__init__.py @@ -1,8 +1,13 @@ """Decoding and encoding, including machine learning and receptive fields.""" -from .transformer import (PSDEstimator, Vectorizer, - UnsupervisedSpatialFilter, TemporalFilter, - Scaler, FilterEstimator) +from .transformer import ( + PSDEstimator, + Vectorizer, + UnsupervisedSpatialFilter, + TemporalFilter, + Scaler, + FilterEstimator, +) from .mixin import TransformerMixin from .base import BaseEstimator, LinearModel, get_coef, cross_val_multiscore from .csp import CSP, SPoC diff --git a/mne/decoding/base.py b/mne/decoding/base.py index 9d8070b8179..348ee2ee0f7 100644 --- a/mne/decoding/base.py +++ b/mne/decoding/base.py @@ -54,7 +54,8 @@ class LinearModel(BaseEstimator): def __init__(self, model=None): # noqa: D102 if model is None: from sklearn.linear_model import LogisticRegression - model = LogisticRegression(solver='liblinear') + + model = LogisticRegression(solver="liblinear") self.model = model self._estimator_type = getattr(model, "_estimator_type", None) @@ -81,18 +82,22 @@ def fit(self, X, y, **fit_params): """ X, y = np.asarray(X), np.asarray(y) if X.ndim != 2: - raise ValueError('LinearModel only accepts 2-dimensional X, got ' - '%s instead.' % (X.shape,)) + raise ValueError( + "LinearModel only accepts 2-dimensional X, got " + "%s instead." % (X.shape,) + ) if y.ndim > 2: - raise ValueError('LinearModel only accepts up to 2-dimensional y, ' - 'got %s instead.' % (y.shape,)) + raise ValueError( + "LinearModel only accepts up to 2-dimensional y, " + "got %s instead." % (y.shape,) + ) # fit the Model self.model.fit(X, y, **fit_params) # Computes patterns using Haufe's trick: A = Cov_X . W . Precision_Y - inv_Y = 1. + inv_Y = 1.0 X = X - X.mean(0, keepdims=True) if y.ndim == 2 and y.shape[1] != 1: y = y - y.mean(0, keepdims=True) @@ -103,14 +108,14 @@ def fit(self, X, y, **fit_params): @property def filters_(self): - if hasattr(self.model, 'coef_'): + if hasattr(self.model, "coef_"): # Standard Linear Model filters = self.model.coef_ - elif hasattr(self.model.best_estimator_, 'coef_'): + elif hasattr(self.model.best_estimator_, "coef_"): # Linear Model with GridSearchCV filters = self.model.best_estimator_.coef_ else: - raise ValueError('model does not have a `coef_` attribute.') + raise ValueError("model does not have a `coef_` attribute.") if filters.ndim == 2 and filters.shape[0] == 1: filters = filters[0] return filters @@ -213,60 +218,62 @@ def score(self, X, y): def _set_cv(cv, estimator=None, X=None, y=None): """Set the default CV depending on whether clf is classifier/regressor.""" # Detect whether classification or regression - if estimator in ['classifier', 'regressor']: - est_is_classifier = estimator == 'classifier' + if estimator in ["classifier", "regressor"]: + est_is_classifier = estimator == "classifier" else: est_is_classifier = is_classifier(estimator) # Setup CV from sklearn import model_selection as models - from sklearn.model_selection import (check_cv, StratifiedKFold, KFold) + from sklearn.model_selection import check_cv, StratifiedKFold, KFold + if isinstance(cv, (int, np.int64)): XFold = StratifiedKFold if est_is_classifier else KFold cv = XFold(n_splits=cv) elif isinstance(cv, str): if not hasattr(models, cv): - raise ValueError('Unknown cross-validation') + raise ValueError("Unknown cross-validation") cv = getattr(models, cv) cv = cv() cv = check_cv(cv=cv, y=y, classifier=est_is_classifier) # Extract train and test set to retrieve them at predict time - cv_splits = [(train, test) for train, test in - cv.split(X=np.zeros_like(y), y=y)] + cv_splits = [(train, test) for train, test in cv.split(X=np.zeros_like(y), y=y)] if not np.all([len(train) for train, _ in cv_splits]): - raise ValueError('Some folds do not have any train epochs.') + raise ValueError("Some folds do not have any train epochs.") return cv, cv_splits def _check_estimator(estimator, get_params=True): """Check whether an object has the methods required by sklearn.""" - valid_methods = ('predict', 'transform', 'predict_proba', - 'decision_function') - if ( - (not hasattr(estimator, 'fit')) or - (not any(hasattr(estimator, method) for method in valid_methods)) + valid_methods = ("predict", "transform", "predict_proba", "decision_function") + if (not hasattr(estimator, "fit")) or ( + not any(hasattr(estimator, method) for method in valid_methods) ): - raise ValueError('estimator must be a scikit-learn transformer or ' - 'an estimator with the fit and a predict-like (e.g. ' - 'predict_proba) or a transform method.') + raise ValueError( + "estimator must be a scikit-learn transformer or " + "an estimator with the fit and a predict-like (e.g. " + "predict_proba) or a transform method." + ) - if get_params and not hasattr(estimator, 'get_params'): - raise ValueError('estimator must be a scikit-learn transformer or an ' - 'estimator with the get_params method that allows ' - 'cloning.') + if get_params and not hasattr(estimator, "get_params"): + raise ValueError( + "estimator must be a scikit-learn transformer or an " + "estimator with the get_params method that allows " + "cloning." + ) def _get_inverse_funcs(estimator, terminal=True): """Retrieve the inverse functions of an pipeline or an estimator.""" inverse_func = [False] - if hasattr(estimator, 'steps'): + if hasattr(estimator, "steps"): # if pipeline, retrieve all steps by nesting inverse_func = list() for _, est in estimator.steps: inverse_func.extend(_get_inverse_funcs(est, terminal=False)) - elif hasattr(estimator, 'inverse_transform'): + elif hasattr(estimator, "inverse_transform"): # if not pipeline attempt to retrieve inverse function inverse_func = [estimator.inverse_transform] @@ -284,7 +291,7 @@ def _get_inverse_funcs(estimator, terminal=True): return inverse_func -def get_coef(estimator, attr='filters_', inverse_transform=False): +def get_coef(estimator, attr="filters_", inverse_transform=False): """Retrieve the coefficients of an estimator ending with a Linear Model. This is typically useful to retrieve "spatial filters" or "spatial @@ -312,13 +319,13 @@ def get_coef(estimator, attr='filters_', inverse_transform=False): """ # Get the coefficients of the last estimator in case of nested pipeline est = estimator - while hasattr(est, 'steps'): + while hasattr(est, "steps"): est = est.steps[-1][1] squeeze_first_dim = False # If SlidingEstimator, loop across estimators - if hasattr(est, 'estimators_'): + if hasattr(est, "estimators_"): coef = list() for this_est in est.estimators_: coef.append(get_coef(this_est, attr, inverse_transform)) @@ -326,8 +333,9 @@ def get_coef(estimator, attr='filters_', inverse_transform=False): coef = coef[np.newaxis] # fake a sample dimension squeeze_first_dim = True elif not hasattr(est, attr): - raise ValueError('This estimator does not have a %s attribute:\n%s' - % (attr, est)) + raise ValueError( + "This estimator does not have a %s attribute:\n%s" % (attr, est) + ) else: coef = getattr(est, attr) @@ -337,9 +345,10 @@ def get_coef(estimator, attr='filters_', inverse_transform=False): # inverse pattern e.g. to get back physical units if inverse_transform: - if not hasattr(estimator, 'steps') and not hasattr(est, 'estimators_'): - raise ValueError('inverse_transform can only be applied onto ' - 'pipeline estimators.') + if not hasattr(estimator, "steps") and not hasattr(est, "estimators_"): + raise ValueError( + "inverse_transform can only be applied onto " "pipeline estimators." + ) # The inverse_transform parameter will call this method on any # estimator contained in the pipeline, in reverse order. for inverse_func in _get_inverse_funcs(estimator)[::-1]: @@ -352,9 +361,18 @@ def get_coef(estimator, attr='filters_', inverse_transform=False): @verbose -def cross_val_multiscore(estimator, X, y=None, groups=None, scoring=None, - cv=None, n_jobs=None, verbose=None, fit_params=None, - pre_dispatch='2*n_jobs'): +def cross_val_multiscore( + estimator, + X, + y=None, + groups=None, + scoring=None, + cv=None, + n_jobs=None, + verbose=None, + fit_params=None, + pre_dispatch="2*n_jobs", +): """Evaluate a score by cross-validation. Parameters @@ -420,6 +438,7 @@ def cross_val_multiscore(estimator, X, y=None, groups=None, scoring=None, from sklearn.base import clone from sklearn.utils import indexable from sklearn.model_selection._split import check_cv + check_scoring = _get_check_scoring() X, y, groups = indexable(X, y, groups) @@ -430,15 +449,23 @@ def cross_val_multiscore(estimator, X, y=None, groups=None, scoring=None, # We clone the estimator to make sure that all the folds are # independent, and that it is pickle-able. # Note: this parallelization is implemented using MNE Parallel - parallel, p_func, n_jobs = parallel_func(_fit_and_score, n_jobs, - pre_dispatch=pre_dispatch) - position = hasattr(estimator, 'position') + parallel, p_func, n_jobs = parallel_func( + _fit_and_score, n_jobs, pre_dispatch=pre_dispatch + ) + position = hasattr(estimator, "position") scores = parallel( p_func( - estimator=clone(estimator), X=X, y=y, scorer=scorer, train=train, - test=test, fit_params=fit_params, verbose=verbose, + estimator=clone(estimator), + X=X, + y=y, + scorer=scorer, + train=train, + test=test, + fit_params=fit_params, + verbose=verbose, parameters=dict(position=ii % n_jobs) if position else None, - ) for ii, (train, test) in enumerate(cv_iter) + ) + for ii, (train, test) in enumerate(cv_iter) ) return np.array(scores)[:, 0, ...] # flatten over joblib output. @@ -446,11 +473,24 @@ def cross_val_multiscore(estimator, X, y=None, groups=None, scoring=None, # This verbose is necessary to properly set the verbosity level # during parallelization @verbose -def _fit_and_score(estimator, X, y, scorer, train, test, - parameters, fit_params, return_train_score=False, - return_parameters=False, return_n_test_samples=False, - return_times=False, error_score='raise', *, verbose=None, - position=0): +def _fit_and_score( + estimator, + X, + y, + scorer, + train, + test, + parameters, + fit_params, + return_train_score=False, + return_parameters=False, + return_n_test_samples=False, + return_times=False, + error_score="raise", + *, + verbose=None, + position=0 +): """Fit estimator and compute scores for a given dataset split.""" # This code is adapted from sklearn from ..fixes import _check_fit_params @@ -479,19 +519,23 @@ def _fit_and_score(estimator, X, y, scorer, train, test, # Note fit time as time until error fit_duration = dt.datetime.now() - start_time score_duration = dt.timedelta(0) - if error_score == 'raise': + if error_score == "raise": raise elif isinstance(error_score, numbers.Number): test_score = error_score if return_train_score: train_score = error_score - warn("Classifier fit failed. The score on this train-test" - " partition for these parameters will be set to %f. " - "Details: \n%r" % (error_score, e)) + warn( + "Classifier fit failed. The score on this train-test" + " partition for these parameters will be set to %f. " + "Details: \n%r" % (error_score, e) + ) else: - raise ValueError("error_score must be the string 'raise' or a" - " numeric value. (Hint: if using 'raise', please" - " make sure that it has been spelled correctly.)") + raise ValueError( + "error_score must be the string 'raise' or a" + " numeric value. (Hint: if using 'raise', please" + " make sure that it has been spelled correctly.)" + ) else: fit_duration = dt.datetime.now() - start_time @@ -505,10 +549,7 @@ def _fit_and_score(estimator, X, y, scorer, train, test, if return_n_test_samples: ret.append(_num_samples(X_test)) if return_times: - ret.extend([ - fit_duration.total_seconds(), - score_duration.total_seconds() - ]) + ret.extend([fit_duration.total_seconds(), score_duration.total_seconds()]) if return_parameters: ret.append(parameters) return ret @@ -524,7 +565,7 @@ def _score(estimator, X_test, y_test, scorer): score = scorer(estimator, X_test) else: score = scorer(estimator, X_test, y_test) - if hasattr(score, 'item'): + if hasattr(score, "item"): try: # e.g. unwrap memmapped scalars score = score.item() diff --git a/mne/decoding/csp.py b/mne/decoding/csp.py index 6e3ed67c163..f4c74ad6d91 100644 --- a/mne/decoding/csp.py +++ b/mne/decoding/csp.py @@ -13,8 +13,7 @@ from .base import BaseEstimator from .mixin import TransformerMixin from ..cov import _regularized_covariance -from ..defaults import (_BORDER_DEFAULT, _EXTRAPOLATE_DEFAULT, - _INTERPOLATION_DEFAULT) +from ..defaults import _BORDER_DEFAULT, _EXTRAPOLATE_DEFAULT, _INTERPOLATION_DEFAULT from ..fixes import pinv from ..utils import fill_doc, _check_option, _validate_type, copy_doc @@ -97,13 +96,21 @@ class CSP(TransformerMixin, BaseEstimator): .. footbibliography:: """ - def __init__(self, n_components=4, reg=None, log=None, cov_est='concat', - transform_into='average_power', norm_trace=False, - cov_method_params=None, rank=None, - component_order='mutual_info'): + def __init__( + self, + n_components=4, + reg=None, + log=None, + cov_est="concat", + transform_into="average_power", + norm_trace=False, + cov_method_params=None, + rank=None, + component_order="mutual_info", + ): # Init default CSP if not isinstance(n_components, int): - raise ValueError('n_components must be an integer.') + raise ValueError("n_components must be an integer.") self.n_components = n_components self.rank = rank self.reg = reg @@ -114,37 +121,39 @@ def __init__(self, n_components=4, reg=None, log=None, cov_est='concat', self.cov_est = cov_est # Init default transform_into - self.transform_into = _check_option('transform_into', transform_into, - ['average_power', 'csp_space']) + self.transform_into = _check_option( + "transform_into", transform_into, ["average_power", "csp_space"] + ) # Init default log - if transform_into == 'average_power': + if transform_into == "average_power": if log is not None and not isinstance(log, bool): - raise ValueError('log must be a boolean if transform_into == ' - '"average_power".') + raise ValueError( + "log must be a boolean if transform_into == " '"average_power".' + ) else: if log is not None: - raise ValueError('log must be a None if transform_into == ' - '"csp_space".') + raise ValueError( + "log must be a None if transform_into == " '"csp_space".' + ) self.log = log - _validate_type(norm_trace, bool, 'norm_trace') + _validate_type(norm_trace, bool, "norm_trace") self.norm_trace = norm_trace self.cov_method_params = cov_method_params - self.component_order = _check_option('component_order', - component_order, - ('mutual_info', 'alternate')) + self.component_order = _check_option( + "component_order", component_order, ("mutual_info", "alternate") + ) def _check_Xy(self, X, y=None): """Check input data.""" if not isinstance(X, np.ndarray): - raise ValueError("X should be of type ndarray (got %s)." - % type(X)) + raise ValueError("X should be of type ndarray (got %s)." % type(X)) if y is not None: if len(X) != len(y) or len(y) < 1: - raise ValueError('X and y must have the same length.') + raise ValueError("X and y must have the same length.") if X.ndim < 3: - raise ValueError('X must have at least 3 dimensions.') + raise ValueError("X must have at least 3 dimensions.") def fit(self, X, y): """Estimate the CSP decomposition on epochs. @@ -167,28 +176,30 @@ def fit(self, X, y): n_classes = len(self._classes) if n_classes < 2: raise ValueError("n_classes must be >= 2.") - if n_classes > 2 and self.component_order == 'alternate': - raise ValueError("component_order='alternate' requires two " - "classes, but data contains {} classes; use " - "component_order='mutual_info' " - "instead.".format(n_classes)) + if n_classes > 2 and self.component_order == "alternate": + raise ValueError( + "component_order='alternate' requires two " + "classes, but data contains {} classes; use " + "component_order='mutual_info' " + "instead.".format(n_classes) + ) covs, sample_weights = self._compute_covariance_matrices(X, y) - eigen_vectors, eigen_values = self._decompose_covs(covs, - sample_weights) - ix = self._order_components(covs, sample_weights, eigen_vectors, - eigen_values, self.component_order) + eigen_vectors, eigen_values = self._decompose_covs(covs, sample_weights) + ix = self._order_components( + covs, sample_weights, eigen_vectors, eigen_values, self.component_order + ) eigen_vectors = eigen_vectors[:, ix] self.filters_ = eigen_vectors.T self.patterns_ = pinv(eigen_vectors) - pick_filters = self.filters_[:self.n_components] + pick_filters = self.filters_[: self.n_components] X = np.asarray([np.dot(pick_filters, epoch) for epoch in X]) # compute features (mean power) - X = (X ** 2).mean(axis=2) + X = (X**2).mean(axis=2) # To standardize features self.mean_ = X.mean(axis=0) @@ -215,15 +226,16 @@ def transform(self, X): if not isinstance(X, np.ndarray): raise ValueError("X should be of type ndarray (got %s)." % type(X)) if self.filters_ is None: - raise RuntimeError('No filters available. Please first fit CSP ' - 'decomposition.') + raise RuntimeError( + "No filters available. Please first fit CSP " "decomposition." + ) - pick_filters = self.filters_[:self.n_components] + pick_filters = self.filters_[: self.n_components] X = np.asarray([np.dot(pick_filters, epoch) for epoch in X]) # compute features (mean band power) - if self.transform_into == 'average_power': - X = (X ** 2).mean(axis=2) + if self.transform_into == "average_power": + X = (X**2).mean(axis=2) log = True if self.log is None else self.log if log: X = np.log(X) @@ -238,14 +250,37 @@ def fit_transform(self, X, y, **fit_params): # noqa: D102 @fill_doc def plot_patterns( - self, info, components=None, *, average=None, ch_type=None, - scalings=None, sensors=True, show_names=False, mask=None, - mask_params=None, contours=6, outlines='head', sphere=None, - image_interp=_INTERPOLATION_DEFAULT, - extrapolate=_EXTRAPOLATE_DEFAULT, border=_BORDER_DEFAULT, res=64, - size=1, cmap='RdBu_r', vlim=(None, None), cnorm=None, - colorbar=True, cbar_fmt='%3.1f', units=None, axes=None, - name_format='CSP%01d', nrows=1, ncols='auto', show=True): + self, + info, + components=None, + *, + average=None, + ch_type=None, + scalings=None, + sensors=True, + show_names=False, + mask=None, + mask_params=None, + contours=6, + outlines="head", + sphere=None, + image_interp=_INTERPOLATION_DEFAULT, + extrapolate=_EXTRAPOLATE_DEFAULT, + border=_BORDER_DEFAULT, + res=64, + size=1, + cmap="RdBu_r", + vlim=(None, None), + cnorm=None, + colorbar=True, + cbar_fmt="%3.1f", + units=None, + axes=None, + name_format="CSP%01d", + nrows=1, + ncols="auto", + show=True + ): """Plot topographic patterns of components. The patterns explain how the measured data was generated from the @@ -304,38 +339,81 @@ def plot_patterns( from .. import EvokedArray if units is None: - units = 'AU' + units = "AU" if components is None: components = np.arange(self.n_components) # set sampling frequency to have 1 component per time point info = cp.deepcopy(info) with info._unlock(): - info['sfreq'] = 1. + info["sfreq"] = 1.0 # create an evoked patterns = EvokedArray(self.patterns_.T, info, tmin=0) # the call plot_topomap fig = patterns.plot_topomap( - times=components, average=average, ch_type=ch_type, - scalings=scalings, sensors=sensors, show_names=show_names, - mask=mask, mask_params=mask_params, contours=contours, - outlines=outlines, sphere=sphere, image_interp=image_interp, - extrapolate=extrapolate, border=border, res=res, size=size, - cmap=cmap, vlim=vlim, cnorm=cnorm, colorbar=colorbar, - cbar_fmt=cbar_fmt, units=units, axes=axes, time_format=name_format, - nrows=nrows, ncols=ncols, show=show) + times=components, + average=average, + ch_type=ch_type, + scalings=scalings, + sensors=sensors, + show_names=show_names, + mask=mask, + mask_params=mask_params, + contours=contours, + outlines=outlines, + sphere=sphere, + image_interp=image_interp, + extrapolate=extrapolate, + border=border, + res=res, + size=size, + cmap=cmap, + vlim=vlim, + cnorm=cnorm, + colorbar=colorbar, + cbar_fmt=cbar_fmt, + units=units, + axes=axes, + time_format=name_format, + nrows=nrows, + ncols=ncols, + show=show, + ) return fig @fill_doc def plot_filters( - self, info, components=None, *, average=None, ch_type=None, - scalings=None, sensors=True, show_names=False, mask=None, - mask_params=None, contours=6, outlines='head', sphere=None, - image_interp=_INTERPOLATION_DEFAULT, - extrapolate=_EXTRAPOLATE_DEFAULT, border=_BORDER_DEFAULT, res=64, - size=1, cmap='RdBu_r', vlim=(None, None), cnorm=None, - colorbar=True, cbar_fmt='%3.1f', units=None, axes=None, - name_format='CSP%01d', nrows=1, ncols='auto', show=True): + self, + info, + components=None, + *, + average=None, + ch_type=None, + scalings=None, + sensors=True, + show_names=False, + mask=None, + mask_params=None, + contours=6, + outlines="head", + sphere=None, + image_interp=_INTERPOLATION_DEFAULT, + extrapolate=_EXTRAPOLATE_DEFAULT, + border=_BORDER_DEFAULT, + res=64, + size=1, + cmap="RdBu_r", + vlim=(None, None), + cnorm=None, + colorbar=True, + cbar_fmt="%3.1f", + units=None, + axes=None, + name_format="CSP%01d", + nrows=1, + ncols="auto", + show=True + ): """Plot topographic filters of components. The filters are used to extract discriminant neural sources from @@ -394,26 +472,46 @@ def plot_filters( from .. import EvokedArray if units is None: - units = 'AU' + units = "AU" if components is None: components = np.arange(self.n_components) # set sampling frequency to have 1 component per time point info = cp.deepcopy(info) with info._unlock(): - info['sfreq'] = 1. + info["sfreq"] = 1.0 # create an evoked filters = EvokedArray(self.filters_.T, info, tmin=0) # the call plot_topomap fig = filters.plot_topomap( - times=components, average=average, ch_type=ch_type, - scalings=scalings, sensors=sensors, show_names=show_names, - mask=mask, mask_params=mask_params, contours=contours, - outlines=outlines, sphere=sphere, image_interp=image_interp, - extrapolate=extrapolate, border=border, res=res, size=size, - cmap=cmap, vlim=vlim, cnorm=cnorm, colorbar=colorbar, - cbar_fmt=cbar_fmt, units=units, axes=axes, time_format=name_format, - nrows=nrows, ncols=ncols, show=show) + times=components, + average=average, + ch_type=ch_type, + scalings=scalings, + sensors=sensors, + show_names=show_names, + mask=mask, + mask_params=mask_params, + contours=contours, + outlines=outlines, + sphere=sphere, + image_interp=image_interp, + extrapolate=extrapolate, + border=border, + res=res, + size=size, + cmap=cmap, + vlim=vlim, + cnorm=cnorm, + colorbar=colorbar, + cbar_fmt=cbar_fmt, + units=units, + axes=axes, + time_format=name_format, + nrows=nrows, + ncols=ncols, + show=show, + ) return fig def _compute_covariance_matrices(self, X, y): @@ -444,18 +542,23 @@ def _concat_cov(self, x_class): x_class = np.transpose(x_class, [1, 0, 2]) x_class = x_class.reshape(n_channels, -1) cov = _regularized_covariance( - x_class, reg=self.reg, method_params=self.cov_method_params, - rank=self.rank) + x_class, reg=self.reg, method_params=self.cov_method_params, rank=self.rank + ) weight = x_class.shape[0] return cov, weight def _epoch_cov(self, x_class): """Mean of per-epoch covariances.""" - cov = sum(_regularized_covariance( - this_X, reg=self.reg, - method_params=self.cov_method_params, - rank=self.rank) for this_X in x_class) + cov = sum( + _regularized_covariance( + this_X, + reg=self.reg, + method_params=self.cov_method_params, + rank=self.rank, + ) + for this_X in x_class + ) cov /= len(x_class) weight = len(x_class) @@ -463,6 +566,7 @@ def _epoch_cov(self, x_class): def _decompose_covs(self, covs, sample_weights): from scipy import linalg + n_classes = len(covs) if n_classes == 2: eigen_values, eigen_vectors = linalg.eigh(covs[0], covs.sum(0)) @@ -470,8 +574,9 @@ def _decompose_covs(self, covs, sample_weights): # The multiclass case is adapted from # http://github.com/alexandrebarachant/pyRiemann eigen_vectors, D = _ajd_pham(covs) - eigen_vectors = self._normalize_eigenvectors(eigen_vectors.T, covs, - sample_weights) + eigen_vectors = self._normalize_eigenvectors( + eigen_vectors.T, covs, sample_weights + ) eigen_values = None return eigen_vectors, eigen_values @@ -481,12 +586,11 @@ def _compute_mutual_info(self, covs, sample_weights, eigen_vectors): mutual_info = [] for jj in range(eigen_vectors.shape[1]): aa, bb = 0, 0 - for (cov, prob) in zip(covs, class_probas): - tmp = np.dot(np.dot(eigen_vectors[:, jj].T, cov), - eigen_vectors[:, jj]) + for cov, prob in zip(covs, class_probas): + tmp = np.dot(np.dot(eigen_vectors[:, jj].T, cov), eigen_vectors[:, jj]) aa += prob * np.log(np.sqrt(tmp)) - bb += prob * (tmp ** 2 - 1) - mi = - (aa + (3.0 / 16) * (bb ** 2)) + bb += prob * (tmp**2 - 1) + mi = -(aa + (3.0 / 16) * (bb**2)) mutual_info.append(mi) return mutual_info @@ -496,25 +600,24 @@ def _normalize_eigenvectors(self, eigen_vectors, covs, sample_weights): mean_cov = np.average(covs, axis=0, weights=sample_weights) for ii in range(eigen_vectors.shape[1]): - tmp = np.dot(np.dot(eigen_vectors[:, ii].T, mean_cov), - eigen_vectors[:, ii]) + tmp = np.dot(np.dot(eigen_vectors[:, ii].T, mean_cov), eigen_vectors[:, ii]) eigen_vectors[:, ii] /= np.sqrt(tmp) return eigen_vectors - def _order_components(self, covs, sample_weights, eigen_vectors, - eigen_values, component_order): + def _order_components( + self, covs, sample_weights, eigen_vectors, eigen_values, component_order + ): n_classes = len(self._classes) - if component_order == 'mutual_info' and n_classes > 2: - mutual_info = self._compute_mutual_info(covs, sample_weights, - eigen_vectors) + if component_order == "mutual_info" and n_classes > 2: + mutual_info = self._compute_mutual_info(covs, sample_weights, eigen_vectors) ix = np.argsort(mutual_info)[::-1] - elif component_order == 'mutual_info' and n_classes == 2: + elif component_order == "mutual_info" and n_classes == 2: ix = np.argsort(np.abs(eigen_values - 0.5))[::-1] - elif component_order == 'alternate' and n_classes == 2: + elif component_order == "alternate" and n_classes == 2: i = np.argsort(eigen_values) ix = np.empty_like(i) - ix[1::2] = i[:len(i) // 2] - ix[0::2] = i[len(i) // 2:][::-1] + ix[1::2] = i[: len(i) // 2] + ix[0::2] = i[len(i) // 2 :][::-1] return ix @@ -583,16 +686,16 @@ def _ajd_pham(X, eps=1e-6, max_iter=15): decr += n_epochs * (g12 * np.conj(h12) + g21 * h21) / 2.0 - tmp = 1 + 1.j * 0.5 * np.imag(h12 * h21) - tmp = np.real(tmp + np.sqrt(tmp ** 2 - h12 * h21)) + tmp = 1 + 1.0j * 0.5 * np.imag(h12 * h21) + tmp = np.real(tmp + np.sqrt(tmp**2 - h12 * h21)) tau = np.array([[1, -h12 / tmp], [-h21 / tmp, 1]]) A[[ii, jj], :] = np.dot(tau, A[[ii, jj], :]) tmp = np.c_[A[:, Ii], A[:, Ij]] - tmp = np.reshape(tmp, (n_times * n_epochs, 2), order='F') + tmp = np.reshape(tmp, (n_times * n_epochs, 2), order="F") tmp = np.dot(tmp, tau.T) - tmp = np.reshape(tmp, (n_times, n_epochs * 2), order='F') + tmp = np.reshape(tmp, (n_times, n_epochs * 2), order="F") A[:, Ii] = tmp[:, :n_epochs] A[:, Ij] = tmp[:, n_epochs:] V[[ii, jj], :] = np.dot(tau, V[[ii, jj], :]) @@ -663,19 +766,31 @@ class SPoC(CSP): .. footbibliography:: """ - def __init__(self, n_components=4, reg=None, log=None, - transform_into='average_power', cov_method_params=None, - rank=None): + def __init__( + self, + n_components=4, + reg=None, + log=None, + transform_into="average_power", + cov_method_params=None, + rank=None, + ): """Init of SPoC.""" - super(SPoC, self).__init__(n_components=n_components, reg=reg, log=log, - cov_est="epoch", norm_trace=False, - transform_into=transform_into, rank=rank, - cov_method_params=cov_method_params) + super(SPoC, self).__init__( + n_components=n_components, + reg=reg, + log=log, + cov_est="epoch", + norm_trace=False, + transform_into=transform_into, + rank=rank, + cov_method_params=cov_method_params, + ) # Covariance estimation have to be done on the single epoch level, # unlike CSP where covariance estimation can also be achieved through # concatenation of all epochs from the same class. - delattr(self, 'cov_est') - delattr(self, 'norm_trace') + delattr(self, "cov_est") + delattr(self, "norm_trace") def fit(self, X, y): """Estimate the SPoC decomposition on epochs. @@ -693,6 +808,7 @@ def fit(self, X, y): Returns the modified instance. """ from scipy import linalg + self._check_Xy(X, y) if len(np.unique(y)) < 2: @@ -711,8 +827,11 @@ def fit(self, X, y): covs = np.empty((n_epochs, n_channels, n_channels)) for ii, epoch in enumerate(X): covs[ii] = _regularized_covariance( - epoch, reg=self.reg, method_params=self.cov_method_params, - rank=self.rank) + epoch, + reg=self.reg, + method_params=self.cov_method_params, + rank=self.rank, + ) C = covs.mean(0) Cz = np.mean(covs * target[:, np.newaxis, np.newaxis], axis=0) @@ -731,11 +850,11 @@ def fit(self, X, y): self.patterns_ = linalg.pinv(evecs).T # n_channels x n_channels self.filters_ = evecs # n_channels x n_channels - pick_filters = self.filters_[:self.n_components] + pick_filters = self.filters_[: self.n_components] X = np.asarray([np.dot(pick_filters, epoch) for epoch in X]) # compute features (mean band power) - X = (X ** 2).mean(axis=-1) + X = (X**2).mean(axis=-1) # To standardize features self.mean_ = X.mean(axis=0) diff --git a/mne/decoding/ems.py b/mne/decoding/ems.py index 3f125bfb74a..f0dabe4d681 100644 --- a/mne/decoding/ems.py +++ b/mne/decoding/ems.py @@ -41,11 +41,13 @@ class EMS(TransformerMixin, EstimatorMixin): """ def __repr__(self): # noqa: D105 - if hasattr(self, 'filters_'): - return '' % ( - len(self.filters_), len(self.classes_)) + if hasattr(self, "filters_"): + return "" % ( + len(self.filters_), + len(self.classes_), + ) else: - return '' + return "" def fit(self, X, y): """Fit the spatial filters. @@ -67,7 +69,7 @@ def fit(self, X, y): """ classes = np.unique(y) if len(classes) != 2: - raise ValueError('EMS only works for binary classification.') + raise ValueError("EMS only works for binary classification.") self.classes_ = classes filters = X[y == classes[0]].mean(0) - X[y == classes[1]].mean(0) filters /= np.linalg.norm(filters, axis=0)[None, :] @@ -92,8 +94,9 @@ def transform(self, X): @verbose -def compute_ems(epochs, conditions=None, picks=None, n_jobs=None, cv=None, - verbose=None): +def compute_ems( + epochs, conditions=None, picks=None, n_jobs=None, cv=None, verbose=None +): """Compute event-matched spatial filter on epochs. This version of EMS :footcite:`SchurgerEtAl2013` operates on the entire @@ -141,16 +144,18 @@ def compute_ems(epochs, conditions=None, picks=None, n_jobs=None, cv=None, ---------- .. footbibliography:: """ - logger.info('...computing surrogate time series. This can take some time') + logger.info("...computing surrogate time series. This can take some time") # Default to leave-one-out cv - cv = 'LeaveOneOut' if cv is None else cv + cv = "LeaveOneOut" if cv is None else cv picks = _picks_to_idx(epochs.info, picks) if not len(set(Counter(epochs.events[:, 2]).values())) == 1: - raise ValueError('The same number of epochs is required by ' - 'this function. Please consider ' - '`epochs.equalize_event_counts`') + raise ValueError( + "The same number of epochs is required by " + "this function. Please consider " + "`epochs.equalize_event_counts`" + ) if conditions is None: conditions = epochs.event_id.keys() @@ -161,9 +166,10 @@ def compute_ems(epochs, conditions=None, picks=None, n_jobs=None, cv=None, epochs.drop_bad() if len(conditions) != 2: - raise ValueError('Currently this function expects exactly 2 ' - 'conditions but you gave me %i' % - len(conditions)) + raise ValueError( + "Currently this function expects exactly 2 " + "conditions but you gave me %i" % len(conditions) + ) ev = epochs.events[:, 2] # Special care to avoid path dependent mappings and orders @@ -175,10 +181,10 @@ def compute_ems(epochs, conditions=None, picks=None, n_jobs=None, cv=None, # Scale (z-score) the data by channel type # XXX the z-scoring is applied outside the CV, which is not standard. - for ch_type in ['mag', 'grad', 'eeg']: + for ch_type in ["mag", "grad", "eeg"]: if ch_type in epochs: # FIXME should be applied to all sort of data channels - if ch_type == 'eeg': + if ch_type == "eeg": this_picks = pick_types(info, meg=False, eeg=True) else: this_picks = pick_types(info, meg=ch_type, eeg=False) @@ -187,15 +193,16 @@ def compute_ems(epochs, conditions=None, picks=None, n_jobs=None, cv=None, # Setup cross-validation. Need to use _set_cv to deal with sklearn # deprecation of cv objects. y = epochs.events[:, 2] - _, cv_splits = _set_cv(cv, 'classifier', X=y, y=y) + _, cv_splits = _set_cv(cv, "classifier", X=y, y=y) parallel, p_func, n_jobs = parallel_func(_run_ems, n_jobs=n_jobs) # FIXME this parallelization should be removed. # 1) it's numpy computation so it's already efficient, # 2) it duplicates the data in RAM, # 3) the computation is already super fast. - out = parallel(p_func(_ems_diff, data, cond_idx, train, test) - for train, test in cv_splits) + out = parallel( + p_func(_ems_diff, data, cond_idx, train, test) for train, test in cv_splits + ) surrogate_trials, spatial_filter = zip(*out) surrogate_trials = np.array(surrogate_trials) @@ -212,6 +219,6 @@ def _ems_diff(data0, data1): def _run_ems(objective_function, data, cond_idx, train, test): """Run EMS.""" d = objective_function(*(data[np.intersect1d(c, train)] for c in cond_idx)) - d /= np.sqrt(np.sum(d ** 2, axis=0))[None, :] + d /= np.sqrt(np.sum(d**2, axis=0))[None, :] # compute surrogates return np.sum(data[test[0]] * d, axis=0), d diff --git a/mne/decoding/mixin.py b/mne/decoding/mixin.py index c000ae4b74d..d009e0a23ba 100644 --- a/mne/decoding/mixin.py +++ b/mne/decoding/mixin.py @@ -61,23 +61,26 @@ def set_params(self, **params): return self valid_params = self.get_params(deep=True) for key, value in params.items(): - split = key.split('__', 1) + split = key.split("__", 1) if len(split) > 1: # nested objects case name, sub_name = split if name not in valid_params: - raise ValueError('Invalid parameter %s for estimator %s. ' - 'Check the list of available parameters ' - 'with `estimator.get_params().keys()`.' % - (name, self)) + raise ValueError( + "Invalid parameter %s for estimator %s. " + "Check the list of available parameters " + "with `estimator.get_params().keys()`." % (name, self) + ) sub_object = valid_params[name] sub_object.set_params(**{sub_name: value}) else: # simple objects case if key not in valid_params: - raise ValueError('Invalid parameter %s for estimator %s. ' - 'Check the list of available parameters ' - 'with `estimator.get_params().keys()`.' % - (key, self.__class__.__name__)) + raise ValueError( + "Invalid parameter %s for estimator %s. " + "Check the list of available parameters " + "with `estimator.get_params().keys()`." + % (key, self.__class__.__name__) + ) setattr(self, key, value) return self diff --git a/mne/decoding/receptive_field.py b/mne/decoding/receptive_field.py index cf6e6dd35bc..6fa38a4f72f 100644 --- a/mne/decoding/receptive_field.py +++ b/mne/decoding/receptive_field.py @@ -103,14 +103,25 @@ class ReceptiveField(BaseEstimator): """ # noqa E501 @verbose - def __init__(self, tmin, tmax, sfreq, feature_names=None, estimator=None, - fit_intercept=None, scoring='r2', patterns=False, - n_jobs=None, edge_correction=True, verbose=None): + def __init__( + self, + tmin, + tmax, + sfreq, + feature_names=None, + estimator=None, + fit_intercept=None, + scoring="r2", + patterns=False, + n_jobs=None, + edge_correction=True, + verbose=None, + ): self.feature_names = feature_names self.sfreq = float(sfreq) self.tmin = tmin self.tmax = tmax - self.estimator = 0. if estimator is None else estimator + self.estimator = 0.0 if estimator is None else estimator self.fit_intercept = fit_intercept self.scoring = scoring self.patterns = patterns @@ -123,7 +134,7 @@ def __repr__(self): # noqa: D105 if not isinstance(estimator, str): estimator = type(self.estimator) s += "estimator : %s, " % (estimator,) - if hasattr(self, 'coef_'): + if hasattr(self, "coef_"): if self.feature_names is not None: feats = self.feature_names if len(feats) == 1: @@ -133,7 +144,7 @@ def __repr__(self): # noqa: D105 s += "fit: True" else: s += "fit: False" - if hasattr(self, 'scores_'): + if hasattr(self, "scores_"): s += "scored (%s)" % self.scoring return "" % s @@ -141,12 +152,13 @@ def _delay_and_reshape(self, X, y=None): """Delay and reshape the variables.""" if not isinstance(self.estimator_, TimeDelayingRidge): # X is now shape (n_times, n_epochs, n_feats, n_delays) - X = _delay_time_series(X, self.tmin, self.tmax, self.sfreq, - fill_mean=self.fit_intercept) + X = _delay_time_series( + X, self.tmin, self.tmax, self.sfreq, fill_mean=self.fit_intercept + ) X = _reshape_for_est(X) # Concat times + epochs if y is not None: - y = y.reshape(-1, y.shape[-1], order='F') + y = y.reshape(-1, y.shape[-1], order="F") return X, y def fit(self, X, y): @@ -165,15 +177,20 @@ def fit(self, X, y): The instance so you can chain operations. """ from scipy import linalg + if self.scoring not in _SCORERS.keys(): - raise ValueError('scoring must be one of %s, got' - '%s ' % (sorted(_SCORERS.keys()), self.scoring)) + raise ValueError( + "scoring must be one of %s, got" + "%s " % (sorted(_SCORERS.keys()), self.scoring) + ) from sklearn.base import clone + X, y, _, self._y_dim = self._check_dimensions(X, y) if self.tmin > self.tmax: - raise ValueError('tmin (%s) must be at most tmax (%s)' - % (self.tmin, self.tmax)) + raise ValueError( + "tmin (%s) must be at most tmax (%s)" % (self.tmin, self.tmax) + ) # Initialize delays self.delays_ = _times_to_delays(self.tmin, self.tmax, self.sfreq) @@ -184,23 +201,33 @@ def fit(self, X, y): if self.fit_intercept is None: self.fit_intercept = True estimator = TimeDelayingRidge( - self.tmin, self.tmax, self.sfreq, alpha=self.estimator, - fit_intercept=self.fit_intercept, n_jobs=self.n_jobs, - edge_correction=self.edge_correction) + self.tmin, + self.tmax, + self.sfreq, + alpha=self.estimator, + fit_intercept=self.fit_intercept, + n_jobs=self.n_jobs, + edge_correction=self.edge_correction, + ) elif is_regressor(self.estimator): estimator = clone(self.estimator) - if self.fit_intercept is not None and \ - estimator.fit_intercept != self.fit_intercept: + if ( + self.fit_intercept is not None + and estimator.fit_intercept != self.fit_intercept + ): raise ValueError( - 'Estimator fit_intercept (%s) != initialization ' - 'fit_intercept (%s), initialize ReceptiveField with the ' - 'same fit_intercept value or use fit_intercept=None' - % (estimator.fit_intercept, self.fit_intercept)) + "Estimator fit_intercept (%s) != initialization " + "fit_intercept (%s), initialize ReceptiveField with the " + "same fit_intercept value or use fit_intercept=None" + % (estimator.fit_intercept, self.fit_intercept) + ) self.fit_intercept = estimator.fit_intercept else: - raise ValueError('`estimator` must be a float or an instance' - ' of `BaseEstimator`,' - ' got type %s.' % type(self.estimator)) + raise ValueError( + "`estimator` must be a float or an instance" + " of `BaseEstimator`," + " got type %s." % type(self.estimator) + ) self.estimator_ = estimator del estimator _check_estimator(self.estimator_) @@ -211,16 +238,17 @@ def fit(self, X, y): n_delays = len(self.delays_) # Update feature names if we have none - if ((self.feature_names is not None) and - (len(self.feature_names) != n_feats)): - raise ValueError('n_features in X does not match feature names ' - '(%s != %s)' % (n_feats, len(self.feature_names))) + if (self.feature_names is not None) and (len(self.feature_names) != n_feats): + raise ValueError( + "n_features in X does not match feature names " + "(%s != %s)" % (n_feats, len(self.feature_names)) + ) # Create input features X, y = self._delay_and_reshape(X, y) self.estimator_.fit(X, y) - coef = get_coef(self.estimator_, 'coef_') # (n_targets, n_features) + coef = get_coef(self.estimator_, "coef_") # (n_targets, n_features) shape = [n_feats, n_delays] if self._y_dim > 1: shape.insert(0, -1) @@ -230,7 +258,7 @@ def fit(self, X, y): if self.patterns: if isinstance(self.estimator_, TimeDelayingRidge): cov_ = self.estimator_.cov_ / float(n_times * n_epochs - 1) - y = y.reshape(-1, y.shape[-1], order='F') + y = y.reshape(-1, y.shape[-1], order="F") else: X = X - X.mean(0, keepdims=True) cov_ = np.cov(X.T) @@ -241,7 +269,7 @@ def fit(self, X, y): y = y - y.mean(0, keepdims=True) inv_Y = linalg.pinv(np.cov(y.T)) else: - inv_Y = 1. / float(n_times * n_epochs - 1) + inv_Y = 1.0 / float(n_times * n_epochs - 1) del y # Inverse coef according to Haufe's method @@ -267,8 +295,8 @@ def predict(self, X): unaffected by edge artifacts during the time delaying step) can be obtained using ``y_pred[rf.valid_samples_]``. """ - if not hasattr(self, 'delays_'): - raise ValueError('Estimator has not been fit yet.') + if not hasattr(self, "delays_"): + raise ValueError("Estimator has not been fit yet.") X, _, X_dim = self._check_dimensions(X, None, predict=True)[:3] del _ # convert to sklearn and back @@ -277,14 +305,14 @@ def predict(self, X): pred_shape = pred_shape + (self.coef_.shape[0],) X, _ = self._delay_and_reshape(X) y_pred = self.estimator_.predict(X) - y_pred = y_pred.reshape(pred_shape, order='F') + y_pred = y_pred.reshape(pred_shape, order="F") shape = list(y_pred.shape) if X_dim <= 2: shape.pop(1) # epochs extra = 0 else: extra = 1 - shape = shape[:self._y_dim + extra] + shape = shape[: self._y_dim + extra] y_pred.shape = shape return y_pred @@ -319,10 +347,10 @@ def score(self, X, y): y = y[self.valid_samples_] # Re-vectorize and call scorer - y = y.reshape([-1, n_outputs], order='F') - y_pred = y_pred.reshape([-1, n_outputs], order='F') + y = y.reshape([-1, n_outputs], order="F") + y_pred = y_pred.reshape([-1, n_outputs], order="F") assert y.shape == y_pred.shape - scores = scorer_(y, y_pred, multioutput='raw_values') + scores = scorer_(y, y_pred, multioutput="raw_values") return scores def _check_dimensions(self, X, y, predict=False): @@ -337,28 +365,39 @@ def _check_dimensions(self, X, y, predict=False): elif y_dim == 2: y = y[:, np.newaxis, :] # epochs else: - raise ValueError('y must be shape (n_times[, n_epochs]' - '[,n_outputs], got %s' % (y.shape,)) + raise ValueError( + "y must be shape (n_times[, n_epochs]" + "[,n_outputs], got %s" % (y.shape,) + ) elif X.ndim == 3: if y is not None: if y.ndim == 2: y = y[:, :, np.newaxis] # Add an outputs dim elif y.ndim != 3: - raise ValueError('If X has 3 dimensions, ' - 'y must have 2 or 3 dimensions') + raise ValueError( + "If X has 3 dimensions, " "y must have 2 or 3 dimensions" + ) else: - raise ValueError('X must be shape (n_times[, n_epochs],' - ' n_features), got %s' % (X.shape,)) + raise ValueError( + "X must be shape (n_times[, n_epochs]," + " n_features), got %s" % (X.shape,) + ) if y is not None: if X.shape[0] != y.shape[0]: - raise ValueError('X and y do not have the same n_times\n' - '%s != %s' % (X.shape[0], y.shape[0])) + raise ValueError( + "X and y do not have the same n_times\n" + "%s != %s" % (X.shape[0], y.shape[0]) + ) if X.shape[1] != y.shape[1]: - raise ValueError('X and y do not have the same n_epochs\n' - '%s != %s' % (X.shape[1], y.shape[1])) + raise ValueError( + "X and y do not have the same n_epochs\n" + "%s != %s" % (X.shape[1], y.shape[1]) + ) if predict and y.shape[-1] != len(self.estimator_.coef_): - raise ValueError('Number of outputs does not match' - ' estimator coefficients dimensions') + raise ValueError( + "Number of outputs does not match" + " estimator coefficients dimensions" + ) return X, y, X_dim, y_dim @@ -423,15 +462,14 @@ def _delay_time_series(X, tmin, tmax, sfreq, fill_mean=False): use_X = X out[:] = use_X if fill_mean: - out[:] += (mean_value - use_X.mean(axis=0)) + out[:] += mean_value - use_X.mean(axis=0) return delayed def _times_to_delays(tmin, tmax, sfreq): """Convert a tmin/tmax in seconds to delays.""" # Convert seconds to samples - delays = np.arange(int(np.round(tmin * sfreq)), - int(np.round(tmax * sfreq) + 1)) + delays = np.arange(int(np.round(tmin * sfreq)), int(np.round(tmax * sfreq) + 1)) return delays @@ -446,37 +484,39 @@ def _delays_to_slice(delays): def _check_delayer_params(tmin, tmax, sfreq): """Check delayer input parameters. For future custom delay support.""" - _validate_type(sfreq, 'numeric', '`sfreq`') + _validate_type(sfreq, "numeric", "`sfreq`") for tlim in (tmin, tmax): - _validate_type(tlim, 'numeric', 'tmin/tmax') + _validate_type(tlim, "numeric", "tmin/tmax") if not tmin <= tmax: - raise ValueError('tmin must be <= tmax') + raise ValueError("tmin must be <= tmax") def _reshape_for_est(X_del): """Convert X_del to a sklearn-compatible shape.""" n_times, n_epochs, n_feats, n_delays = X_del.shape X_del = X_del.reshape(n_times, n_epochs, -1) # concatenate feats - X_del = X_del.reshape(n_times * n_epochs, -1, order='F') + X_del = X_del.reshape(n_times * n_epochs, -1, order="F") return X_del # Create a correlation scikit-learn-style scorer def _corr_score(y_true, y, multioutput=None): from scipy.stats import pearsonr - assert multioutput == 'raw_values' + + assert multioutput == "raw_values" for this_y in (y_true, y): if this_y.ndim != 2: - raise ValueError('inputs must be shape (samples, outputs), got %s' - % (this_y.shape,)) - return np.array([pearsonr(y_true[:, ii], y[:, ii])[0] - for ii in range(y.shape[-1])]) + raise ValueError( + "inputs must be shape (samples, outputs), got %s" % (this_y.shape,) + ) + return np.array([pearsonr(y_true[:, ii], y[:, ii])[0] for ii in range(y.shape[-1])]) def _r2_score(y_true, y, multioutput=None): from sklearn.metrics import r2_score + return r2_score(y_true, y, multioutput=multioutput) -_SCORERS = {'r2': _r2_score, 'corrcoef': _corr_score} +_SCORERS = {"r2": _r2_score, "corrcoef": _corr_score} diff --git a/mne/decoding/search_light.py b/mne/decoding/search_light.py index 81c83b256a4..f2671b7ea11 100644 --- a/mne/decoding/search_light.py +++ b/mne/decoding/search_light.py @@ -8,8 +8,7 @@ from .base import BaseEstimator, _check_estimator from ..fixes import _get_check_scoring from ..parallel import parallel_func -from ..utils import (array_split_idx, ProgressBar, - verbose, fill_doc, _parse_verbose) +from ..utils import array_split_idx, ProgressBar, verbose, fill_doc, _parse_verbose @fill_doc @@ -35,8 +34,9 @@ class SlidingEstimator(BaseEstimator, TransformerMixin): """ @verbose - def __init__(self, base_estimator, scoring=None, n_jobs=None, *, - position=0, verbose=None): # noqa: D102 + def __init__( + self, base_estimator, scoring=None, n_jobs=None, *, position=0, verbose=None + ): # noqa: D102 _check_estimator(base_estimator) self._estimator_type = getattr(base_estimator, "_estimator_type", None) self.base_estimator = base_estimator @@ -46,11 +46,11 @@ def __init__(self, base_estimator, scoring=None, n_jobs=None, *, self.verbose = verbose def __repr__(self): # noqa: D105 - repr_str = '<' + super(SlidingEstimator, self).__repr__() - if hasattr(self, 'estimators_'): + repr_str = "<" + super(SlidingEstimator, self).__repr__() + if hasattr(self, "estimators_"): repr_str = repr_str[:-1] - repr_str += ', fitted with %i estimators' % len(self.estimators_) - return repr_str + '>' + repr_str += ", fitted with %i estimators" % len(self.estimators_) + return repr_str + ">" def fit(self, X, y, **fit_params): """Fit a series of independent estimators to the dataset. @@ -74,16 +74,16 @@ def fit(self, X, y, **fit_params): """ self._check_Xy(X, y) parallel, p_func, n_jobs = parallel_func( - _sl_fit, self.n_jobs, max_jobs=X.shape[-1], verbose=False) + _sl_fit, self.n_jobs, max_jobs=X.shape[-1], verbose=False + ) self.estimators_ = list() self.fit_params = fit_params # For fitting, the parallelization is across estimators. - context = _create_progressbar_context(self, X, 'Fitting') + context = _create_progressbar_context(self, X, "Fitting") with context as pb: estimators = parallel( - p_func(self.base_estimator, split, y, - pb.subset(pb_idx), **fit_params) + p_func(self.base_estimator, split, y, pb.subset(pb_idx), **fit_params) for pb_idx, split in array_split_idx(X, n_jobs, axis=-1) ) @@ -126,17 +126,17 @@ def _transform(self, X, method): self._check_Xy(X) method = _check_method(self.base_estimator, method) if X.shape[-1] != len(self.estimators_): - raise ValueError('The number of estimators does not match ' - 'X.shape[-1]') + raise ValueError("The number of estimators does not match " "X.shape[-1]") # For predictions/transforms the parallelization is across the data and # not across the estimators to avoid memory load. parallel, p_func, n_jobs = parallel_func( - _sl_transform, self.n_jobs, max_jobs=X.shape[-1], verbose=False) + _sl_transform, self.n_jobs, max_jobs=X.shape[-1], verbose=False + ) X_splits = np.array_split(X, n_jobs, axis=-1) idx, est_splits = zip(*array_split_idx(self.estimators_, n_jobs)) - context = _create_progressbar_context(self, X, 'Transforming') + context = _create_progressbar_context(self, X, "Transforming") with context as pb: y_pred = parallel( p_func(est, x, method, pb.subset(pb_idx)) @@ -166,7 +166,7 @@ def transform(self, X): Xt : array, shape (n_samples, n_estimators) The transformed values generated by each estimator. """ # noqa: E501 - return self._transform(X, 'transform') + return self._transform(X, "transform") def predict(self, X): """Predict each data slice/task with a series of independent estimators. @@ -188,7 +188,7 @@ def predict(self, X): y_pred : array, shape (n_samples, n_estimators) | (n_samples, n_tasks, n_targets) Predicted values for each estimator/data slice. """ # noqa: E501 - return self._transform(X, 'predict') + return self._transform(X, "predict") def predict_proba(self, X): """Predict each data slice with a series of independent estimators. @@ -210,7 +210,7 @@ def predict_proba(self, X): y_pred : array, shape (n_samples, n_tasks, n_classes) Predicted probabilities for each estimator/data slice/task. """ # noqa: E501 - return self._transform(X, 'predict_proba') + return self._transform(X, "predict_proba") def decision_function(self, X): """Estimate distances of each data slice to the hyperplanes. @@ -233,15 +233,15 @@ def decision_function(self, X): ----- This requires base_estimator to have a ``decision_function`` method. """ # noqa: E501 - return self._transform(X, 'decision_function') + return self._transform(X, "decision_function") def _check_Xy(self, X, y=None): """Aux. function to check input data.""" if y is not None: if len(X) != len(y) or len(y) < 1: - raise ValueError('X and y must have the same length.') + raise ValueError("X and y must have the same length.") if X.ndim < 3: - raise ValueError('X must have at least 3 dimensions.') + raise ValueError("X must have at least 3 dimensions.") def score(self, X, y): """Score each estimator on each task. @@ -270,8 +270,7 @@ def score(self, X, y): self._check_Xy(X) if X.shape[-1] != len(self.estimators_): - raise ValueError('The number of estimators does not match ' - 'X.shape[-1]') + raise ValueError("The number of estimators does not match " "X.shape[-1]") scoring = check_scoring(self.base_estimator, self.scoring) y = _fix_auc(scoring, y) @@ -279,21 +278,25 @@ def score(self, X, y): # For predictions/transforms the parallelization is across the data and # not across the estimators to avoid memory load. parallel, p_func, n_jobs = parallel_func( - _sl_score, self.n_jobs, max_jobs=X.shape[-1], verbose=False) + _sl_score, self.n_jobs, max_jobs=X.shape[-1], verbose=False + ) X_splits = np.array_split(X, n_jobs, axis=-1) est_splits = np.array_split(self.estimators_, n_jobs) - score = parallel(p_func(est, scoring, x, y) - for (est, x) in zip(est_splits, X_splits)) + score = parallel( + p_func(est, scoring, x, y) for (est, x) in zip(est_splits, X_splits) + ) score = np.concatenate(score, axis=0) return score @property def classes_(self): - if not hasattr(self.estimators_[0], 'classes_'): - raise AttributeError('classes_ attribute available only if ' - 'base_estimator has it, and estimator %s does' - ' not' % (self.estimators_[0],)) + if not hasattr(self.estimators_[0], "classes_"): + raise AttributeError( + "classes_ attribute available only if " + "base_estimator has it, and estimator %s does" + " not" % (self.estimators_[0],) + ) return self.estimators_[0].classes_ @@ -322,6 +325,7 @@ def _sl_fit(estimator, X, y, pb, **fit_params): The fitted estimators. """ from sklearn.base import clone + estimators_ = list() for ii in range(X.shape[-1]): est = clone(estimator) @@ -410,10 +414,10 @@ def _check_method(estimator, method): If method == 'transform' and estimator does not have 'transform', use 'predict' instead. """ - if method == 'transform' and not hasattr(estimator, 'transform'): - method = 'predict' + if method == "transform" and not hasattr(estimator, "transform"): + method = "predict" if not hasattr(estimator, method): - ValueError('base_estimator does not have `%s` method.' % method) + ValueError("base_estimator does not have `%s` method." % method) return method @@ -435,9 +439,9 @@ class GeneralizingEstimator(SlidingEstimator): def __repr__(self): # noqa: D105 repr_str = super(GeneralizingEstimator, self).__repr__() - if hasattr(self, 'estimators_'): + if hasattr(self, "estimators_"): repr_str = repr_str[:-1] - repr_str += ', fitted with %i estimators>' % len(self.estimators_) + repr_str += ", fitted with %i estimators>" % len(self.estimators_) return repr_str def _transform(self, X, method): @@ -446,14 +450,16 @@ def _transform(self, X, method): method = _check_method(self.base_estimator, method) parallel, p_func, n_jobs = parallel_func( - _gl_transform, self.n_jobs, max_jobs=X.shape[-1], verbose=False) + _gl_transform, self.n_jobs, max_jobs=X.shape[-1], verbose=False + ) - context = _create_progressbar_context(self, X, 'Transforming') + context = _create_progressbar_context(self, X, "Transforming") with context as pb: y_pred = parallel( p_func(self.estimators_, x_split, method, pb.subset(pb_idx)) for pb_idx, x_split in array_split_idx( - X, n_jobs, axis=-1, n_per_split=len(self.estimators_)) + X, n_jobs, axis=-1, n_per_split=len(self.estimators_) + ) ) y_pred = np.concatenate(y_pred, axis=2) @@ -475,7 +481,7 @@ def transform(self, X): Xt : array, shape (n_samples, n_estimators, n_slices) The transformed values generated by each estimator. """ - return self._transform(X, 'transform') + return self._transform(X, "transform") def predict(self, X): """Predict each data slice with all possible estimators. @@ -493,7 +499,7 @@ def predict(self, X): y_pred : array, shape (n_samples, n_estimators, n_slices) | (n_samples, n_estimators, n_slices, n_targets) The predicted values for each estimator. """ # noqa: E501 - return self._transform(X, 'predict') + return self._transform(X, "predict") def predict_proba(self, X): """Estimate probabilistic estimates of each data slice with all possible estimators. @@ -515,7 +521,7 @@ def predict_proba(self, X): ----- This requires ``base_estimator`` to have a ``predict_proba`` method. """ # noqa: E501 - return self._transform(X, 'predict_proba') + return self._transform(X, "predict_proba") def decision_function(self, X): """Estimate distances of each data slice to all hyperplanes. @@ -539,7 +545,7 @@ def decision_function(self, X): This requires ``base_estimator`` to have a ``decision_function`` method. """ # noqa: E501 - return self._transform(X, 'decision_function') + return self._transform(X, "decision_function") def score(self, X, y): """Score each of the estimators on the tested dimensions. @@ -565,16 +571,18 @@ def score(self, X, y): # For predictions/transforms the parallelization is across the data and # not across the estimators to avoid memory load. parallel, p_func, n_jobs = parallel_func( - _gl_score, self.n_jobs, max_jobs=X.shape[-1], verbose=False) + _gl_score, self.n_jobs, max_jobs=X.shape[-1], verbose=False + ) scoring = check_scoring(self.base_estimator, self.scoring) y = _fix_auc(scoring, y) - context = _create_progressbar_context(self, X, 'Scoring') + context = _create_progressbar_context(self, X, "Scoring") with context as pb: score = parallel( p_func(self.estimators_, scoring, x, y, pb.subset(pb_idx)) for pb_idx, x in array_split_idx( - X, n_jobs, axis=-1, n_per_split=len(self.estimators_)) + X, n_jobs, axis=-1, n_per_split=len(self.estimators_) + ) ) score = np.concatenate(score, axis=1) @@ -628,8 +636,7 @@ def _gl_init_pred(y_pred, X, n_train): """Aux. function to GeneralizingEstimator to initialize y_pred.""" n_sample, n_iter = X.shape[0], X.shape[-1] if y_pred.ndim == 3: - y_pred = np.zeros((n_sample, n_train, n_iter, y_pred.shape[-1]), - y_pred.dtype) + y_pred = np.zeros((n_sample, n_train, n_iter, y_pred.shape[-1]), y_pred.dtype) else: y_pred = np.zeros((n_sample, n_train, n_iter), y_pred.dtype) return y_pred @@ -679,30 +686,34 @@ def _gl_score(estimators, scoring, X, y, pb): def _fix_auc(scoring, y): from sklearn.preprocessing import LabelEncoder + # This fixes sklearn's inability to compute roc_auc when y not in [0, 1] # scikit-learn/scikit-learn#6874 if scoring is not None: - score_func = getattr(scoring, '_score_func', None) - kwargs = getattr(scoring, '_kwargs', {}) - if (getattr(score_func, '__name__', '') == 'roc_auc_score' and - kwargs.get('multi_class', 'raise') == 'raise'): + score_func = getattr(scoring, "_score_func", None) + kwargs = getattr(scoring, "_kwargs", {}) + if ( + getattr(score_func, "__name__", "") == "roc_auc_score" + and kwargs.get("multi_class", "raise") == "raise" + ): if np.ndim(y) != 1 or len(set(y)) != 2: - raise ValueError('roc_auc scoring can only be computed for ' - 'two-class problems.') + raise ValueError( + "roc_auc scoring can only be computed for " "two-class problems." + ) y = LabelEncoder().fit_transform(y) return y def _create_progressbar_context(inst, X, message): """Create a progress bar taking into account ``inst.verbose``.""" - multiply = (len(inst.estimators_) - if isinstance(inst, GeneralizingEstimator) else 1) + multiply = len(inst.estimators_) if isinstance(inst, GeneralizingEstimator) else 1 n_steps = X.shape[-1] * max(1, multiply) - mesg = f'{message} {inst.__class__.__name__}' + mesg = f"{message} {inst.__class__.__name__}" - which_tqdm = 'off' if not _check_verbose(inst.verbose) else None - context = ProgressBar(n_steps, mesg=mesg, position=inst.position, - which_tqdm=which_tqdm) + which_tqdm = "off" if not _check_verbose(inst.verbose) else None + context = ProgressBar( + n_steps, mesg=mesg, position=inst.position, which_tqdm=which_tqdm + ) return context diff --git a/mne/decoding/ssd.py b/mne/decoding/ssd.py index 4739264f544..8b747e4e350 100644 --- a/mne/decoding/ssd.py +++ b/mne/decoding/ssd.py @@ -13,8 +13,13 @@ from ..rank import compute_rank from ..time_frequency import psd_array_welch from ..utils import ( - fill_doc, logger, _check_option, _time_mask, _validate_type, - _verbose_safe_false) + fill_doc, + logger, + _check_option, + _time_mask, + _validate_type, + _verbose_safe_false, +) @fill_doc @@ -84,52 +89,66 @@ class SSD(BaseEstimator, TransformerMixin): .. footbibliography:: """ - def __init__(self, info, filt_params_signal, filt_params_noise, - reg=None, n_components=None, picks=None, - sort_by_spectral_ratio=True, return_filtered=False, - n_fft=None, cov_method_params=None, rank=None): + def __init__( + self, + info, + filt_params_signal, + filt_params_noise, + reg=None, + n_components=None, + picks=None, + sort_by_spectral_ratio=True, + return_filtered=False, + n_fft=None, + cov_method_params=None, + rank=None, + ): """Initialize instance.""" dicts = {"signal": filt_params_signal, "noise": filt_params_noise} - for param, dd in [('l', 0), ('h', 0), ('l', 1), ('h', 1)]: - key = ('signal', 'noise')[dd] - if param + '_freq' not in dicts[key]: + for param, dd in [("l", 0), ("h", 0), ("l", 1), ("h", 1)]: + key = ("signal", "noise")[dd] + if param + "_freq" not in dicts[key]: raise ValueError( - '%s must be defined in filter parameters for %s' - % (param + '_freq', key)) - val = dicts[key][param + '_freq'] + "%s must be defined in filter parameters for %s" + % (param + "_freq", key) + ) + val = dicts[key][param + "_freq"] if not isinstance(val, (int, float)): - _validate_type(val, ('numeric',), f'{key} {param}_freq') + _validate_type(val, ("numeric",), f"{key} {param}_freq") # check freq bands - if (filt_params_noise['l_freq'] > filt_params_signal['l_freq'] or - filt_params_signal['h_freq'] > filt_params_noise['h_freq']): - raise ValueError('Wrongly specified frequency bands!\n' - 'The signal band-pass must be within the noise ' - 'band-pass!') - self.picks_ = _picks_to_idx(info, picks, none='data', exclude='bads') + if ( + filt_params_noise["l_freq"] > filt_params_signal["l_freq"] + or filt_params_signal["h_freq"] > filt_params_noise["h_freq"] + ): + raise ValueError( + "Wrongly specified frequency bands!\n" + "The signal band-pass must be within the noise " + "band-pass!" + ) + self.picks_ = _picks_to_idx(info, picks, none="data", exclude="bads") del picks ch_types = _get_channel_types(info, picks=self.picks_, unique=True) if len(ch_types) > 1: - raise ValueError('At this point SSD only supports fitting ' - 'single channel types. Your info has %i types' % - (len(ch_types))) + raise ValueError( + "At this point SSD only supports fitting " + "single channel types. Your info has %i types" % (len(ch_types)) + ) self.info = info - self.freqs_signal = (filt_params_signal['l_freq'], - filt_params_signal['h_freq']) - self.freqs_noise = (filt_params_noise['l_freq'], - filt_params_noise['h_freq']) + self.freqs_signal = (filt_params_signal["l_freq"], filt_params_signal["h_freq"]) + self.freqs_noise = (filt_params_noise["l_freq"], filt_params_noise["h_freq"]) self.filt_params_signal = filt_params_signal self.filt_params_noise = filt_params_noise # check if boolean if not isinstance(sort_by_spectral_ratio, (bool)): - raise ValueError('sort_by_spectral_ratio must be boolean') + raise ValueError("sort_by_spectral_ratio must be boolean") self.sort_by_spectral_ratio = sort_by_spectral_ratio if n_fft is None: - self.n_fft = int(self.info['sfreq']) + self.n_fft = int(self.info["sfreq"]) else: self.n_fft = int(n_fft) # check if boolean if not isinstance(return_filtered, (bool)): - raise ValueError('return_filtered must be boolean') + raise ValueError("return_filtered must be boolean") self.return_filtered = return_filtered self.reg = reg self.n_components = n_components @@ -138,13 +157,14 @@ def __init__(self, info, filt_params_signal, filt_params_noise, def _check_X(self, X): """Check input data.""" - _validate_type(X, np.ndarray, 'X') - _check_option('X.ndim', X.ndim, (2, 3)) + _validate_type(X, np.ndarray, "X") + _check_option("X.ndim", X.ndim, (2, 3)) n_chan = X.shape[-2] - if n_chan != self.info['nchan']: - raise ValueError('Info must match the input data.' - 'Found %i channels but expected %i.' % - (n_chan, self.info['nchan'])) + if n_chan != self.info["nchan"]: + raise ValueError( + "Info must match the input data." + "Found %i channels but expected %i." % (n_chan, self.info["nchan"]) + ) def fit(self, X, y=None): """Estimate the SSD decomposition on raw or epoched data. @@ -164,13 +184,12 @@ def fit(self, X, y=None): Returns the modified instance. """ from scipy import linalg + self._check_X(X) X_aux = X[..., self.picks_, :] - X_signal = filter_data( - X_aux, self.info['sfreq'], **self.filt_params_signal) - X_noise = filter_data( - X_aux, self.info['sfreq'], **self.filt_params_noise) + X_signal = filter_data(X_aux, self.info["sfreq"], **self.filt_params_signal) + X_noise = filter_data(X_aux, self.info["sfreq"], **self.filt_params_noise) X_noise -= X_signal if X.ndim == 3: X_signal = np.hstack(X_signal) @@ -178,15 +197,24 @@ def fit(self, X, y=None): # prevent rank change when computing cov with rank='full' cov_signal = _regularized_covariance( - X_signal, reg=self.reg, method_params=self.cov_method_params, - rank='full', info=self.info) + X_signal, + reg=self.reg, + method_params=self.cov_method_params, + rank="full", + info=self.info, + ) cov_noise = _regularized_covariance( - X_noise, reg=self.reg, method_params=self.cov_method_params, - rank='full', info=self.info) + X_noise, + reg=self.reg, + method_params=self.cov_method_params, + rank="full", + info=self.info, + ) # project cov to rank subspace - cov_signal, cov_noise, rank_proj = (_dimensionality_reduction( - cov_signal, cov_noise, self.info, self.rank)) + cov_signal, cov_noise, rank_proj = _dimensionality_reduction( + cov_signal, cov_noise, self.info, self.rank + ) eigvals_, eigvects_ = linalg.eigh(cov_signal, cov_noise) # sort in descending order @@ -204,7 +232,7 @@ def fit(self, X, y=None): if self.sort_by_spectral_ratio: _, sorter_spec = self.get_spectral_ratio(ssd_sources=X_ssd) self.sorter_spec = sorter_spec - logger.info('Done.') + logger.info("Done.") return self def transform(self, X): @@ -224,16 +252,15 @@ def transform(self, X): """ self._check_X(X) if self.filters_ is None: - raise RuntimeError('No filters available. Please first call fit') + raise RuntimeError("No filters available. Please first call fit") if self.return_filtered: X_aux = X[..., self.picks_, :] - X = filter_data(X_aux, self.info['sfreq'], - **self.filt_params_signal) + X = filter_data(X_aux, self.info["sfreq"], **self.filt_params_signal) X_ssd = self.filters_.T @ X[..., self.picks_, :] if X.ndim == 2: - X_ssd = X_ssd[self.sorter_spec][:self.n_components] + X_ssd = X_ssd[self.sorter_spec][: self.n_components] else: - X_ssd = X_ssd[:, self.sorter_spec, :][:, :self.n_components, :] + X_ssd = X_ssd[:, self.sorter_spec, :][:, : self.n_components, :] return X_ssd def get_spectral_ratio(self, ssd_sources): @@ -259,7 +286,8 @@ def get_spectral_ratio(self, ssd_sources): .. footbibliography:: """ psd, freqs = psd_array_welch( - ssd_sources, sfreq=self.info['sfreq'], n_fft=self.n_fft) + ssd_sources, sfreq=self.info["sfreq"], n_fft=self.n_fft + ) sig_idx = _time_mask(freqs, *self.freqs_signal) noise_idx = _time_mask(freqs, *self.freqs_noise) if psd.ndim == 3: @@ -275,7 +303,7 @@ def get_spectral_ratio(self, ssd_sources): def inverse_transform(self): """Not implemented yet.""" - raise NotImplementedError('inverse_transform is not yet available.') + raise NotImplementedError("inverse_transform is not yet available.") def apply(self, X): """Remove selected components from the signal. @@ -301,7 +329,7 @@ def apply(self, X): The processed data. """ X_ssd = self.transform(X) - pick_patterns = self.patterns_[self.sorter_spec][:self.n_components].T + pick_patterns = self.patterns_[self.sorter_spec][: self.n_components].T X = pick_patterns @ X_ssd return X @@ -309,17 +337,40 @@ def apply(self, X): def _dimensionality_reduction(cov_signal, cov_noise, info, rank): """Perform dimensionality reduction on the covariance matrices.""" from scipy import linalg + n_channels = cov_signal.shape[0] # find ranks of covariance matrices - rank_signal = list(compute_rank( - Covariance(cov_signal, info.ch_names, list(), list(), 0, - verbose=_verbose_safe_false()), - rank, _handle_default('scalings_cov_rank', None), info).values())[0] - rank_noise = list(compute_rank( - Covariance(cov_noise, info.ch_names, list(), list(), 0, - verbose=_verbose_safe_false()), - rank, _handle_default('scalings_cov_rank', None), info).values())[0] + rank_signal = list( + compute_rank( + Covariance( + cov_signal, + info.ch_names, + list(), + list(), + 0, + verbose=_verbose_safe_false(), + ), + rank, + _handle_default("scalings_cov_rank", None), + info, + ).values() + )[0] + rank_noise = list( + compute_rank( + Covariance( + cov_noise, + info.ch_names, + list(), + list(), + 0, + verbose=_verbose_safe_false(), + ), + rank, + _handle_default("scalings_cov_rank", None), + info, + ).values() + )[0] rank = np.min([rank_signal, rank_noise]) # should be identical if rank < n_channels: @@ -330,13 +381,18 @@ def _dimensionality_reduction(cov_signal, cov_noise, info, rank): eigvects = eigvects[:, ix] # compute rank subspace projection matrix rank_proj = np.matmul( - eigvects[:, :rank], np.eye(rank) * (eigvals[:rank]**-0.5)) + eigvects[:, :rank], np.eye(rank) * (eigvals[:rank] ** -0.5) + ) logger.info( - 'Projecting covariance of %i channels to %i rank subspace' - % (n_channels, rank,)) + "Projecting covariance of %i channels to %i rank subspace" + % ( + n_channels, + rank, + ) + ) else: rank_proj = np.eye(n_channels) - logger.info('Preserving covariance rank (%i)' % (rank,)) + logger.info("Preserving covariance rank (%i)" % (rank,)) # project covariance matrices to rank subspace cov_signal = np.matmul(rank_proj.T, np.matmul(cov_signal, rank_proj)) diff --git a/mne/decoding/tests/test_base.py b/mne/decoding/tests/test_base.py index 43f8b08d097..c7773a217d4 100644 --- a/mne/decoding/tests/test_base.py +++ b/mne/decoding/tests/test_base.py @@ -4,18 +4,27 @@ # License: BSD-3-Clause import numpy as np -from numpy.testing import (assert_array_equal, assert_array_almost_equal, - assert_equal, assert_allclose, assert_array_less) +from numpy.testing import ( + assert_array_equal, + assert_array_almost_equal, + assert_equal, + assert_allclose, + assert_array_less, +) import pytest from mne import create_info, EpochsArray from mne.fixes import is_regressor, is_classifier from mne.utils import requires_sklearn -from mne.decoding.base import (_get_inverse_funcs, LinearModel, get_coef, - cross_val_multiscore, BaseEstimator) +from mne.decoding.base import ( + _get_inverse_funcs, + LinearModel, + get_coef, + cross_val_multiscore, + BaseEstimator, +) from mne.decoding.search_light import SlidingEstimator -from mne.decoding import (Scaler, TransformerMixin, Vectorizer, - GeneralizingEstimator) +from mne.decoding import Scaler, TransformerMixin, Vectorizer, GeneralizingEstimator def _make_data(n_samples=1000, n_features=5, n_targets=3): @@ -43,7 +52,7 @@ def _make_data(n_samples=1000, n_features=5, n_targets=3): # Define Y latent factors np.random.seed(0) cov_Y = np.eye(n_targets) * 10 + np.random.rand(n_targets, n_targets) - cov_Y = (cov_Y + cov_Y.T) / 2. + cov_Y = (cov_Y + cov_Y.T) / 2.0 mean_Y = np.random.rand(n_targets) Y = np.random.multivariate_normal(mean_Y, cov_Y, size=n_samples) @@ -68,19 +77,21 @@ def test_get_coef(): from sklearn.model_selection import GridSearchCV lm_classification = LinearModel() - assert (is_classifier(lm_classification)) + assert is_classifier(lm_classification) lm_regression = LinearModel(Ridge()) - assert (is_regressor(lm_regression)) + assert is_regressor(lm_regression) - parameters = {'kernel': ['linear'], 'C': [1, 10]} + parameters = {"kernel": ["linear"], "C": [1, 10]} lm_gs_classification = LinearModel( - GridSearchCV(svm.SVC(), parameters, cv=2, refit=True, n_jobs=None)) - assert (is_classifier(lm_gs_classification)) + GridSearchCV(svm.SVC(), parameters, cv=2, refit=True, n_jobs=None) + ) + assert is_classifier(lm_gs_classification) lm_gs_regression = LinearModel( - GridSearchCV(svm.SVR(), parameters, cv=2, refit=True, n_jobs=None)) - assert (is_regressor(lm_gs_regression)) + GridSearchCV(svm.SVR(), parameters, cv=2, refit=True, n_jobs=None) + ) + assert is_regressor(lm_gs_regression) # Define a classifier, an invertible transformer and an non-invertible one. @@ -113,14 +124,15 @@ def inverse_transform(self, X): for expected_n, est in good_estimators: est.fit(X, y) - assert (expected_n == len(_get_inverse_funcs(est))) + assert expected_n == len(_get_inverse_funcs(est)) bad_estimators = [ Clf(), # no preprocessing Inv(), # final estimator isn't classifier make_pipeline(NoInv(), Clf()), # first step isn't invertible - make_pipeline(Inv(), make_pipeline( - Inv(), NoInv()), Clf()), # nested step isn't invertible + make_pipeline( + Inv(), make_pipeline(Inv(), NoInv()), Clf() + ), # nested step isn't invertible ] for est in bad_estimators: est.fit(X, y) @@ -129,11 +141,12 @@ def inverse_transform(self, X): # II. Test get coef for classification/regression estimators and pipelines rng = np.random.RandomState(0) - for clf in (lm_regression, - lm_gs_classification, - make_pipeline(StandardScaler(), lm_classification), - make_pipeline(StandardScaler(), lm_gs_regression)): - + for clf in ( + lm_regression, + lm_gs_classification, + make_pipeline(StandardScaler(), lm_classification), + make_pipeline(StandardScaler(), lm_gs_regression), + ): # generate some categorical/continuous data # according to the type of estimator. if is_classifier(clf): @@ -147,16 +160,16 @@ def inverse_transform(self, X): clf.fit(X, y) # Retrieve final linear model - filters = get_coef(clf, 'filters_', False) - if hasattr(clf, 'steps'): - if hasattr(clf.steps[-1][-1].model, 'best_estimator_'): + filters = get_coef(clf, "filters_", False) + if hasattr(clf, "steps"): + if hasattr(clf.steps[-1][-1].model, "best_estimator_"): # Linear Model with GridSearchCV coefs = clf.steps[-1][-1].model.best_estimator_.coef_ else: # Standard Linear Model coefs = clf.steps[-1][-1].model.coef_ else: - if hasattr(clf.model, 'best_estimator_'): + if hasattr(clf.model, "best_estimator_"): # Linear Model with GridSearchCV coefs = clf.model.best_estimator_.coef_ else: @@ -165,20 +178,19 @@ def inverse_transform(self, X): if coefs.ndim == 2 and coefs.shape[0] == 1: coefs = coefs[0] assert_array_equal(filters, coefs) - patterns = get_coef(clf, 'patterns_', False) - assert (filters[0] != patterns[0]) + patterns = get_coef(clf, "patterns_", False) + assert filters[0] != patterns[0] n_chans = X.shape[1] assert_array_equal(filters.shape, patterns.shape, [n_chans, n_chans]) # Inverse transform linear model - filters_inv = get_coef(clf, 'filters_', True) - assert (filters[0] != filters_inv[0]) - patterns_inv = get_coef(clf, 'patterns_', True) - assert (patterns[0] != patterns_inv[0]) + filters_inv = get_coef(clf, "filters_", True) + assert filters[0] != filters_inv[0] + patterns_inv = get_coef(clf, "patterns_", True) + assert patterns[0] != patterns_inv[0] class _Noop(BaseEstimator, TransformerMixin): - def fit(self, X, y=None): return self @@ -189,15 +201,19 @@ def transform(self, X): @requires_sklearn -@pytest.mark.parametrize('inverse', (True, False)) -@pytest.mark.parametrize('Scale, kwargs', [ - (Scaler, dict(info=None, scalings='mean')), - (_Noop, dict()), -]) +@pytest.mark.parametrize("inverse", (True, False)) +@pytest.mark.parametrize( + "Scale, kwargs", + [ + (Scaler, dict(info=None, scalings="mean")), + (_Noop, dict()), + ], +) def test_get_coef_inverse_transform(inverse, Scale, kwargs): """Test get_coef with and without inverse_transform.""" from sklearn.linear_model import Ridge from sklearn.pipeline import make_pipeline + lm_regression = LinearModel(Ridge()) X, y, A = _make_data(n_samples=1000, n_features=3, n_targets=1) # Check with search_light and combination of preprocessing ending with sl: @@ -208,29 +224,29 @@ def test_get_coef_inverse_transform(inverse, Scale, kwargs): X = np.transpose([X, -X], [1, 2, 0]) # invert X across 2 time samples clf = make_pipeline(Scale(**kwargs), slider) clf.fit(X, y) - patterns = get_coef(clf, 'patterns_', inverse) - filters = get_coef(clf, 'filters_', inverse) + patterns = get_coef(clf, "patterns_", inverse) + filters = get_coef(clf, "filters_", inverse) assert_array_equal(filters.shape, patterns.shape, X.shape[1:]) # the two time samples get inverted patterns assert_equal(patterns[0, 0], -patterns[0, 1]) for t in [0, 1]: filters_t = get_coef( - clf.named_steps['slidingestimator'].estimators_[t], - 'filters_', False) + clf.named_steps["slidingestimator"].estimators_[t], "filters_", False + ) if Scale is _Noop: assert_array_equal(filters_t, filters[:, t]) @requires_sklearn -@pytest.mark.parametrize('n_features', [1, 5]) -@pytest.mark.parametrize('n_targets', [1, 3]) +@pytest.mark.parametrize("n_features", [1, 5]) +@pytest.mark.parametrize("n_targets", [1, 3]) def test_get_coef_multiclass(n_features, n_targets): """Test get_coef on multiclass problems.""" # Check patterns with more than 1 regressor from sklearn.linear_model import LinearRegression, Ridge from sklearn.pipeline import make_pipeline - X, Y, A = _make_data( - n_samples=30000, n_features=n_features, n_targets=n_targets) + + X, Y, A = _make_data(n_samples=30000, n_features=n_features, n_targets=n_targets) lm = LinearModel(LinearRegression()).fit(X, Y) assert_array_equal(lm.filters_.shape, lm.patterns_.shape) if n_targets == 1: @@ -245,22 +261,22 @@ def test_get_coef_multiclass(n_features, n_targets): clf.fit(X, Y) if n_features > 1 and n_targets > 1: assert_allclose(A, lm.patterns_.T, atol=2e-2) - coef = get_coef(clf, 'patterns_', inverse_transform=True) + coef = get_coef(clf, "patterns_", inverse_transform=True) assert_allclose(lm.patterns_, coef, atol=1e-5) # With epochs, scaler, and vectorizer (typical use case) X_epo = X.reshape(X.shape + (1,)) - info = create_info(n_features, 1000., 'eeg') + info = create_info(n_features, 1000.0, "eeg") lm = LinearModel(Ridge(alpha=1)) clf = make_pipeline( - Scaler(info, scalings=dict(eeg=1.)), # XXX adding this step breaks + Scaler(info, scalings=dict(eeg=1.0)), # XXX adding this step breaks Vectorizer(), lm, ) clf.fit(X_epo, Y) if n_features > 1 and n_targets > 1: assert_allclose(A, lm.patterns_.T, atol=2e-2) - coef = get_coef(clf, 'patterns_', inverse_transform=True) + coef = get_coef(clf, "patterns_", inverse_transform=True) lm_patterns_ = lm.patterns_[..., np.newaxis] assert_allclose(lm_patterns_, coef, atol=1e-5) @@ -269,31 +285,36 @@ def test_get_coef_multiclass(n_features, n_targets): @requires_sklearn -@pytest.mark.parametrize('n_classes, n_channels, n_times', [ - (4, 10, 2), - (4, 3, 2), - (3, 2, 1), - (3, 1, 2), -]) +@pytest.mark.parametrize( + "n_classes, n_channels, n_times", + [ + (4, 10, 2), + (4, 3, 2), + (3, 2, 1), + (3, 1, 2), + ], +) def test_get_coef_multiclass_full(n_classes, n_channels, n_times): """Test a full example with pattern extraction.""" from sklearn.pipeline import make_pipeline from sklearn.linear_model import LogisticRegression from sklearn.model_selection import StratifiedKFold + data = np.zeros((10 * n_classes, n_channels, n_times)) # Make only the first channel informative for ii in range(n_classes): - data[ii * 10:(ii + 1) * 10, 0] = ii + data[ii * 10 : (ii + 1) * 10, 0] = ii events = np.zeros((len(data), 3), int) events[:, 0] = np.arange(len(events)) events[:, 2] = data[:, 0, 0] - info = create_info(n_channels, 1000., 'eeg') + info = create_info(n_channels, 1000.0, "eeg") epochs = EpochsArray(data, info, events, tmin=0) clf = make_pipeline( - Scaler(epochs.info), Vectorizer(), - LinearModel(LogisticRegression(random_state=0, multi_class='ovr')), + Scaler(epochs.info), + Vectorizer(), + LinearModel(LogisticRegression(random_state=0, multi_class="ovr")), ) - scorer = 'roc_auc_ovr_weighted' + scorer = "roc_auc_ovr_weighted" time_gen = GeneralizingEstimator(clf, scorer, verbose=True) X = epochs.get_data() y = epochs.events[:, 2] @@ -306,9 +327,9 @@ def test_get_coef_multiclass_full(n_classes, n_channels, n_times): assert scores.shape == want assert_array_less(0.8, scores) clf.fit(X, y) - patterns = get_coef(clf, 'patterns_', inverse_transform=True) + patterns = get_coef(clf, "patterns_", inverse_transform=True) assert patterns.shape == (n_classes, n_channels, n_times) - assert_allclose(patterns[:, 1:], 0., atol=1e-7) # no other channels useful + assert_allclose(patterns[:, 1:], 0.0, atol=1e-7) # no other channels useful @requires_sklearn @@ -316,6 +337,7 @@ def test_linearmodel(): """Test LinearModel class for computing filters and patterns.""" # check categorical target fit in standard linear model from sklearn.linear_model import LinearRegression + rng = np.random.RandomState(0) clf = LinearModel() n, n_features = 20, 3 @@ -331,9 +353,11 @@ def test_linearmodel(): # check categorical target fit in standard linear model with GridSearchCV from sklearn import svm from sklearn.model_selection import GridSearchCV - parameters = {'kernel': ['linear'], 'C': [1, 10]} + + parameters = {"kernel": ["linear"], "C": [1, 10]} clf = LinearModel( - GridSearchCV(svm.SVC(), parameters, cv=2, refit=True, n_jobs=None)) + GridSearchCV(svm.SVC(), parameters, cv=2, refit=True, n_jobs=None) + ) clf.fit(X, y) assert_equal(clf.filters_.shape, (n_features,)) assert_equal(clf.patterns_.shape, (n_features,)) @@ -345,10 +369,11 @@ def test_linearmodel(): n_targets = 1 Y = rng.rand(n, n_targets) clf = LinearModel( - GridSearchCV(svm.SVR(), parameters, cv=2, refit=True, n_jobs=None)) + GridSearchCV(svm.SVR(), parameters, cv=2, refit=True, n_jobs=None) + ) clf.fit(X, y) - assert_equal(clf.filters_.shape, (n_features, )) - assert_equal(clf.patterns_.shape, (n_features, )) + assert_equal(clf.filters_.shape, (n_features,)) + assert_equal(clf.patterns_.shape, (n_features,)) with pytest.raises(ValueError): wrong_y = rng.rand(n, n_features, 99) clf.fit(X, wrong_y) @@ -371,20 +396,21 @@ def test_cross_val_multiscore(): from sklearn.model_selection import KFold, StratifiedKFold, cross_val_score from sklearn.linear_model import LogisticRegression, LinearRegression - logreg = LogisticRegression(solver='liblinear', random_state=0) + logreg = LogisticRegression(solver="liblinear", random_state=0) # compare to cross-val-score X = np.random.rand(20, 3) y = np.arange(20) % 2 cv = KFold(2, random_state=0, shuffle=True) clf = logreg - assert_array_equal(cross_val_score(clf, X, y, cv=cv), - cross_val_multiscore(clf, X, y, cv=cv)) + assert_array_equal( + cross_val_score(clf, X, y, cv=cv), cross_val_multiscore(clf, X, y, cv=cv) + ) # Test with search light X = np.random.rand(20, 4, 3) y = np.arange(20) % 2 - clf = SlidingEstimator(logreg, scoring='accuracy') + clf = SlidingEstimator(logreg, scoring="accuracy") scores_acc = cross_val_multiscore(clf, X, y, cv=cv) assert_array_equal(np.shape(scores_acc), [2, 3]) @@ -399,9 +425,8 @@ def test_cross_val_multiscore(): # raise an error if scoring is defined at cross-val-score level and # search light, because search light does not return a 1-dimensional # prediction. - pytest.raises(ValueError, cross_val_multiscore, clf, X, y, cv=cv, - scoring='roc_auc') - clf = SlidingEstimator(logreg, scoring='roc_auc') + pytest.raises(ValueError, cross_val_multiscore, clf, X, y, cv=cv, scoring="roc_auc") + clf = SlidingEstimator(logreg, scoring="roc_auc") scores_auc = cross_val_multiscore(clf, X, y, cv=cv, n_jobs=None) scores_auc_manual = list() for train, test in cv.split(X, y): diff --git a/mne/decoding/tests/test_csp.py b/mne/decoding/tests/test_csp.py index 6945a812cf7..a505a6c7bdc 100644 --- a/mne/decoding/tests/test_csp.py +++ b/mne/decoding/tests/test_csp.py @@ -9,8 +9,7 @@ import numpy as np import pytest -from numpy.testing import (assert_array_almost_equal, assert_array_equal, - assert_equal) +from numpy.testing import assert_array_almost_equal, assert_array_equal, assert_equal from mne import io, Epochs, read_events, pick_types from mne.decoding.csp import CSP, _ajd_pham, SPoC @@ -46,23 +45,39 @@ def simulate_data(target, n_trials=100, n_channels=10, random_state=42): return X, mixing_mat -def deterministic_toy_data(classes=('class_a', 'class_b')): +def deterministic_toy_data(classes=("class_a", "class_b")): """Generate a small deterministic toy data set. Four independent sources are modulated by the target class and mixed into signal space. """ - sources_a = np.array([[0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1], - [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1], - [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1], - [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]], - dtype=float) * 2 - 1 - - sources_b = np.array([[0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1], - [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1], - [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1], - [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]], - dtype=float) * 2 - 1 + sources_a = ( + np.array( + [ + [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1], + [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1], + [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1], + ], + dtype=float, + ) + * 2 + - 1 + ) + + sources_b = ( + np.array( + [ + [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1], + [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1], + [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1], + ], + dtype=float, + ) + * 2 + - 1 + ) sources_a[0, :] *= 1 sources_a[1, :] *= 2 @@ -70,10 +85,14 @@ def deterministic_toy_data(classes=('class_a', 'class_b')): sources_b[2, :] *= 3 sources_b[3, :] *= 4 - mixing = np.array([[1.0, 0.8, 0.6, 0.4], - [0.8, 1.0, 0.8, 0.6], - [0.6, 0.8, 1.0, 0.8], - [0.4, 0.6, 0.8, 1.0]]) + mixing = np.array( + [ + [1.0, 0.8, 0.6, 0.4], + [0.8, 1.0, 0.8, 0.6], + [0.6, 0.8, 1.0, 0.8], + [0.4, 0.6, 0.8, 1.0], + ] + ) x_class_a = mixing @ sources_a x_class_b = mixing @ sources_b @@ -89,28 +108,38 @@ def test_csp(): """Test Common Spatial Patterns algorithm on epochs.""" raw = io.read_raw_fif(raw_fname, preload=False) events = read_events(event_name) - picks = pick_types(raw.info, meg=True, stim=False, ecg=False, - eog=False, exclude='bads') + picks = pick_types( + raw.info, meg=True, stim=False, ecg=False, eog=False, exclude="bads" + ) picks = picks[2:12:3] # subselect channels -> disable proj! raw.add_proj([], remove_existing=True) - epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, - baseline=(None, 0), preload=True, proj=False) + epochs = Epochs( + raw, + events, + event_id, + tmin, + tmax, + picks=picks, + baseline=(None, 0), + preload=True, + proj=False, + ) epochs_data = epochs.get_data() n_channels = epochs_data.shape[1] y = epochs.events[:, -1] # Init - pytest.raises(ValueError, CSP, n_components='foo', norm_trace=False) - for reg in ['foo', -0.1, 1.1]: + pytest.raises(ValueError, CSP, n_components="foo", norm_trace=False) + for reg in ["foo", -0.1, 1.1]: csp = CSP(reg=reg, norm_trace=False) pytest.raises(ValueError, csp.fit, epochs_data, epochs.events[:, -1]) - for reg in ['oas', 'ledoit_wolf', 0, 0.5, 1.]: + for reg in ["oas", "ledoit_wolf", 0, 0.5, 1.0]: CSP(reg=reg, norm_trace=False) - for cov_est in ['foo', None]: + for cov_est in ["foo", None]: pytest.raises(ValueError, CSP, cov_est=cov_est, norm_trace=False) - with pytest.raises(TypeError, match='instance of bool'): - CSP(norm_trace='foo') - for cov_est in ['concat', 'epoch']: + with pytest.raises(TypeError, match="instance of bool"): + CSP(norm_trace="foo") + for cov_est in ["concat", "epoch"]: CSP(cov_est=cov_est, norm_trace=False) n_components = 3 @@ -125,33 +154,40 @@ def test_csp(): # Transform X = csp.fit_transform(epochs_data, y) sources = csp.transform(epochs_data) - assert (sources.shape[1] == n_components) - assert (csp.filters_.shape == (n_channels, n_channels)) - assert (csp.patterns_.shape == (n_channels, n_channels)) + assert sources.shape[1] == n_components + assert csp.filters_.shape == (n_channels, n_channels) + assert csp.patterns_.shape == (n_channels, n_channels) assert_array_almost_equal(sources, X) # Test data exception - pytest.raises(ValueError, csp.fit, epochs_data, - np.zeros_like(epochs.events)) + pytest.raises(ValueError, csp.fit, epochs_data, np.zeros_like(epochs.events)) pytest.raises(ValueError, csp.fit, epochs, y) pytest.raises(ValueError, csp.transform, epochs) # Test plots - epochs.pick_types(meg='mag') - cmap = ('RdBu', True) + epochs.pick_types(meg="mag") + cmap = ("RdBu", True) components = np.arange(n_components) for plot in (csp.plot_patterns, csp.plot_filters): plot(epochs.info, components=components, res=12, show=False, cmap=cmap) # Test with more than 2 classes - epochs = Epochs(raw, events, tmin=tmin, tmax=tmax, picks=picks, - event_id=dict(aud_l=1, aud_r=2, vis_l=3, vis_r=4), - baseline=(None, 0), proj=False, preload=True) + epochs = Epochs( + raw, + events, + tmin=tmin, + tmax=tmax, + picks=picks, + event_id=dict(aud_l=1, aud_r=2, vis_l=3, vis_r=4), + baseline=(None, 0), + proj=False, + preload=True, + ) epochs_data = epochs.get_data() n_channels = epochs_data.shape[1] n_channels = epochs_data.shape[1] - for cov_est in ['concat', 'epoch']: + for cov_est in ["concat", "epoch"]: csp = CSP(n_components=n_components, cov_est=cov_est, norm_trace=False) csp.fit(epochs_data, epochs.events[:, 2]).transform(epochs_data) assert_equal(len(csp._classes), 4) @@ -160,31 +196,31 @@ def test_csp(): # Test average power transform n_components = 2 - assert (csp.transform_into == 'average_power') + assert csp.transform_into == "average_power" feature_shape = [len(epochs_data), n_components] X_trans = dict() for log in (None, True, False): csp = CSP(n_components=n_components, log=log, norm_trace=False) - assert (csp.log is log) + assert csp.log is log Xt = csp.fit_transform(epochs_data, epochs.events[:, 2]) assert_array_equal(Xt.shape, feature_shape) X_trans[str(log)] = Xt # log=None => log=True - assert_array_almost_equal(X_trans['None'], X_trans['True']) + assert_array_almost_equal(X_trans["None"], X_trans["True"]) # Different normalization return different transform - assert (np.sum((X_trans['True'] - X_trans['False']) ** 2) > 1.) + assert np.sum((X_trans["True"] - X_trans["False"]) ** 2) > 1.0 # Check wrong inputs - pytest.raises(ValueError, CSP, transform_into='average_power', log='foo') + pytest.raises(ValueError, CSP, transform_into="average_power", log="foo") # Test csp space transform - csp = CSP(transform_into='csp_space', norm_trace=False) - assert (csp.transform_into == 'csp_space') - for log in ('foo', True, False): - pytest.raises(ValueError, CSP, transform_into='csp_space', log=log, - norm_trace=False) + csp = CSP(transform_into="csp_space", norm_trace=False) + assert csp.transform_into == "csp_space" + for log in ("foo", True, False): + pytest.raises( + ValueError, CSP, transform_into="csp_space", log=log, norm_trace=False + ) n_components = 2 - csp = CSP(n_components=n_components, transform_into='csp_space', - norm_trace=False) + csp = CSP(n_components=n_components, transform_into="csp_space", norm_trace=False) Xt = csp.fit(epochs_data, epochs.events[:, 2]).transform(epochs_data) feature_shape = [len(epochs_data), n_components, epochs_data.shape[2]] assert_array_equal(Xt.shape, feature_shape) @@ -193,7 +229,7 @@ def test_csp(): y = np.array([100] * 50 + [1] * 50) X, A = simulate_data(y) - for cov_est in ['concat', 'epoch']: + for cov_est in ["concat", "epoch"]: # fit csp csp = CSP(n_components=1, cov_est=cov_est, norm_trace=False) csp.fit(X, y) @@ -214,36 +250,35 @@ def test_regularized_csp(): """Test Common Spatial Patterns algorithm using regularized covariance.""" raw = io.read_raw_fif(raw_fname) events = read_events(event_name) - picks = pick_types(raw.info, meg=True, stim=False, ecg=False, - eog=False, exclude='bads') + picks = pick_types( + raw.info, meg=True, stim=False, ecg=False, eog=False, exclude="bads" + ) picks = picks[1:13:3] - epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, - baseline=(None, 0), preload=True) + epochs = Epochs( + raw, events, event_id, tmin, tmax, picks=picks, baseline=(None, 0), preload=True + ) epochs_data = epochs.get_data() n_channels = epochs_data.shape[1] n_components = 3 - reg_cov = [None, 0.05, 'ledoit_wolf', 'oas'] + reg_cov = [None, 0.05, "ledoit_wolf", "oas"] for reg in reg_cov: - csp = CSP(n_components=n_components, reg=reg, norm_trace=False, - rank=None) + csp = CSP(n_components=n_components, reg=reg, norm_trace=False, rank=None) csp.fit(epochs_data, epochs.events[:, -1]) y = epochs.events[:, -1] X = csp.fit_transform(epochs_data, y) - assert (csp.filters_.shape == (n_channels, n_channels)) - assert (csp.patterns_.shape == (n_channels, n_channels)) - assert_array_almost_equal(csp.fit(epochs_data, y). - transform(epochs_data), X) + assert csp.filters_.shape == (n_channels, n_channels) + assert csp.patterns_.shape == (n_channels, n_channels) + assert_array_almost_equal(csp.fit(epochs_data, y).transform(epochs_data), X) # test init exception - pytest.raises(ValueError, csp.fit, epochs_data, - np.zeros_like(epochs.events)) + pytest.raises(ValueError, csp.fit, epochs_data, np.zeros_like(epochs.events)) pytest.raises(ValueError, csp.fit, epochs, y) pytest.raises(ValueError, csp.transform, epochs) csp.n_components = n_components sources = csp.transform(epochs_data) - assert (sources.shape[1] == n_components) + assert sources.shape[1] == n_components @requires_sklearn @@ -251,11 +286,12 @@ def test_csp_pipeline(): """Test if CSP works in a pipeline.""" from sklearn.svm import SVC from sklearn.pipeline import Pipeline + csp = CSP(reg=1, norm_trace=False) svc = SVC() pipe = Pipeline([("CSP", csp), ("SVC", svc)]) pipe.set_params(CSP__reg=0.2) - assert (pipe.get_params()["CSP__reg"] == 0.2) + assert pipe.get_params()["CSP__reg"] == 0.2 def test_ajd(): @@ -267,15 +303,17 @@ def test_ajd(): seed = np.random.RandomState(0) diags = 2.0 + 0.1 * seed.randn(n_times, n_channels) A = 2 * seed.rand(n_channels, n_channels) - 1 - A /= np.atleast_2d(np.sqrt(np.sum(A ** 2, 1))).T + A /= np.atleast_2d(np.sqrt(np.sum(A**2, 1))).T covmats = np.empty((n_times, n_channels, n_channels)) for i in range(n_times): covmats[i] = np.dot(np.dot(A, np.diag(diags[i])), A.T) V, D = _ajd_pham(covmats) # Results obtained with original matlab implementation - V_matlab = [[-3.507280775058041, -5.498189967306344, 7.720624541198574], - [0.694689013234610, 0.775690358505945, -1.162043086446043], - [-0.592603135588066, -0.598996925696260, 1.009550086271192]] + V_matlab = [ + [-3.507280775058041, -5.498189967306344, 7.720624541198574], + [0.694689013234610, 0.775690358505945, -1.162043086446043], + [-0.592603135588066, -0.598996925696260, 1.009550086271192], + ] assert_array_almost_equal(V, V_matlab) @@ -288,7 +326,7 @@ def test_spoc(): spoc.fit(X, y) Xt = spoc.transform(X) assert_array_equal(Xt.shape, [10, 4]) - spoc = SPoC(n_components=4, transform_into='csp_space') + spoc = SPoC(n_components=4, transform_into="csp_space") spoc.fit(X, y) Xt = spoc.transform(X) assert_array_equal(Xt.shape, [10, 4, 20]) @@ -299,7 +337,7 @@ def test_spoc(): pytest.raises(ValueError, spoc.fit, X, y * 0) # Check that doesn't take CSP-spcific input - pytest.raises(TypeError, SPoC, cov_est='epoch') + pytest.raises(TypeError, SPoC, cov_est="epoch") # Check mixing matrix on simulated data rs = np.random.RandomState(42) @@ -322,33 +360,32 @@ def test_spoc(): def test_csp_twoclass_symmetry(): """Test that CSP is symmetric when swapping classes.""" - x, y = deterministic_toy_data(['class_a', 'class_b']) - csp = CSP(norm_trace=False, transform_into='average_power', log=True) + x, y = deterministic_toy_data(["class_a", "class_b"]) + csp = CSP(norm_trace=False, transform_into="average_power", log=True) log_power = csp.fit_transform(x, y) log_power_ratio_ab = log_power[0] - log_power[1] - x, y = deterministic_toy_data(['class_b', 'class_a']) - csp = CSP(norm_trace=False, transform_into='average_power', log=True) + x, y = deterministic_toy_data(["class_b", "class_a"]) + csp = CSP(norm_trace=False, transform_into="average_power", log=True) log_power = csp.fit_transform(x, y) log_power_ratio_ba = log_power[0] - log_power[1] - assert_array_almost_equal(log_power_ratio_ab, - log_power_ratio_ba) + assert_array_almost_equal(log_power_ratio_ab, log_power_ratio_ba) def test_csp_component_ordering(): """Test that CSP component ordering works as expected.""" - x, y = deterministic_toy_data(['class_a', 'class_b']) + x, y = deterministic_toy_data(["class_a", "class_b"]) - pytest.raises(ValueError, CSP, component_order='invalid') + pytest.raises(ValueError, CSP, component_order="invalid") # component_order='alternate' only works with two classes - csp = CSP(component_order='alternate') + csp = CSP(component_order="alternate") with pytest.raises(ValueError): - csp.fit(np.zeros((3, 0, 0)), ['a', 'b', 'c']) + csp.fit(np.zeros((3, 0, 0)), ["a", "b", "c"]) - p_alt = CSP(component_order='alternate').fit(x, y).patterns_ - p_mut = CSP(component_order='mutual_info').fit(x, y).patterns_ + p_alt = CSP(component_order="alternate").fit(x, y).patterns_ + p_mut = CSP(component_order="mutual_info").fit(x, y).patterns_ # This permutation of p_alt and p_mut is explained by the particular # eigenvalues of the toy data: [0.06, 0.1, 0.5, 0.8]. diff --git a/mne/decoding/tests/test_ems.py b/mne/decoding/tests/test_ems.py index b24ebdd75aa..aaeea7c28f7 100644 --- a/mne/decoding/tests/test_ems.py +++ b/mne/decoding/tests/test_ems.py @@ -23,44 +23,55 @@ def test_ems(): """Test event-matched spatial filters.""" from sklearn.model_selection import StratifiedKFold + raw = io.read_raw_fif(raw_fname, preload=False) # create unequal number of events events = read_events(event_name) events[-2, 2] = 3 - picks = pick_types(raw.info, meg=True, stim=False, ecg=False, - eog=False, exclude='bads') + picks = pick_types( + raw.info, meg=True, stim=False, ecg=False, eog=False, exclude="bads" + ) picks = picks[1:13:3] - epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, - baseline=(None, 0), preload=True) - pytest.raises(ValueError, compute_ems, epochs, ['aud_l', 'vis_l']) + epochs = Epochs( + raw, events, event_id, tmin, tmax, picks=picks, baseline=(None, 0), preload=True + ) + pytest.raises(ValueError, compute_ems, epochs, ["aud_l", "vis_l"]) epochs.equalize_event_counts(epochs.event_id) - pytest.raises(KeyError, compute_ems, epochs, ['blah', 'hahah']) + pytest.raises(KeyError, compute_ems, epochs, ["blah", "hahah"]) surrogates, filters, conditions = compute_ems(epochs) assert_equal(list(set(conditions)), [1, 3]) events = read_events(event_name) event_id2 = dict(aud_l=1, aud_r=2, vis_l=3) - epochs = Epochs(raw, events, event_id2, tmin, tmax, picks=picks, - baseline=(None, 0), preload=True) + epochs = Epochs( + raw, + events, + event_id2, + tmin, + tmax, + picks=picks, + baseline=(None, 0), + preload=True, + ) epochs.equalize_event_counts(epochs.event_id) - n_expected = sum([len(epochs[k]) for k in ['aud_l', 'vis_l']]) + n_expected = sum([len(epochs[k]) for k in ["aud_l", "vis_l"]]) pytest.raises(ValueError, compute_ems, epochs) - surrogates, filters, conditions = compute_ems(epochs, ['aud_r', 'vis_l']) + surrogates, filters, conditions = compute_ems(epochs, ["aud_r", "vis_l"]) assert_equal(n_expected, len(surrogates)) assert_equal(n_expected, len(conditions)) assert_equal(list(set(conditions)), [2, 3]) # test compute_ems cv - epochs = epochs['aud_r', 'vis_l'] + epochs = epochs["aud_r", "vis_l"] epochs.equalize_event_counts(epochs.event_id) cv = StratifiedKFold(n_splits=3) compute_ems(epochs, cv=cv) compute_ems(epochs, cv=2) - pytest.raises(ValueError, compute_ems, epochs, cv='foo') + pytest.raises(ValueError, compute_ems, epochs, cv="foo") pytest.raises(ValueError, compute_ems, epochs, cv=len(epochs) + 1) raw.close() @@ -70,13 +81,13 @@ def test_ems(): X = X / np.std(X) # X scaled outside cv in compute_ems Xt, coefs = list(), list() ems = EMS() - assert_equal(ems.__repr__(), '') + assert_equal(ems.__repr__(), "") # manual leave-one-out to avoid sklearn version problem for test in range(len(y)): train = np.setdiff1d(range(len(y)), np.atleast_1d(test)) ems.fit(X[train], y[train]) coefs.append(ems.filters_) Xt.append(ems.transform(X[[test]])) - assert_equal(ems.__repr__(), '') + assert_equal(ems.__repr__(), "") assert_array_almost_equal(filters, np.mean(coefs, axis=0)) assert_array_almost_equal(surrogates, np.vstack(Xt)) diff --git a/mne/decoding/tests/test_receptive_field.py b/mne/decoding/tests/test_receptive_field.py index c5d62fb4c63..9a993b43669 100644 --- a/mne/decoding/tests/test_receptive_field.py +++ b/mne/decoding/tests/test_receptive_field.py @@ -12,10 +12,13 @@ from mne.utils import requires_sklearn from mne.decoding import ReceptiveField, TimeDelayingRidge -from mne.decoding.receptive_field import (_delay_time_series, _SCORERS, - _times_to_delays, _delays_to_slice) -from mne.decoding.time_delaying_ridge import (_compute_reg_neighbors, - _compute_corrs) +from mne.decoding.receptive_field import ( + _delay_time_series, + _SCORERS, + _times_to_delays, + _delays_to_slice, +) +from mne.decoding.time_delaying_ridge import _compute_reg_neighbors, _compute_corrs data_dir = Path(__file__).parent.parent.parent / "io" / "tests" / "data" @@ -26,30 +29,53 @@ event_id = dict(aud_l=1, vis_l=3) # Loading raw data -n_jobs_test = (1, 'cuda') +n_jobs_test = (1, "cuda") def test_compute_reg_neighbors(): """Test fast calculation of laplacian regularizer.""" for reg_type in ( - ('ridge', 'ridge'), - ('ridge', 'laplacian'), - ('laplacian', 'ridge'), - ('laplacian', 'laplacian')): + ("ridge", "ridge"), + ("ridge", "laplacian"), + ("laplacian", "ridge"), + ("laplacian", "laplacian"), + ): for n_ch_x, n_delays in ( - (1, 1), (1, 2), (2, 1), (1, 3), (3, 1), (1, 4), (4, 1), - (2, 2), (2, 3), (3, 2), (3, 3), - (2, 4), (4, 2), (3, 4), (4, 3), (4, 4), - (5, 4), (4, 5), (5, 5), - (20, 9), (9, 20)): + (1, 1), + (1, 2), + (2, 1), + (1, 3), + (3, 1), + (1, 4), + (4, 1), + (2, 2), + (2, 3), + (3, 2), + (3, 3), + (2, 4), + (4, 2), + (3, 4), + (4, 3), + (4, 4), + (5, 4), + (4, 5), + (5, 5), + (20, 9), + (9, 20), + ): for normed in (True, False): reg_direct = _compute_reg_neighbors( - n_ch_x, n_delays, reg_type, 'direct', normed=normed) + n_ch_x, n_delays, reg_type, "direct", normed=normed + ) reg_csgraph = _compute_reg_neighbors( - n_ch_x, n_delays, reg_type, 'csgraph', normed=normed) + n_ch_x, n_delays, reg_type, "csgraph", normed=normed + ) assert_allclose( - reg_direct, reg_csgraph, atol=1e-7, - err_msg='%s: %s' % (reg_type, (n_ch_x, n_delays))) + reg_direct, + reg_csgraph, + atol=1e-7, + err_msg="%s: %s" % (reg_type, (n_ch_x, n_delays)), + ) @requires_sklearn @@ -57,19 +83,20 @@ def test_rank_deficiency(): """Test signals that are rank deficient.""" # See GH#4253 from sklearn.linear_model import Ridge + N = 256 - fs = 1. + fs = 1.0 tmin, tmax = -50, 100 reg = 0.1 rng = np.random.RandomState(0) eeg = rng.randn(N, 1) eeg *= 100 eeg = rfft(eeg, axis=0) - eeg[N // 4:] = 0 # rank-deficient lowpass + eeg[N // 4 :] = 0 # rank-deficient lowpass eeg = irfft(eeg, axis=0) win = np.hanning(N // 8) win /= win.mean() - y = np.apply_along_axis(np.convolve, 0, eeg, win, mode='same') + y = np.apply_along_axis(np.convolve, 0, eeg, win, mode="same") y += rng.randn(*y.shape) * 100 for est in (Ridge(reg), reg): @@ -101,14 +128,15 @@ def test_time_delay(): ((-2, 0), 1), ((-2, -1), 1), ((-2, -1), 1), - ((0, .2), 10), - ((-.1, .1), 10)] + ((0, 0.2), 10), + ((-0.1, 0.1), 10), + ] for (tmin, tmax), isfreq in test_tlims: # sfreq must be int/float - with pytest.raises(TypeError, match='`sfreq` must be an instance of'): + with pytest.raises(TypeError, match="`sfreq` must be an instance of"): _delay_time_series(X, tmin, tmax, sfreq=[1]) # Delays must be int/float - with pytest.raises(TypeError, match='.*complex.*'): + with pytest.raises(TypeError, match=".*complex.*"): _delay_time_series(X, np.complex128(tmin), tmax, 1) # Make sure swapaxes works start, stop = int(round(tmin * isfreq)), int(round(tmax * isfreq)) + 1 @@ -128,34 +156,36 @@ def test_time_delay(): del_zero = int(round(-tmin * isfreq)) for ii in range(-2, 3): idx = del_zero + ii - err_msg = '[%s,%s] (%s): %s %s' % (tmin, tmax, isfreq, ii, idx) + err_msg = "[%s,%s] (%s): %s %s" % (tmin, tmax, isfreq, ii, idx) if 0 <= idx < X_delayed.shape[-1]: if ii == 0: - assert_array_equal(X_delayed[:, :, idx], X, - err_msg=err_msg) + assert_array_equal(X_delayed[:, :, idx], X, err_msg=err_msg) elif ii < 0: # negative delay - assert_array_equal(X_delayed[:ii, :, idx], X[-ii:, :], - err_msg=err_msg) - assert_array_equal(X_delayed[ii:, :, idx], 0.) + assert_array_equal( + X_delayed[:ii, :, idx], X[-ii:, :], err_msg=err_msg + ) + assert_array_equal(X_delayed[ii:, :, idx], 0.0) else: - assert_array_equal(X_delayed[ii:, :, idx], X[:-ii, :], - err_msg=err_msg) - assert_array_equal(X_delayed[:ii, :, idx], 0.) + assert_array_equal( + X_delayed[ii:, :, idx], X[:-ii, :], err_msg=err_msg + ) + assert_array_equal(X_delayed[:ii, :, idx], 0.0) @pytest.mark.slowtest # slow on Azure -@pytest.mark.parametrize('n_jobs', n_jobs_test) +@pytest.mark.parametrize("n_jobs", n_jobs_test) @requires_sklearn def test_receptive_field_basic(n_jobs): """Test model prep and fitting.""" from sklearn.linear_model import Ridge + # Make sure estimator pulling works mod = Ridge() rng = np.random.RandomState(1337) # Test the receptive field model # Define parameters for the model and simulate inputs + weights - tmin, tmax = -10., 0 + tmin, tmax = -10.0, 0 n_feats = 3 rng = np.random.RandomState(0) X = rng.randn(10000, n_feats) @@ -163,82 +193,83 @@ def test_receptive_field_basic(n_jobs): # Delay inputs and cut off first 4 values since they'll be cut in the fit X_del = np.concatenate( - _delay_time_series(X, tmin, tmax, 1.).transpose(2, 0, 1), axis=1) + _delay_time_series(X, tmin, tmax, 1.0).transpose(2, 0, 1), axis=1 + ) y = np.dot(X_del, w) # Fit the model and test values - feature_names = ['feature_%i' % ii for ii in [0, 1, 2]] - rf = ReceptiveField(tmin, tmax, 1, feature_names, estimator=mod, - patterns=True) + feature_names = ["feature_%i" % ii for ii in [0, 1, 2]] + rf = ReceptiveField(tmin, tmax, 1, feature_names, estimator=mod, patterns=True) rf.fit(X, y) assert_array_equal(rf.delays_, np.arange(tmin, tmax + 1)) y_pred = rf.predict(X) assert_allclose(y[rf.valid_samples_], y_pred[rf.valid_samples_], atol=1e-2) scores = rf.score(X, y) - assert scores > .99 + assert scores > 0.99 assert_allclose(rf.coef_.T.ravel(), w, atol=1e-3) # Make sure different input shapes work - rf.fit(X[:, np.newaxis:], y[:, np.newaxis]) + rf.fit(X[:, np.newaxis :], y[:, np.newaxis]) rf.fit(X, y[:, np.newaxis]) - with pytest.raises(ValueError, match='If X has 3 .* y must have 2 or 3'): + with pytest.raises(ValueError, match="If X has 3 .* y must have 2 or 3"): rf.fit(X[..., np.newaxis], y) - with pytest.raises(ValueError, match='X must be shape'): + with pytest.raises(ValueError, match="X must be shape"): rf.fit(X[:, 0], y) - with pytest.raises(ValueError, match='X and y do not have the same n_epo'): - rf.fit(X[:, np.newaxis], np.tile(y[:, np.newaxis, np.newaxis], - [1, 2, 1])) - with pytest.raises(ValueError, match='X and y do not have the same n_tim'): + with pytest.raises(ValueError, match="X and y do not have the same n_epo"): + rf.fit(X[:, np.newaxis], np.tile(y[:, np.newaxis, np.newaxis], [1, 2, 1])) + with pytest.raises(ValueError, match="X and y do not have the same n_tim"): rf.fit(X, y[:-2]) - with pytest.raises(ValueError, match='n_features in X does not match'): + with pytest.raises(ValueError, match="n_features in X does not match"): rf.fit(X[:, :1], y) # auto-naming features - feature_names = ['feature_%s' % ii for ii in [0, 1, 2]] - rf = ReceptiveField(tmin, tmax, 1, estimator=mod, - feature_names=feature_names) + feature_names = ["feature_%s" % ii for ii in [0, 1, 2]] + rf = ReceptiveField(tmin, tmax, 1, estimator=mod, feature_names=feature_names) assert_equal(rf.feature_names, feature_names) rf = ReceptiveField(tmin, tmax, 1, estimator=mod) rf.fit(X, y) assert_equal(rf.feature_names, None) # Float becomes ridge - rf = ReceptiveField(tmin, tmax, 1, ['one', 'two', 'three'], estimator=0) + rf = ReceptiveField(tmin, tmax, 1, ["one", "two", "three"], estimator=0) str(rf) # repr works before fit rf.fit(X, y) assert isinstance(rf.estimator_, TimeDelayingRidge) str(rf) # repr works after fit - rf = ReceptiveField(tmin, tmax, 1, ['one'], estimator=0) + rf = ReceptiveField(tmin, tmax, 1, ["one"], estimator=0) rf.fit(X[:, [0]], y) str(rf) # repr with one feature # Should only accept estimators or floats - with pytest.raises(ValueError, match='`estimator` must be a float or'): - ReceptiveField(tmin, tmax, 1, estimator='foo').fit(X, y) - with pytest.raises(ValueError, match='`estimator` must be a float or'): + with pytest.raises(ValueError, match="`estimator` must be a float or"): + ReceptiveField(tmin, tmax, 1, estimator="foo").fit(X, y) + with pytest.raises(ValueError, match="`estimator` must be a float or"): ReceptiveField(tmin, tmax, 1, estimator=np.array([1, 2, 3])).fit(X, y) - with pytest.raises(ValueError, match='tmin .* must be at most tmax'): + with pytest.raises(ValueError, match="tmin .* must be at most tmax"): ReceptiveField(5, 4, 1).fit(X, y) # scorers for key, val in _SCORERS.items(): - rf = ReceptiveField(tmin, tmax, 1, ['one'], - estimator=0, scoring=key, patterns=True) + rf = ReceptiveField( + tmin, tmax, 1, ["one"], estimator=0, scoring=key, patterns=True + ) rf.fit(X[:, [0]], y) y_pred = rf.predict(X[:, [0]]).T.ravel()[:, np.newaxis] - assert_allclose(val(y[:, np.newaxis], y_pred, - multioutput='raw_values'), - rf.score(X[:, [0]], y), rtol=1e-2) - with pytest.raises(ValueError, match='inputs must be shape'): - _SCORERS['corrcoef'](y.ravel(), y_pred, multioutput='raw_values') + assert_allclose( + val(y[:, np.newaxis], y_pred, multioutput="raw_values"), + rf.score(X[:, [0]], y), + rtol=1e-2, + ) + with pytest.raises(ValueError, match="inputs must be shape"): + _SCORERS["corrcoef"](y.ravel(), y_pred, multioutput="raw_values") # Need correct scorers - with pytest.raises(ValueError, match='scoring must be one of'): - ReceptiveField(tmin, tmax, 1., scoring='foo').fit(X, y) + with pytest.raises(ValueError, match="scoring must be one of"): + ReceptiveField(tmin, tmax, 1.0, scoring="foo").fit(X, y) -@pytest.mark.parametrize('n_jobs', n_jobs_test) +@pytest.mark.parametrize("n_jobs", n_jobs_test) def test_time_delaying_fast_calc(n_jobs): """Test time delaying and fast calculations.""" X = np.array([[1, 2, 3], [5, 7, 11]]).T # all negative smin, smax = 1, 2 - X_del = _delay_time_series(X, smin, smax, 1.) + X_del = _delay_time_series(X, smin, smax, 1.0) # (n_times, n_features, n_delays) -> (n_times, n_features * n_delays) X_del.shape = (X.shape[0], -1) expected = np.array([[0, 1, 2], [0, 0, 1], [0, 5, 7], [0, 0, 5]]).T @@ -250,30 +281,32 @@ def test_time_delaying_fast_calc(n_jobs): assert_allclose(x_xt, expected) # all positive smin, smax = -2, -1 - X_del = _delay_time_series(X, smin, smax, 1.) + X_del = _delay_time_series(X, smin, smax, 1.0) X_del.shape = (X.shape[0], -1) expected = np.array([[3, 0, 0], [2, 3, 0], [11, 0, 0], [7, 11, 0]]).T assert_allclose(X_del, expected) Xt_X = np.dot(X_del.T, X_del) - expected = [[9, 6, 33, 21], [6, 13, 22, 47], - [33, 22, 121, 77], [21, 47, 77, 170]] + expected = [[9, 6, 33, 21], [6, 13, 22, 47], [33, 22, 121, 77], [21, 47, 77, 170]] assert_allclose(Xt_X, expected) x_xt = _compute_corrs(X, np.zeros((X.shape[0], 1)), smin, smax + 1)[0] assert_allclose(x_xt, expected) # both sides smin, smax = -1, 1 - X_del = _delay_time_series(X, smin, smax, 1.) + X_del = _delay_time_series(X, smin, smax, 1.0) X_del.shape = (X.shape[0], -1) - expected = np.array([[2, 3, 0], [1, 2, 3], [0, 1, 2], - [7, 11, 0], [5, 7, 11], [0, 5, 7]]).T + expected = np.array( + [[2, 3, 0], [1, 2, 3], [0, 1, 2], [7, 11, 0], [5, 7, 11], [0, 5, 7]] + ).T assert_allclose(X_del, expected) Xt_X = np.dot(X_del.T, X_del) - expected = [[13, 8, 3, 47, 31, 15], - [8, 14, 8, 29, 52, 31], - [3, 8, 5, 11, 29, 19], - [47, 29, 11, 170, 112, 55], - [31, 52, 29, 112, 195, 112], - [15, 31, 19, 55, 112, 74]] + expected = [ + [13, 8, 3, 47, 31, 15], + [8, 14, 8, 29, 52, 31], + [3, 8, 5, 11, 29, 19], + [47, 29, 11, 170, 112, 55], + [31, 52, 29, 112, 195, 112], + [15, 31, 19, 55, 112, 74], + ] assert_allclose(Xt_X, expected) x_xt = _compute_corrs(X, np.zeros((X.shape[0], 1)), smin, smax + 1)[0] assert_allclose(x_xt, expected) @@ -281,10 +314,9 @@ def test_time_delaying_fast_calc(n_jobs): # slightly harder to get the non-Toeplitz correction correct X = np.array([[1, 2, 3, 5]]).T smin, smax = 0, 3 - X_del = _delay_time_series(X, smin, smax, 1.) + X_del = _delay_time_series(X, smin, smax, 1.0) X_del.shape = (X.shape[0], -1) - expected = np.array([[1, 2, 3, 5], [0, 1, 2, 3], - [0, 0, 1, 2], [0, 0, 0, 1]]).T + expected = np.array([[1, 2, 3, 5], [0, 1, 2, 3], [0, 0, 1, 2], [0, 0, 0, 1]]).T assert_allclose(X_del, expected) Xt_X = np.dot(X_del.T, X_del) expected = [[39, 23, 13, 5], [23, 14, 8, 3], [13, 8, 5, 2], [5, 3, 2, 1]] @@ -295,18 +327,23 @@ def test_time_delaying_fast_calc(n_jobs): # even worse X = np.array([[1, 2, 3], [5, 7, 11]]).T smin, smax = 0, 2 - X_del = _delay_time_series(X, smin, smax, 1.) + X_del = _delay_time_series(X, smin, smax, 1.0) X_del.shape = (X.shape[0], -1) - expected = np.array([[1, 2, 3], [0, 1, 2], [0, 0, 1], - [5, 7, 11], [0, 5, 7], [0, 0, 5]]).T + expected = np.array( + [[1, 2, 3], [0, 1, 2], [0, 0, 1], [5, 7, 11], [0, 5, 7], [0, 0, 5]] + ).T assert_allclose(X_del, expected) Xt_X = np.dot(X_del.T, X_del) - expected = np.array([[14, 8, 3, 52, 31, 15], - [8, 5, 2, 29, 19, 10], - [3, 2, 1, 11, 7, 5], - [52, 29, 11, 195, 112, 55], - [31, 19, 7, 112, 74, 35], - [15, 10, 5, 55, 35, 25]]) + expected = np.array( + [ + [14, 8, 3, 52, 31, 15], + [8, 5, 2, 29, 19, 10], + [3, 2, 1, 11, 7, 5], + [52, 29, 11, 195, 112, 55], + [31, 19, 7, 112, 74, 35], + [15, 10, 5, 55, 35, 25], + ] + ) assert_allclose(Xt_X, expected) x_xt = _compute_corrs(X, np.zeros((X.shape[0], 1)), smin, smax + 1)[0] assert_allclose(x_xt, expected) @@ -323,10 +360,10 @@ def test_time_delaying_fast_calc(n_jobs): for ii in range(X.shape[1]): kernel = rng.randn(smax - smin + 1) kernel -= np.mean(kernel) - y[:, ii % y.shape[-1]] = np.convolve(X[:, ii], kernel, 'same') + y[:, ii % y.shape[-1]] = np.convolve(X[:, ii], kernel, "same") x_xt, x_yt, n_ch_x, _, _ = _compute_corrs(X, y, smin, smax + 1) - X_del = _delay_time_series(X, smin, smax, 1., fill_mean=False) - x_yt_true = einsum('tfd,to->ofd', X_del, y) + X_del = _delay_time_series(X, smin, smax, 1.0, fill_mean=False) + x_yt_true = einsum("tfd,to->ofd", X_del, y) x_yt_true = np.reshape(x_yt_true, (x_yt_true.shape[0], -1)).T assert_allclose(x_yt, x_yt_true, atol=1e-7, err_msg=(smin, smax)) X_del.shape = (X.shape[0], -1) @@ -334,11 +371,12 @@ def test_time_delaying_fast_calc(n_jobs): assert_allclose(x_xt, x_xt_true, atol=1e-7, err_msg=(smin, smax)) -@pytest.mark.parametrize('n_jobs', n_jobs_test) +@pytest.mark.parametrize("n_jobs", n_jobs_test) @requires_sklearn def test_receptive_field_1d(n_jobs): """Test that the fast solving works like Ridge.""" from sklearn.linear_model import Ridge + rng = np.random.RandomState(0) x = rng.randn(500, 1) for delay in range(-2, 3): @@ -356,22 +394,26 @@ def test_receptive_field_1d(n_jobs): y.shape = (y.shape[0],) + (1,) * (ndim - 1) for slim in slims: smin, smax = slim - lap = TimeDelayingRidge(smin, smax, 1., 0.1, 'laplacian', - fit_intercept=False, n_jobs=n_jobs) - for estimator in (Ridge(alpha=0.), Ridge(alpha=0.1), 0., 0.1, - lap): + lap = TimeDelayingRidge( + smin, + smax, + 1.0, + 0.1, + "laplacian", + fit_intercept=False, + n_jobs=n_jobs, + ) + for estimator in (Ridge(alpha=0.0), Ridge(alpha=0.1), 0.0, 0.1, lap): for offset in (-100, 0, 100): - model = ReceptiveField(smin, smax, 1., - estimator=estimator, - n_jobs=n_jobs) + model = ReceptiveField( + smin, smax, 1.0, estimator=estimator, n_jobs=n_jobs + ) use_x = x + offset model.fit(use_x, y) if estimator is lap: continue # these checks are too stringent - assert_allclose(model.estimator_.intercept_, -offset, - atol=1e-1) - assert_array_equal(model.delays_, - np.arange(smin, smax + 1)) + assert_allclose(model.estimator_.intercept_, -offset, atol=1e-1) + assert_array_equal(model.delays_, np.arange(smin, smax + 1)) expected = (model.delays_ == delay).astype(float) expected = expected[np.newaxis] # features if y.ndim == 2: @@ -383,16 +425,19 @@ def test_receptive_field_1d(n_jobs): assert stop - start >= 495 assert_allclose( model.predict(use_x)[model.valid_samples_], - y[model.valid_samples_], atol=1e-2) + y[model.valid_samples_], + atol=1e-2, + ) score = np.mean(model.score(use_x, y)) assert score > 0.9999 -@pytest.mark.parametrize('n_jobs', n_jobs_test) +@pytest.mark.parametrize("n_jobs", n_jobs_test) @requires_sklearn def test_receptive_field_nd(n_jobs): """Test multidimensional support.""" from sklearn.linear_model import Ridge + # multidimensional rng = np.random.RandomState(3) x = rng.randn(1000, 3) @@ -407,55 +452,57 @@ def test_receptive_field_nd(n_jobs): x -= np.mean(x, axis=0) x_off = x + 1e3 expected = [ - [[0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 4, 0], - [0, 0, 2, 0, 0, 0]], - [[0, 0, 0, -3, 0, 0], - [0, -1, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0]], + [[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 4, 0], [0, 0, 2, 0, 0, 0]], + [[0, 0, 0, -3, 0, 0], [0, -1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]], ] - tdr_l = TimeDelayingRidge(smin, smax, 1., 0.1, 'laplacian', n_jobs=n_jobs) - tdr_nc = TimeDelayingRidge(smin, smax, 1., 0.1, n_jobs=n_jobs, - edge_correction=False) - for estimator, atol in zip((Ridge(alpha=0.), 0., 0.01, tdr_l, tdr_nc), - (1e-3, 1e-3, 1e-3, 5e-3, 5e-2)): - model = ReceptiveField(smin, smax, 1., - estimator=estimator) + tdr_l = TimeDelayingRidge(smin, smax, 1.0, 0.1, "laplacian", n_jobs=n_jobs) + tdr_nc = TimeDelayingRidge( + smin, smax, 1.0, 0.1, n_jobs=n_jobs, edge_correction=False + ) + for estimator, atol in zip( + (Ridge(alpha=0.0), 0.0, 0.01, tdr_l, tdr_nc), (1e-3, 1e-3, 1e-3, 5e-3, 5e-2) + ): + model = ReceptiveField(smin, smax, 1.0, estimator=estimator) model.fit(x, y) - assert_array_equal(model.delays_, - np.arange(smin, smax + 1)) + assert_array_equal(model.delays_, np.arange(smin, smax + 1)) assert_allclose(model.coef_, expected, atol=atol) - tdr = TimeDelayingRidge(smin, smax, 1., 0.01, reg_type='foo', - n_jobs=n_jobs) - model = ReceptiveField(smin, smax, 1., estimator=tdr) - with pytest.raises(ValueError, match='reg_type entries must be one of'): + tdr = TimeDelayingRidge(smin, smax, 1.0, 0.01, reg_type="foo", n_jobs=n_jobs) + model = ReceptiveField(smin, smax, 1.0, estimator=tdr) + with pytest.raises(ValueError, match="reg_type entries must be one of"): model.fit(x, y) - tdr = TimeDelayingRidge(smin, smax, 1., 0.01, reg_type=['laplacian'], - n_jobs=n_jobs) - model = ReceptiveField(smin, smax, 1., estimator=tdr) - with pytest.raises(ValueError, match='reg_type must have two elements'): + tdr = TimeDelayingRidge( + smin, smax, 1.0, 0.01, reg_type=["laplacian"], n_jobs=n_jobs + ) + model = ReceptiveField(smin, smax, 1.0, estimator=tdr) + with pytest.raises(ValueError, match="reg_type must have two elements"): model.fit(x, y) model = ReceptiveField(smin, smax, 1, estimator=tdr, fit_intercept=False) - with pytest.raises(ValueError, match='fit_intercept'): + with pytest.raises(ValueError, match="fit_intercept"): model.fit(x, y) # Now check the intercept_ - tdr = TimeDelayingRidge(smin, smax, 1., 0., n_jobs=n_jobs) - tdr_no = TimeDelayingRidge(smin, smax, 1., 0., fit_intercept=False, - n_jobs=n_jobs) - for estimator in (Ridge(alpha=0.), tdr, - Ridge(alpha=0., fit_intercept=False), tdr_no): + tdr = TimeDelayingRidge(smin, smax, 1.0, 0.0, n_jobs=n_jobs) + tdr_no = TimeDelayingRidge(smin, smax, 1.0, 0.0, fit_intercept=False, n_jobs=n_jobs) + for estimator in ( + Ridge(alpha=0.0), + tdr, + Ridge(alpha=0.0, fit_intercept=False), + tdr_no, + ): # first with no intercept in the data - model = ReceptiveField(smin, smax, 1., estimator=estimator) + model = ReceptiveField(smin, smax, 1.0, estimator=estimator) model.fit(x, y) - assert_allclose(model.estimator_.intercept_, 0., atol=1e-7, - err_msg=repr(estimator)) - assert_allclose(model.coef_, expected, atol=1e-3, - err_msg=repr(estimator)) + assert_allclose( + model.estimator_.intercept_, 0.0, atol=1e-7, err_msg=repr(estimator) + ) + assert_allclose(model.coef_, expected, atol=1e-3, err_msg=repr(estimator)) y_pred = model.predict(x) - assert_allclose(y_pred[model.valid_samples_], - y[model.valid_samples_], - atol=1e-2, err_msg=repr(estimator)) + assert_allclose( + y_pred[model.valid_samples_], + y[model.valid_samples_], + atol=1e-2, + err_msg=repr(estimator), + ) score = np.mean(model.score(x, y)) assert score > 0.9999 @@ -466,12 +513,14 @@ def test_receptive_field_nd(n_jobs): itol = 0.5 ctol = 5e-4 else: - val = itol = 0. - ctol = 2. - assert_allclose(model.estimator_.intercept_, val, atol=itol, - err_msg=repr(estimator)) - assert_allclose(model.coef_, expected, atol=ctol, rtol=ctol, - err_msg=repr(estimator)) + val = itol = 0.0 + ctol = 2.0 + assert_allclose( + model.estimator_.intercept_, val, atol=itol, err_msg=repr(estimator) + ) + assert_allclose( + model.coef_, expected, atol=ctol, rtol=ctol, err_msg=repr(estimator) + ) if estimator.fit_intercept: ptol = 1e-2 stol = 0.999999 @@ -479,13 +528,14 @@ def test_receptive_field_nd(n_jobs): ptol = 10 stol = 0.6 y_pred = model.predict(x_off)[model.valid_samples_] - assert_allclose(y_pred, y[model.valid_samples_], - atol=ptol, err_msg=repr(estimator)) + assert_allclose( + y_pred, y[model.valid_samples_], atol=ptol, err_msg=repr(estimator) + ) score = np.mean(model.score(x_off, y)) assert score > stol, estimator - model = ReceptiveField(smin, smax, 1., fit_intercept=False) + model = ReceptiveField(smin, smax, 1.0, fit_intercept=False) model.fit(x_off, y) - assert_allclose(model.estimator_.intercept_, 0., atol=1e-7) + assert_allclose(model.estimator_.intercept_, 0.0, atol=1e-7) score = np.mean(model.score(x_off, y)) assert score > 0.6 @@ -496,7 +546,8 @@ def _make_data(n_feats, n_targets, n_samples, tmin, tmax): w = rng.randn(int((tmax - tmin) + 1) * n_feats, n_targets) # Delay inputs X_del = np.concatenate( - _delay_time_series(X, tmin, tmax, 1.).transpose(2, 0, 1), axis=1) + _delay_time_series(X, tmin, tmax, 1.0).transpose(2, 0, 1), axis=1 + ) y = np.dot(X_del, w) return X, y @@ -506,25 +557,25 @@ def test_inverse_coef(): """Test inverse coefficients computation.""" from sklearn.linear_model import Ridge - tmin, tmax = 0., 10. + tmin, tmax = 0.0, 10.0 n_feats, n_targets, n_samples = 3, 2, 1000 n_delays = int((tmax - tmin) + 1) # Check coefficient dims, for all estimator types X, y = _make_data(n_feats, n_targets, n_samples, tmin, tmax) - tdr = TimeDelayingRidge(tmin, tmax, 1., 0.1, 'laplacian') - for estimator in (0., 0.01, Ridge(alpha=0.), tdr): - rf = ReceptiveField(tmin, tmax, 1., estimator=estimator, - patterns=True) + tdr = TimeDelayingRidge(tmin, tmax, 1.0, 0.1, "laplacian") + for estimator in (0.0, 0.01, Ridge(alpha=0.0), tdr): + rf = ReceptiveField(tmin, tmax, 1.0, estimator=estimator, patterns=True) rf.fit(X, y) - inv_rf = ReceptiveField(tmin, tmax, 1., estimator=estimator, - patterns=True) + inv_rf = ReceptiveField(tmin, tmax, 1.0, estimator=estimator, patterns=True) inv_rf.fit(y, X) - assert_array_equal(rf.coef_.shape, rf.patterns_.shape, - (n_targets, n_feats, n_delays)) - assert_array_equal(inv_rf.coef_.shape, inv_rf.patterns_.shape, - (n_feats, n_targets, n_delays)) + assert_array_equal( + rf.coef_.shape, rf.patterns_.shape, (n_targets, n_feats, n_delays) + ) + assert_array_equal( + inv_rf.coef_.shape, inv_rf.patterns_.shape, (n_feats, n_targets, n_delays) + ) # we should have np.dot(patterns.T,coef) ~ np.eye(n) c0 = rf.coef_.reshape(n_targets, n_feats * n_delays) @@ -536,10 +587,12 @@ def test_inverse_coef(): def test_linalg_warning(): """Test that warnings are issued when no regularization is applied.""" from sklearn.linear_model import Ridge + n_feats, n_targets, n_samples = 5, 60, 50 X, y = _make_data(n_feats, n_targets, n_samples, tmin, tmax) - for estimator in (0., Ridge(alpha=0.)): - rf = ReceptiveField(tmin, tmax, 1., estimator=estimator) - with pytest.warns((RuntimeWarning, UserWarning), - match='[Singular|scipy.linalg.solve]'): + for estimator in (0.0, Ridge(alpha=0.0)): + rf = ReceptiveField(tmin, tmax, 1.0, estimator=estimator) + with pytest.warns( + (RuntimeWarning, UserWarning), match="[Singular|scipy.linalg.solve]" + ): rf.fit(y, X) diff --git a/mne/decoding/tests/test_search_light.py b/mne/decoding/tests/test_search_light.py index 1bc4f1e1e9a..a531d7b668e 100644 --- a/mne/decoding/tests/test_search_light.py +++ b/mne/decoding/tests/test_search_light.py @@ -31,25 +31,25 @@ def test_search_light(): from sklearn.linear_model import Ridge, LogisticRegression from sklearn.pipeline import make_pipeline from sklearn.metrics import roc_auc_score, make_scorer + with _record_warnings(): # NumPy module import from sklearn.ensemble import BaggingClassifier from sklearn.base import is_classifier - logreg = LogisticRegression(solver='liblinear', multi_class='ovr', - random_state=0) + logreg = LogisticRegression(solver="liblinear", multi_class="ovr", random_state=0) X, y = make_data() n_epochs, _, n_time = X.shape # init - pytest.raises(ValueError, SlidingEstimator, 'foo') + pytest.raises(ValueError, SlidingEstimator, "foo") sl = SlidingEstimator(Ridge()) - assert (not is_classifier(sl)) - sl = SlidingEstimator(LogisticRegression(solver='liblinear')) - assert (is_classifier(sl)) + assert not is_classifier(sl) + sl = SlidingEstimator(LogisticRegression(solver="liblinear")) + assert is_classifier(sl) # fit - assert_equal(sl.__repr__()[:18], '') + assert_equal(sl.__repr__()[-28:], ", fitted with 10 estimators>") pytest.raises(ValueError, sl.fit, X[1:], y) pytest.raises(ValueError, sl.fit, X[:, :, 0], y) sl.fit(X, y, sample_weight=np.ones_like(y)) @@ -57,38 +57,37 @@ def test_search_light(): # transforms pytest.raises(ValueError, sl.predict, X[:, :, :2]) y_pred = sl.predict(X) - assert (y_pred.dtype == int) + assert y_pred.dtype == int assert_array_equal(y_pred.shape, [n_epochs, n_time]) y_proba = sl.predict_proba(X) - assert (y_proba.dtype == float) + assert y_proba.dtype == float assert_array_equal(y_proba.shape, [n_epochs, n_time, 2]) # score score = sl.score(X, y) assert_array_equal(score.shape, [n_time]) - assert (np.sum(np.abs(score)) != 0) - assert (score.dtype == float) + assert np.sum(np.abs(score)) != 0 + assert score.dtype == float sl = SlidingEstimator(logreg) assert_equal(sl.scoring, None) # Scoring method - for scoring in ['foo', 999]: + for scoring in ["foo", 999]: sl = SlidingEstimator(logreg, scoring=scoring) sl.fit(X, y) pytest.raises((ValueError, TypeError), sl.score, X, y) # Check sklearn's roc_auc fix: scikit-learn/scikit-learn#6874 # -- 3 class problem - sl = SlidingEstimator(logreg, scoring='roc_auc') + sl = SlidingEstimator(logreg, scoring="roc_auc") y = np.arange(len(X)) % 3 sl.fit(X, y) - with pytest.raises(ValueError, match='for two-class'): + with pytest.raises(ValueError, match="for two-class"): sl.score(X, y) # But check that valid ones should work with new enough sklearn - if 'multi_class' in signature(roc_auc_score).parameters: - scoring = make_scorer( - roc_auc_score, needs_proba=True, multi_class='ovo') + if "multi_class" in signature(roc_auc_score).parameters: + scoring = make_scorer(roc_auc_score, needs_proba=True, multi_class="ovo") sl = SlidingEstimator(logreg, scoring=scoring) sl.fit(X, y) sl.score(X, y) # smoke test @@ -97,8 +96,10 @@ def test_search_light(): y = np.arange(len(X)) % 2 + 1 sl.fit(X, y) score = sl.score(X, y) - assert_array_equal(score, [roc_auc_score(y - 1, _y_pred - 1) - for _y_pred in sl.decision_function(X).T]) + assert_array_equal( + score, + [roc_auc_score(y - 1, _y_pred - 1) for _y_pred in sl.decision_function(X).T], + ) y = np.arange(len(X)) % 2 # Cannot pass a metric as a scoring parameter @@ -107,22 +108,23 @@ def test_search_light(): pytest.raises(ValueError, sl1.score, X, y) # Now use string as scoring - sl1 = SlidingEstimator(logreg, scoring='roc_auc') + sl1 = SlidingEstimator(logreg, scoring="roc_auc") sl1.fit(X, y) rng = np.random.RandomState(0) X = rng.randn(*X.shape) # randomize X to avoid AUCs in [0, 1] score_sl = sl1.score(X, y) assert_array_equal(score_sl.shape, [n_time]) - assert (score_sl.dtype == float) + assert score_sl.dtype == float # Check that scoring was applied adequately scoring = make_scorer(roc_auc_score, needs_threshold=True) - score_manual = [scoring(est, x, y) for est, x in zip( - sl1.estimators_, X.transpose(2, 0, 1))] + score_manual = [ + scoring(est, x, y) for est, x in zip(sl1.estimators_, X.transpose(2, 0, 1)) + ] assert_array_equal(score_manual, score_sl) # n_jobs - sl = SlidingEstimator(logreg, n_jobs=None, scoring='roc_auc') + sl = SlidingEstimator(logreg, n_jobs=None, scoring="roc_auc") score_1job = sl.fit(X, y).score(X, y) sl.n_jobs = 2 score_njobs = sl.fit(X, y).score(X, y) @@ -139,10 +141,9 @@ def transform(self, X): return super(_LogRegTransformer, self).predict_proba(X)[..., 1] logreg_transformer = _LogRegTransformer( - random_state=0, multi_class='ovr', solver='liblinear' + random_state=0, multi_class="ovr", solver="liblinear" ) - pipe = make_pipeline(SlidingEstimator(logreg_transformer), - logreg) + pipe = make_pipeline(SlidingEstimator(logreg_transformer), logreg) pipe.fit(X, y) pipe.predict(X) @@ -151,8 +152,7 @@ def transform(self, X): y = np.arange(10) % 2 y_preds = list() for n_jobs in [1, 2]: - pipe = SlidingEstimator( - make_pipeline(Vectorizer(), logreg), n_jobs=n_jobs) + pipe = SlidingEstimator(make_pipeline(Vectorizer(), logreg), n_jobs=n_jobs) y_preds.append(pipe.fit(X, y).predict(X)) features_shape = pipe.estimators_[0].steps[0][1].features_shape_ assert_array_equal(features_shape, [3, 4]) @@ -164,7 +164,7 @@ def transform(self, X): pipe = SlidingEstimator(BaggingClassifier(None, 2), n_jobs=n_jobs) pipe.fit(X, y) pipe.score(X, y) - assert (isinstance(pipe.estimators_[0], BaggingClassifier)) + assert isinstance(pipe.estimators_[0], BaggingClassifier) @requires_sklearn @@ -174,24 +174,23 @@ def test_generalization_light(): from sklearn.linear_model import LogisticRegression from sklearn.metrics import roc_auc_score - logreg = LogisticRegression(solver='liblinear', multi_class='ovr', - random_state=0) + logreg = LogisticRegression(solver="liblinear", multi_class="ovr", random_state=0) X, y = make_data() n_epochs, _, n_time = X.shape # fit gl = GeneralizingEstimator(logreg) - assert_equal(repr(gl)[:23], '') + assert_equal(gl.__repr__()[-28:], ", fitted with 10 estimators>") # transforms y_pred = gl.predict(X) assert_array_equal(y_pred.shape, [n_epochs, n_time, n_time]) - assert (y_pred.dtype == int) + assert y_pred.dtype == int y_proba = gl.predict_proba(X) - assert (y_proba.dtype == float) + assert y_proba.dtype == float assert_array_equal(y_proba.shape, [n_epochs, n_time, n_time, 2]) # transform to different datasize @@ -201,23 +200,23 @@ def test_generalization_light(): # score score = gl.score(X[:, :, :3], y) assert_array_equal(score.shape, [n_time, 3]) - assert (np.sum(np.abs(score)) != 0) - assert (score.dtype == float) + assert np.sum(np.abs(score)) != 0 + assert score.dtype == float - gl = GeneralizingEstimator(logreg, scoring='roc_auc') + gl = GeneralizingEstimator(logreg, scoring="roc_auc") gl.fit(X, y) score = gl.score(X, y) auc = roc_auc_score(y, gl.estimators_[0].predict_proba(X[..., 0])[..., 1]) assert_equal(score[0, 0], auc) - for scoring in ['foo', 999]: + for scoring in ["foo", 999]: gl = GeneralizingEstimator(logreg, scoring=scoring) gl.fit(X, y) pytest.raises((ValueError, TypeError), gl.score, X, y) # Check sklearn's roc_auc fix: scikit-learn/scikit-learn#6874 # -- 3 class problem - gl = GeneralizingEstimator(logreg, scoring='roc_auc') + gl = GeneralizingEstimator(logreg, scoring="roc_auc") y = np.arange(len(X)) % 3 gl.fit(X, y) pytest.raises(ValueError, gl.score, X, y) @@ -225,8 +224,10 @@ def test_generalization_light(): y = np.arange(len(X)) % 2 + 1 gl.fit(X, y) score = gl.score(X, y) - manual_score = [[roc_auc_score(y - 1, _y_pred) for _y_pred in _y_preds] - for _y_preds in gl.decision_function(X).transpose(1, 2, 0)] + manual_score = [ + [roc_auc_score(y - 1, _y_pred) for _y_pred in _y_preds] + for _y_preds in gl.decision_function(X).transpose(1, 2, 0) + ] assert_array_equal(score, manual_score) # n_jobs @@ -246,8 +247,7 @@ def test_generalization_light(): y = np.arange(10) % 2 y_preds = list() for n_jobs in [1, 2]: - pipe = GeneralizingEstimator( - make_pipeline(Vectorizer(), logreg), n_jobs=n_jobs) + pipe = GeneralizingEstimator(make_pipeline(Vectorizer(), logreg), n_jobs=n_jobs) y_preds.append(pipe.fit(X, y).predict(X)) features_shape = pipe.estimators_[0].steps[0][1].features_shape_ assert_array_equal(features_shape, [3, 4]) @@ -255,8 +255,9 @@ def test_generalization_light(): @requires_sklearn -@pytest.mark.parametrize('n_jobs, verbose', - [(1, False), (2, False), (1, True), (2, 'info')]) +@pytest.mark.parametrize( + "n_jobs, verbose", [(1, False), (2, False), (1, True), (2, "info")] +) def test_verbose_arg(capsys, n_jobs, verbose): """Test controlling output with the ``verbose`` argument.""" from sklearn.svm import SVC @@ -267,15 +268,14 @@ def test_verbose_arg(capsys, n_jobs, verbose): # shows progress bar and prints other messages to the console with use_log_level(True): for estimator_object in [SlidingEstimator, GeneralizingEstimator]: - estimator = estimator_object( - clf, n_jobs=n_jobs, verbose=verbose) + estimator = estimator_object(clf, n_jobs=n_jobs, verbose=verbose) estimator = estimator.fit(X, y) estimator.score(X, y) estimator.predict(X) stdout, stderr = capsys.readouterr() if isinstance(verbose, bool) and not verbose: - assert all(channel == '' for channel in (stdout, stderr)) + assert all(channel == "" for channel in (stdout, stderr)) else: assert any(len(channel) > 0 for channel in (stdout, stderr)) @@ -287,6 +287,7 @@ def test_cross_val_predict(): from sklearn.discriminant_analysis import LinearDiscriminantAnalysis from sklearn.base import BaseEstimator, clone from sklearn.model_selection import cross_val_predict + rng = np.random.RandomState(42) X = rng.randn(10, 1, 3) y = rng.randint(0, 2, 10) @@ -309,7 +310,7 @@ def predict_proba(self, X): with pytest.raises(AttributeError, match="classes_ attribute"): estimator = SlidingEstimator(Classifier()) - cross_val_predict(estimator, X, y, method='predict_proba', cv=2) + cross_val_predict(estimator, X, y, method="predict_proba", cv=2) estimator = SlidingEstimator(LinearDiscriminantAnalysis()) - cross_val_predict(estimator, X, y, method='predict_proba', cv=2) + cross_val_predict(estimator, X, y, method="predict_proba", cv=2) diff --git a/mne/decoding/tests/test_ssd.py b/mne/decoding/tests/test_ssd.py index 8ba7657b660..4f674242fd8 100644 --- a/mne/decoding/tests/test_ssd.py +++ b/mne/decoding/tests/test_ssd.py @@ -5,7 +5,7 @@ import numpy as np import pytest -from numpy.testing import (assert_array_almost_equal, assert_array_equal) +from numpy.testing import assert_array_almost_equal, assert_array_equal from mne import io from mne.time_frequency import psd_array_welch from mne.decoding.ssd import SSD @@ -18,9 +18,16 @@ freqs_noise = 8, 13 -def simulate_data(freqs_sig=[9, 12], n_trials=100, n_channels=20, - n_samples=500, samples_per_second=250, - n_components=5, SNR=0.05, random_state=42): +def simulate_data( + freqs_sig=[9, 12], + n_trials=100, + n_channels=20, + n_samples=500, + samples_per_second=250, + n_components=5, + SNR=0.05, + random_state=42, +): """Simulate data according to an instantaneous mixin model. Data are simulated in the statistical source space, where n=n_components @@ -28,9 +35,13 @@ def simulate_data(freqs_sig=[9, 12], n_trials=100, n_channels=20, """ rng = np.random.RandomState(random_state) - filt_params_signal = dict(l_freq=freqs_sig[0], h_freq=freqs_sig[1], - l_trans_bandwidth=1, h_trans_bandwidth=1, - fir_design='firwin') + filt_params_signal = dict( + l_freq=freqs_sig[0], + h_freq=freqs_sig[1], + l_trans_bandwidth=1, + h_trans_bandwidth=1, + fir_design="firwin", + ) # generate an orthogonal mixin matrix mixing_mat = np.linalg.svd(rng.randn(n_channels, n_channels))[0] @@ -44,8 +55,8 @@ def simulate_data(freqs_sig=[9, 12], n_trials=100, n_channels=20, X_s = np.dot(mixing_mat[:, :n_components], S_s.T).T X_n = np.dot(mixing_mat[:, n_components:], S_n.T).T # add noise - X_s = X_s / np.linalg.norm(X_s, 'fro') - X_n = X_n / np.linalg.norm(X_n, 'fro') + X_s = X_s / np.linalg.norm(X_s, "fro") + X_n = X_n / np.linalg.norm(X_n, "fro") X = SNR * X_s + (1 - SNR) * X_n X = X.T S = S.T @@ -58,75 +69,98 @@ def test_ssd(): X, A, S = simulate_data() sf = 250 n_channels = X.shape[0] - info = create_info(ch_names=n_channels, sfreq=sf, ch_types='eeg') + info = create_info(ch_names=n_channels, sfreq=sf, ch_types="eeg") n_components_true = 5 # Init - filt_params_signal = dict(l_freq=freqs_sig[0], h_freq=freqs_sig[1], - l_trans_bandwidth=1, h_trans_bandwidth=1) - filt_params_noise = dict(l_freq=freqs_noise[0], h_freq=freqs_noise[1], - l_trans_bandwidth=1, h_trans_bandwidth=1) + filt_params_signal = dict( + l_freq=freqs_sig[0], + h_freq=freqs_sig[1], + l_trans_bandwidth=1, + h_trans_bandwidth=1, + ) + filt_params_noise = dict( + l_freq=freqs_noise[0], + h_freq=freqs_noise[1], + l_trans_bandwidth=1, + h_trans_bandwidth=1, + ) ssd = SSD(info, filt_params_signal, filt_params_noise) # freq no int - freq = 'foo' - filt_params_signal = dict(l_freq=freq, h_freq=freqs_sig[1], - l_trans_bandwidth=1, h_trans_bandwidth=1) - filt_params_noise = dict(l_freq=freqs_noise[0], h_freq=freqs_noise[1], - l_trans_bandwidth=1, h_trans_bandwidth=1) - with pytest.raises(TypeError, match='must be an instance '): + freq = "foo" + filt_params_signal = dict( + l_freq=freq, h_freq=freqs_sig[1], l_trans_bandwidth=1, h_trans_bandwidth=1 + ) + filt_params_noise = dict( + l_freq=freqs_noise[0], + h_freq=freqs_noise[1], + l_trans_bandwidth=1, + h_trans_bandwidth=1, + ) + with pytest.raises(TypeError, match="must be an instance "): ssd = SSD(info, filt_params_signal, filt_params_noise) # Wrongly specified noise band freq = 2 - filt_params_signal = dict(l_freq=freq, h_freq=freqs_sig[1], - l_trans_bandwidth=1, h_trans_bandwidth=1) - filt_params_noise = dict(l_freq=freqs_noise[0], h_freq=freqs_noise[1], - l_trans_bandwidth=1, h_trans_bandwidth=1) - with pytest.raises(ValueError, match='Wrongly specified '): + filt_params_signal = dict( + l_freq=freq, h_freq=freqs_sig[1], l_trans_bandwidth=1, h_trans_bandwidth=1 + ) + filt_params_noise = dict( + l_freq=freqs_noise[0], + h_freq=freqs_noise[1], + l_trans_bandwidth=1, + h_trans_bandwidth=1, + ) + with pytest.raises(ValueError, match="Wrongly specified "): ssd = SSD(info, filt_params_signal, filt_params_noise) # filt param no dict filt_params_signal = freqs_sig filt_params_noise = freqs_noise - with pytest.raises(ValueError, match='must be defined'): + with pytest.raises(ValueError, match="must be defined"): ssd = SSD(info, filt_params_signal, filt_params_noise) # Data type - filt_params_signal = dict(l_freq=freqs_sig[0], h_freq=freqs_sig[1], - l_trans_bandwidth=1, h_trans_bandwidth=1) - filt_params_noise = dict(l_freq=freqs_noise[0], h_freq=freqs_noise[1], - l_trans_bandwidth=1, h_trans_bandwidth=1) + filt_params_signal = dict( + l_freq=freqs_sig[0], + h_freq=freqs_sig[1], + l_trans_bandwidth=1, + h_trans_bandwidth=1, + ) + filt_params_noise = dict( + l_freq=freqs_noise[0], + h_freq=freqs_noise[1], + l_trans_bandwidth=1, + h_trans_bandwidth=1, + ) ssd = SSD(info, filt_params_signal, filt_params_noise) raw = io.RawArray(X, info) pytest.raises(TypeError, ssd.fit, raw) # check non-boolean return_filtered - with pytest.raises(ValueError, match='return_filtered'): - ssd = SSD(info, filt_params_signal, filt_params_noise, - return_filtered=0) + with pytest.raises(ValueError, match="return_filtered"): + ssd = SSD(info, filt_params_signal, filt_params_noise, return_filtered=0) # check non-boolean sort_by_spectral_ratio - with pytest.raises(ValueError, match='sort_by_spectral_ratio'): - ssd = SSD(info, filt_params_signal, filt_params_noise, - sort_by_spectral_ratio=0) + with pytest.raises(ValueError, match="sort_by_spectral_ratio"): + ssd = SSD(info, filt_params_signal, filt_params_noise, sort_by_spectral_ratio=0) # More than 1 channel type - ch_types = np.reshape([['mag'] * 10, ['eeg'] * 10], n_channels) + ch_types = np.reshape([["mag"] * 10, ["eeg"] * 10], n_channels) info_2 = create_info(ch_names=n_channels, sfreq=sf, ch_types=ch_types) - with pytest.raises(ValueError, match='At this point SSD'): + with pytest.raises(ValueError, match="At this point SSD"): ssd = SSD(info_2, filt_params_signal, filt_params_noise) # Number of channels - info_3 = create_info(ch_names=n_channels + 1, sfreq=sf, ch_types='eeg') + info_3 = create_info(ch_names=n_channels + 1, sfreq=sf, ch_types="eeg") ssd = SSD(info_3, filt_params_signal, filt_params_noise) pytest.raises(ValueError, ssd.fit, X) # Fit n_components = 10 - ssd = SSD(info, filt_params_signal, filt_params_noise, - n_components=n_components) + ssd = SSD(info, filt_params_signal, filt_params_noise, n_components=n_components) # Call transform before fit pytest.raises(AttributeError, ssd.transform, X) @@ -134,28 +168,43 @@ def test_ssd(): # Check outputs ssd.fit(X) - assert (ssd.filters_.shape == (n_channels, n_channels)) - assert (ssd.patterns_.shape == (n_channels, n_channels)) + assert ssd.filters_.shape == (n_channels, n_channels) + assert ssd.patterns_.shape == (n_channels, n_channels) # Transform X_ssd = ssd.fit_transform(X) - assert (X_ssd.shape[0] == n_components) + assert X_ssd.shape[0] == n_components # back and forward - ssd = SSD(info, filt_params_signal, filt_params_noise, - n_components=None, sort_by_spectral_ratio=False) + ssd = SSD( + info, + filt_params_signal, + filt_params_noise, + n_components=None, + sort_by_spectral_ratio=False, + ) ssd.fit(X) X_denoised = ssd.apply(X) assert_array_almost_equal(X_denoised, X) # denoised by low-rank-factorization - ssd = SSD(info, filt_params_signal, filt_params_noise, - n_components=n_components, sort_by_spectral_ratio=True) + ssd = SSD( + info, + filt_params_signal, + filt_params_noise, + n_components=n_components, + sort_by_spectral_ratio=True, + ) ssd.fit(X) X_denoised = ssd.apply(X) - assert (np.linalg.matrix_rank(X_denoised) == n_components) + assert np.linalg.matrix_rank(X_denoised) == n_components # Power ratio ordering - ssd = SSD(info, filt_params_signal, filt_params_noise, - n_components=None, sort_by_spectral_ratio=False) + ssd = SSD( + info, + filt_params_signal, + filt_params_noise, + n_components=None, + sort_by_spectral_ratio=False, + ) ssd.fit(X) spec_ratio, sorter_spec = ssd.get_spectral_ratio(ssd.transform(X)) # since we now that the number of true components is 5, the relative @@ -165,12 +214,25 @@ def test_ssd(): # Check detected peaks # fit ssd n_components = n_components_true - filt_params_signal = dict(l_freq=freqs_sig[0], h_freq=freqs_sig[1], - l_trans_bandwidth=1, h_trans_bandwidth=1) - filt_params_noise = dict(l_freq=freqs_noise[0], h_freq=freqs_noise[1], - l_trans_bandwidth=1, h_trans_bandwidth=1) - ssd = SSD(info, filt_params_signal, filt_params_noise, - n_components=n_components, sort_by_spectral_ratio=False) + filt_params_signal = dict( + l_freq=freqs_sig[0], + h_freq=freqs_sig[1], + l_trans_bandwidth=1, + h_trans_bandwidth=1, + ) + filt_params_noise = dict( + l_freq=freqs_noise[0], + h_freq=freqs_noise[1], + l_trans_bandwidth=1, + h_trans_bandwidth=1, + ) + ssd = SSD( + info, + filt_params_signal, + filt_params_noise, + n_components=n_components, + sort_by_spectral_ratio=False, + ) ssd.fit(X) out = ssd.transform(X) @@ -197,7 +259,7 @@ def test_ssd_epoched_data(): X, A, S = simulate_data(n_trials=100, n_channels=20, n_samples=500) sf = 250 n_channels = X.shape[0] - info = create_info(ch_names=n_channels, sfreq=sf, ch_types='eeg') + info = create_info(ch_names=n_channels, sfreq=sf, ch_types="eeg") n_components_true = 5 # Build epochs as sliding windows over the continuous raw file @@ -206,10 +268,18 @@ def test_ssd_epoched_data(): X_e = np.reshape(X, (100, 20, 500)) # Fit - filt_params_signal = dict(l_freq=freqs_sig[0], h_freq=freqs_sig[1], - l_trans_bandwidth=4, h_trans_bandwidth=4) - filt_params_noise = dict(l_freq=freqs_noise[0], h_freq=freqs_noise[1], - l_trans_bandwidth=4, h_trans_bandwidth=4) + filt_params_signal = dict( + l_freq=freqs_sig[0], + h_freq=freqs_sig[1], + l_trans_bandwidth=4, + h_trans_bandwidth=4, + ) + filt_params_noise = dict( + l_freq=freqs_noise[0], + h_freq=freqs_noise[1], + l_trans_bandwidth=4, + h_trans_bandwidth=4, + ) # ssd on epochs ssd_e = SSD(info, filt_params_signal, filt_params_noise) @@ -221,34 +291,44 @@ def test_ssd_epoched_data(): # Check if the 5 first 5 components are the same for both _, sorter_spec_e = ssd_e.get_spectral_ratio(ssd_e.transform(X_e)) _, sorter_spec = ssd.get_spectral_ratio(ssd.transform(X)) - assert_array_equal(sorter_spec_e[:n_components_true], - sorter_spec[:n_components_true]) + assert_array_equal( + sorter_spec_e[:n_components_true], sorter_spec[:n_components_true] + ) @requires_sklearn def test_ssd_pipeline(): """Test if SSD works in a pipeline.""" from sklearn.pipeline import Pipeline + sf = 250 X, A, S = simulate_data(n_trials=100, n_channels=20, n_samples=500) X_e = np.reshape(X, (100, 20, 500)) # define bynary random output y = np.random.randint(2, size=100) - info = create_info(ch_names=20, sfreq=sf, ch_types='eeg') - - filt_params_signal = dict(l_freq=freqs_sig[0], h_freq=freqs_sig[1], - l_trans_bandwidth=4, h_trans_bandwidth=4) - filt_params_noise = dict(l_freq=freqs_noise[0], h_freq=freqs_noise[1], - l_trans_bandwidth=4, h_trans_bandwidth=4) + info = create_info(ch_names=20, sfreq=sf, ch_types="eeg") + + filt_params_signal = dict( + l_freq=freqs_sig[0], + h_freq=freqs_sig[1], + l_trans_bandwidth=4, + h_trans_bandwidth=4, + ) + filt_params_noise = dict( + l_freq=freqs_noise[0], + h_freq=freqs_noise[1], + l_trans_bandwidth=4, + h_trans_bandwidth=4, + ) ssd = SSD(info, filt_params_signal, filt_params_noise) csp = CSP() - pipe = Pipeline([('SSD', ssd), ('CSP', csp)]) + pipe = Pipeline([("SSD", ssd), ("CSP", csp)]) pipe.set_params(SSD__n_components=5) pipe.set_params(CSP__n_components=2) out = pipe.fit_transform(X_e, y) - assert (out.shape == (100, 2)) - assert (pipe.get_params()['SSD__n_components'] == 5) + assert out.shape == (100, 2) + assert pipe.get_params()["SSD__n_components"] == 5 def test_sorting(): @@ -260,30 +340,53 @@ def test_sorting(): Xtr, Xte = X[:80], X[80:] sf = 250 n_channels = Xtr.shape[1] - info = create_info(ch_names=n_channels, sfreq=sf, ch_types='eeg') - - filt_params_signal = dict(l_freq=freqs_sig[0], h_freq=freqs_sig[1], - l_trans_bandwidth=4, h_trans_bandwidth=4) - filt_params_noise = dict(l_freq=freqs_noise[0], h_freq=freqs_noise[1], - l_trans_bandwidth=4, h_trans_bandwidth=4) + info = create_info(ch_names=n_channels, sfreq=sf, ch_types="eeg") + + filt_params_signal = dict( + l_freq=freqs_sig[0], + h_freq=freqs_sig[1], + l_trans_bandwidth=4, + h_trans_bandwidth=4, + ) + filt_params_noise = dict( + l_freq=freqs_noise[0], + h_freq=freqs_noise[1], + l_trans_bandwidth=4, + h_trans_bandwidth=4, + ) # check sort_by_spectral_ratio set to False - ssd = SSD(info, filt_params_signal, filt_params_noise, - n_components=None, sort_by_spectral_ratio=False) + ssd = SSD( + info, + filt_params_signal, + filt_params_noise, + n_components=None, + sort_by_spectral_ratio=False, + ) ssd.fit(Xtr) _, sorter_tr = ssd.get_spectral_ratio(ssd.transform(Xtr)) _, sorter_te = ssd.get_spectral_ratio(ssd.transform(Xte)) assert any(sorter_tr != sorter_te) # check sort_by_spectral_ratio set to True - ssd = SSD(info, filt_params_signal, filt_params_noise, - n_components=None, sort_by_spectral_ratio=True) + ssd = SSD( + info, + filt_params_signal, + filt_params_noise, + n_components=None, + sort_by_spectral_ratio=True, + ) ssd.fit(Xtr) # check sorters sorter_in = ssd.sorter_spec - ssd = SSD(info, filt_params_signal, filt_params_noise, - n_components=None, sort_by_spectral_ratio=False) + ssd = SSD( + info, + filt_params_signal, + filt_params_noise, + n_components=None, + sort_by_spectral_ratio=False, + ) ssd.fit(Xtr) _, sorter_out = ssd.get_spectral_ratio(ssd.transform(Xtr)) @@ -297,44 +400,70 @@ def test_return_filtered(): X, _, _ = simulate_data(SNR=0.9, freqs_sig=[4, 13]) sf = 250 n_channels = X.shape[0] - info = create_info(ch_names=n_channels, sfreq=sf, ch_types='eeg') - - filt_params_signal = dict(l_freq=freqs_sig[0], h_freq=freqs_sig[1], - l_trans_bandwidth=1, h_trans_bandwidth=1) - filt_params_noise = dict(l_freq=freqs_noise[0], h_freq=freqs_noise[1], - l_trans_bandwidth=1, h_trans_bandwidth=1) + info = create_info(ch_names=n_channels, sfreq=sf, ch_types="eeg") + + filt_params_signal = dict( + l_freq=freqs_sig[0], + h_freq=freqs_sig[1], + l_trans_bandwidth=1, + h_trans_bandwidth=1, + ) + filt_params_noise = dict( + l_freq=freqs_noise[0], + h_freq=freqs_noise[1], + l_trans_bandwidth=1, + h_trans_bandwidth=1, + ) # return filtered to true - ssd = SSD(info, filt_params_signal, filt_params_noise, - sort_by_spectral_ratio=False, return_filtered=True) + ssd = SSD( + info, + filt_params_signal, + filt_params_noise, + sort_by_spectral_ratio=False, + return_filtered=True, + ) ssd.fit(X) out = ssd.transform(X) psd_out, freqs = psd_array_welch(out[0], sfreq=250, n_fft=250) freqs_up = int(freqs[psd_out > 0.5][0]), int(freqs[psd_out > 0.5][-1]) - assert (freqs_up == freqs_sig) + assert freqs_up == freqs_sig # return filtered to false - ssd = SSD(info, filt_params_signal, filt_params_noise, - sort_by_spectral_ratio=False, return_filtered=False) + ssd = SSD( + info, + filt_params_signal, + filt_params_noise, + sort_by_spectral_ratio=False, + return_filtered=False, + ) ssd.fit(X) out = ssd.transform(X) psd_out, freqs = psd_array_welch(out[0], sfreq=250, n_fft=250) freqs_up = int(freqs[psd_out > 0.5][0]), int(freqs[psd_out > 0.5][-1]) - assert (freqs_up != freqs_sig) + assert freqs_up != freqs_sig def test_non_full_rank_data(): """Test that the method works with non-full rank data.""" n_channels = 10 X, _, _ = simulate_data(SNR=0.9, freqs_sig=[4, 13], n_channels=n_channels) - info = create_info(ch_names=n_channels, sfreq=250, ch_types='eeg') - - filt_params_signal = dict(l_freq=freqs_sig[0], h_freq=freqs_sig[1], - l_trans_bandwidth=1, h_trans_bandwidth=1) - filt_params_noise = dict(l_freq=freqs_noise[0], h_freq=freqs_noise[1], - l_trans_bandwidth=1, h_trans_bandwidth=1) + info = create_info(ch_names=n_channels, sfreq=250, ch_types="eeg") + + filt_params_signal = dict( + l_freq=freqs_sig[0], + h_freq=freqs_sig[1], + l_trans_bandwidth=1, + h_trans_bandwidth=1, + ) + filt_params_noise = dict( + l_freq=freqs_noise[0], + h_freq=freqs_noise[1], + l_trans_bandwidth=1, + h_trans_bandwidth=1, + ) # Make data non-full rank rank = 5 diff --git a/mne/decoding/tests/test_time_frequency.py b/mne/decoding/tests/test_time_frequency.py index 5fea1402e68..8d92c8fd72e 100644 --- a/mne/decoding/tests/test_time_frequency.py +++ b/mne/decoding/tests/test_time_frequency.py @@ -15,17 +15,18 @@ def test_timefrequency(): """Test TimeFrequency.""" from sklearn.base import clone + # Init n_freqs = 3 freqs = [20, 21, 22] tf = TimeFrequency(freqs, sfreq=100) - for output in ['avg_power', 'foo', None]: + for output in ["avg_power", "foo", None]: pytest.raises(ValueError, TimeFrequency, freqs, output=output) tf = clone(tf) # Clone estimator freqs_array = np.array(np.asarray(freqs)) - tf = TimeFrequency(freqs_array, 100, "morlet", freqs_array / 5.) + tf = TimeFrequency(freqs_array, 100, "morlet", freqs_array / 5.0) clone(tf) # Fit diff --git a/mne/decoding/tests/test_transformer.py b/mne/decoding/tests/test_transformer.py index 3c53d7e2ca1..1884f926862 100644 --- a/mne/decoding/tests/test_transformer.py +++ b/mne/decoding/tests/test_transformer.py @@ -8,12 +8,22 @@ import numpy as np import pytest -from numpy.testing import (assert_array_equal, assert_array_almost_equal, - assert_allclose, assert_equal) +from numpy.testing import ( + assert_array_equal, + assert_array_almost_equal, + assert_allclose, + assert_equal, +) from mne import io, read_events, Epochs, pick_types -from mne.decoding import (Scaler, FilterEstimator, PSDEstimator, Vectorizer, - UnsupervisedSpatialFilter, TemporalFilter) +from mne.decoding import ( + Scaler, + FilterEstimator, + PSDEstimator, + Vectorizer, + UnsupervisedSpatialFilter, + TemporalFilter, +) from mne.defaults import DEFAULTS from mne.utils import requires_sklearn, check_version, use_log_level @@ -25,29 +35,34 @@ event_name = data_dir / "test-eve.fif" -@pytest.mark.parametrize('info, method', [ - (True, None), - (True, dict(mag=5, grad=10, eeg=20)), - (False, 'mean'), - (False, 'median'), -]) +@pytest.mark.parametrize( + "info, method", + [ + (True, None), + (True, dict(mag=5, grad=10, eeg=20)), + (False, "mean"), + (False, "median"), + ], +) def test_scaler(info, method): """Test methods of Scaler.""" raw = io.read_raw_fif(raw_fname) events = read_events(event_name) - picks = pick_types(raw.info, meg=True, stim=False, ecg=False, - eog=False, exclude='bads') + picks = pick_types( + raw.info, meg=True, stim=False, ecg=False, eog=False, exclude="bads" + ) picks = picks[1:13:3] - epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, - baseline=(None, 0), preload=True) + epochs = Epochs( + raw, events, event_id, tmin, tmax, picks=picks, baseline=(None, 0), preload=True + ) epochs_data = epochs.get_data() y = epochs.events[:, -1] epochs_data_t = epochs_data.transpose([1, 0, 2]) - if method in ('mean', 'median'): - if not check_version('sklearn'): - with pytest.raises(ImportError, match='No module'): + if method in ("mean", "median"): + if not check_version("sklearn"): + with pytest.raises(ImportError, match="No module"): Scaler(info, method) return @@ -57,22 +72,28 @@ def test_scaler(info, method): X = scaler.fit_transform(epochs_data, y) assert_equal(X.shape, epochs_data.shape) if method is None or isinstance(method, dict): - sd = DEFAULTS['scalings'] if method is None else method + sd = DEFAULTS["scalings"] if method is None else method stds = np.zeros(len(picks)) - for key in ('mag', 'grad'): - stds[pick_types(epochs.info, meg=key)] = 1. / sd[key] - stds[pick_types(epochs.info, meg=False, eeg=True)] = 1. / sd['eeg'] + for key in ("mag", "grad"): + stds[pick_types(epochs.info, meg=key)] = 1.0 / sd[key] + stds[pick_types(epochs.info, meg=False, eeg=True)] = 1.0 / sd["eeg"] means = np.zeros(len(epochs.ch_names)) - elif method == 'mean': + elif method == "mean": stds = np.array([np.std(ch_data) for ch_data in epochs_data_t]) means = np.array([np.mean(ch_data) for ch_data in epochs_data_t]) else: # median - percs = np.array([np.percentile(ch_data, [25, 50, 75]) - for ch_data in epochs_data_t]) + percs = np.array( + [np.percentile(ch_data, [25, 50, 75]) for ch_data in epochs_data_t] + ) stds = percs[:, 2] - percs[:, 0] means = percs[:, 1] - assert_allclose(X * stds[:, np.newaxis] + means[:, np.newaxis], - epochs_data, rtol=1e-12, atol=1e-20, err_msg=method) + assert_allclose( + X * stds[:, np.newaxis] + means[:, np.newaxis], + epochs_data, + rtol=1e-12, + atol=1e-20, + err_msg=method, + ) X2 = scaler.fit(epochs_data, y).transform(epochs_data) assert_array_equal(X, X2) @@ -85,8 +106,15 @@ def test_scaler(info, method): pytest.raises(ValueError, Scaler, None, None) pytest.raises(TypeError, scaler.fit, epochs, y) pytest.raises(TypeError, scaler.transform, epochs) - epochs_bad = Epochs(raw, events, event_id, 0, 0.01, baseline=None, - picks=np.arange(len(raw.ch_names))) # non-data chs + epochs_bad = Epochs( + raw, + events, + event_id, + 0, + 0.01, + baseline=None, + picks=np.arange(len(raw.ch_names)), + ) # non-data chs scaler = Scaler(epochs_bad.info, None) pytest.raises(ValueError, scaler.fit, epochs_bad.get_data(), y) @@ -95,34 +123,46 @@ def test_filterestimator(): """Test methods of FilterEstimator.""" raw = io.read_raw_fif(raw_fname) events = read_events(event_name) - picks = pick_types(raw.info, meg=True, stim=False, ecg=False, - eog=False, exclude='bads') + picks = pick_types( + raw.info, meg=True, stim=False, ecg=False, eog=False, exclude="bads" + ) picks = picks[1:13:3] - epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, - baseline=(None, 0), preload=True) + epochs = Epochs( + raw, events, event_id, tmin, tmax, picks=picks, baseline=(None, 0), preload=True + ) epochs_data = epochs.get_data() # Add tests for different combinations of l_freq and h_freq filt = FilterEstimator(epochs.info, l_freq=40, h_freq=80) y = epochs.events[:, -1] X = filt.fit_transform(epochs_data, y) - assert (X.shape == epochs_data.shape) + assert X.shape == epochs_data.shape assert_array_equal(filt.fit(epochs_data, y).transform(epochs_data), X) - filt = FilterEstimator(epochs.info, l_freq=None, h_freq=40, - filter_length='auto', - l_trans_bandwidth='auto', h_trans_bandwidth='auto') + filt = FilterEstimator( + epochs.info, + l_freq=None, + h_freq=40, + filter_length="auto", + l_trans_bandwidth="auto", + h_trans_bandwidth="auto", + ) y = epochs.events[:, -1] X = filt.fit_transform(epochs_data, y) filt = FilterEstimator(epochs.info, l_freq=1, h_freq=1) y = epochs.events[:, -1] - with pytest.warns(RuntimeWarning, match='longer than the signal'): + with pytest.warns(RuntimeWarning, match="longer than the signal"): pytest.raises(ValueError, filt.fit_transform, epochs_data, y) - filt = FilterEstimator(epochs.info, l_freq=40, h_freq=None, - filter_length='auto', - l_trans_bandwidth='auto', h_trans_bandwidth='auto') + filt = FilterEstimator( + epochs.info, + l_freq=40, + h_freq=None, + filter_length="auto", + l_trans_bandwidth="auto", + h_trans_bandwidth="auto", + ) X = filt.fit_transform(epochs_data, y) # Test init exception @@ -134,17 +174,19 @@ def test_psdestimator(): """Test methods of PSDEstimator.""" raw = io.read_raw_fif(raw_fname) events = read_events(event_name) - picks = pick_types(raw.info, meg=True, stim=False, ecg=False, - eog=False, exclude='bads') + picks = pick_types( + raw.info, meg=True, stim=False, ecg=False, eog=False, exclude="bads" + ) picks = picks[1:13:3] - epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, - baseline=(None, 0), preload=True) + epochs = Epochs( + raw, events, event_id, tmin, tmax, picks=picks, baseline=(None, 0), preload=True + ) epochs_data = epochs.get_data() psd = PSDEstimator(2 * np.pi, 0, np.inf) y = epochs.events[:, -1] X = psd.fit_transform(epochs_data, y) - assert (X.shape[0] == epochs_data.shape[0]) + assert X.shape[0] == epochs_data.shape[0] assert_array_equal(psd.fit(epochs_data, y).transform(epochs_data), X) # Test init exception @@ -166,15 +208,13 @@ def test_vectorizer(): assert_array_equal(vect.inverse_transform(result[1:]), data[1:]) # check with different shape - assert_equal(vect.fit_transform(np.random.rand(150, 18, 6, 3)).shape, - (150, 324)) + assert_equal(vect.fit_transform(np.random.rand(150, 18, 6, 3)).shape, (150, 324)) assert_equal(vect.fit_transform(data[1:]).shape, (149, 108)) # check if raised errors are working correctly vect.fit(np.random.rand(105, 12, 3)) pytest.raises(ValueError, vect.transform, np.random.rand(105, 12, 3, 1)) - pytest.raises(ValueError, vect.inverse_transform, - np.random.rand(102, 12, 12)) + pytest.raises(ValueError, vect.inverse_transform, np.random.rand(102, 12, 12)) @requires_sklearn @@ -182,13 +222,24 @@ def test_unsupervised_spatial_filter(): """Test unsupervised spatial filter.""" from sklearn.decomposition import PCA from sklearn.kernel_ridge import KernelRidge + raw = io.read_raw_fif(raw_fname) events = read_events(event_name) - picks = pick_types(raw.info, meg=True, stim=False, ecg=False, - eog=False, exclude='bads') + picks = pick_types( + raw.info, meg=True, stim=False, ecg=False, eog=False, exclude="bads" + ) picks = picks[1:13:3] - epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, - preload=True, baseline=None, verbose=False) + epochs = Epochs( + raw, + events, + event_id, + tmin, + tmax, + picks=picks, + preload=True, + baseline=None, + verbose=False, + ) # Test estimator pytest.raises(ValueError, UnsupervisedSpatialFilter, KernelRidge(2)) @@ -218,34 +269,39 @@ def test_temporal_filter(): X = np.random.rand(5, 5, 1200) # Test init test - values = (('10hz', None, 100., 'auto'), (5., '10hz', 100., 'auto'), - (10., 20., 5., 'auto'), (None, None, 100., '5hz')) + values = ( + ("10hz", None, 100.0, "auto"), + (5.0, "10hz", 100.0, "auto"), + (10.0, 20.0, 5.0, "auto"), + (None, None, 100.0, "5hz"), + ) for low, high, sf, ltrans in values: - filt = TemporalFilter(low, high, sf, ltrans, fir_design='firwin') + filt = TemporalFilter(low, high, sf, ltrans, fir_design="firwin") pytest.raises(ValueError, filt.fit_transform, X) # Add tests for different combinations of l_freq and h_freq - for low, high in ((5., 15.), (None, 15.), (5., None)): - filt = TemporalFilter(low, high, sfreq=100., fir_design='firwin') + for low, high in ((5.0, 15.0), (None, 15.0), (5.0, None)): + filt = TemporalFilter(low, high, sfreq=100.0, fir_design="firwin") Xt = filt.fit_transform(X) assert_array_equal(filt.fit_transform(X), Xt) - assert (X.shape == Xt.shape) + assert X.shape == Xt.shape # Test fit and transform numpy type check - with pytest.raises(ValueError, match='Data to be filtered must be'): + with pytest.raises(ValueError, match="Data to be filtered must be"): filt.transform([1, 2]) # Test with 2 dimensional data array X = np.random.rand(101, 500) - filt = TemporalFilter(l_freq=25., h_freq=50., sfreq=1000., - filter_length=150, fir_design='firwin2') - with use_log_level('error'): # warning about transition bandwidth + filt = TemporalFilter( + l_freq=25.0, h_freq=50.0, sfreq=1000.0, filter_length=150, fir_design="firwin2" + ) + with use_log_level("error"): # warning about transition bandwidth assert_equal(filt.fit_transform(X).shape, X.shape) def test_bad_triage(): """Test for gh-10924.""" - filt = TemporalFilter(l_freq=8, h_freq=60, sfreq=160.) + filt = TemporalFilter(l_freq=8, h_freq=60, sfreq=160.0) # Used to fail with "ValueError: Effective band-stop frequency (135.0) is # too high (maximum based on Nyquist is 80.0)" filt.fit_transform(np.zeros((1, 1, 481))) diff --git a/mne/decoding/time_delaying_ridge.py b/mne/decoding/time_delaying_ridge.py index 2d3d13f1300..2299aa5d861 100644 --- a/mne/decoding/time_delaying_ridge.py +++ b/mne/decoding/time_delaying_ridge.py @@ -13,8 +13,9 @@ from ..utils import warn, ProgressBar, logger -def _compute_corrs(X, y, smin, smax, n_jobs=None, fit_intercept=False, - edge_correction=True): +def _compute_corrs( + X, y, smin, smax, n_jobs=None, fit_intercept=False, edge_correction=True +): """Compute auto- and cross-correlations.""" if fit_intercept: # We could do this in the Fourier domain, too, but it should @@ -27,7 +28,7 @@ def _compute_corrs(X, y, smin, smax, n_jobs=None, fit_intercept=False, X = X - X_offset y = y - y_offset else: - X_offset = y_offset = 0. + X_offset = y_offset = 0.0 if X.ndim == 2: assert y.ndim == 2 X = X[:, np.newaxis, :] @@ -41,7 +42,8 @@ def _compute_corrs(X, y, smin, smax, n_jobs=None, fit_intercept=False, n_fft = next_fast_len(2 * X.shape[0] - 1) _, cuda_dict = _setup_cuda_fft_multiply_repeated( - n_jobs, [1.], n_fft, 'correlation calculations') + n_jobs, [1.0], n_fft, "correlation calculations" + ) del n_jobs # only used to set as CUDA # create our Toeplitz indexer @@ -49,26 +51,27 @@ def _compute_corrs(X, y, smin, smax, n_jobs=None, fit_intercept=False, for ii in range(len_trf): ij[ii, ii:] = np.arange(len_trf - ii) x = np.arange(n_fft - 1, n_fft - len_trf + ii, -1) - ij[ii + 1:, ii] = x + ij[ii + 1 :, ii] = x x_xt = np.zeros([n_ch_x * len_trf] * 2) - x_y = np.zeros((len_trf, n_ch_x, n_ch_y), order='F') + x_y = np.zeros((len_trf, n_ch_x, n_ch_y), order="F") n = n_epochs * (n_ch_x * (n_ch_x + 1) // 2 + n_ch_x) - logger.info('Fitting %d epochs, %d channels' % (n_epochs, n_ch_x)) - pb = ProgressBar(n, mesg='Sample') + logger.info("Fitting %d epochs, %d channels" % (n_epochs, n_ch_x)) + pb = ProgressBar(n, mesg="Sample") count = 0 pb.update(count) for ei in range(n_epochs): this_X = X[:, ei, :] # XXX maybe this is what we should parallelize over CPUs at some point - X_fft = cuda_dict['rfft'](this_X, n=n_fft, axis=0) + X_fft = cuda_dict["rfft"](this_X, n=n_fft, axis=0) X_fft_conj = X_fft.conj() - y_fft = cuda_dict['rfft'](y[:, ei, :], n=n_fft, axis=0) + y_fft = cuda_dict["rfft"](y[:, ei, :], n=n_fft, axis=0) for ch0 in range(n_ch_x): for oi, ch1 in enumerate(range(ch0, n_ch_x)): - this_result = cuda_dict['irfft']( - X_fft[:, ch0] * X_fft_conj[:, ch1], n=n_fft, axis=0) + this_result = cuda_dict["irfft"]( + X_fft[:, ch0] * X_fft_conj[:, ch1], n=n_fft, axis=0 + ) # Our autocorrelation structure is a Toeplitz matrix, but # it's faster to create the Toeplitz ourselves than use # linalg.toeplitz. @@ -85,40 +88,43 @@ def _compute_corrs(X, y, smin, smax, n_jobs=None, fit_intercept=False, _edge_correct(this_result, this_X, smax, smin, ch0, ch1) # Store the results in our output matrix - x_xt[ch0 * len_trf:(ch0 + 1) * len_trf, - ch1 * len_trf:(ch1 + 1) * len_trf] += this_result + x_xt[ + ch0 * len_trf : (ch0 + 1) * len_trf, + ch1 * len_trf : (ch1 + 1) * len_trf, + ] += this_result if ch0 != ch1: - x_xt[ch1 * len_trf:(ch1 + 1) * len_trf, - ch0 * len_trf:(ch0 + 1) * len_trf] += this_result.T + x_xt[ + ch1 * len_trf : (ch1 + 1) * len_trf, + ch0 * len_trf : (ch0 + 1) * len_trf, + ] += this_result.T count += 1 pb.update(count) # compute the crosscorrelations - cc_temp = cuda_dict['irfft']( - y_fft * X_fft_conj[:, slice(ch0, ch0 + 1)], n=n_fft, axis=0) + cc_temp = cuda_dict["irfft"]( + y_fft * X_fft_conj[:, slice(ch0, ch0 + 1)], n=n_fft, axis=0 + ) if smin < 0 and smax >= 0: x_y[:-smin, ch0] += cc_temp[smin:] - x_y[len_trf - smax:, ch0] += cc_temp[:smax] + x_y[len_trf - smax :, ch0] += cc_temp[:smax] else: x_y[:, ch0] += cc_temp[smin:smax] count += 1 pb.update(count) - x_y = np.reshape(x_y, (n_ch_x * len_trf, n_ch_y), order='F') + x_y = np.reshape(x_y, (n_ch_x * len_trf, n_ch_y), order="F") return x_xt, x_y, n_ch_x, X_offset, y_offset @jit() def _edge_correct(this_result, this_X, smax, smin, ch0, ch1): if smax > 0: - tail = _toeplitz_dot(this_X[-1:-smax:-1, ch0], - this_X[-1:-smax:-1, ch1]) + tail = _toeplitz_dot(this_X[-1:-smax:-1, ch0], this_X[-1:-smax:-1, ch1]) if smin > 0: - tail = tail[smin - 1:, smin - 1:] - this_result[max(-smin + 1, 0):, max(-smin + 1, 0):] -= tail + tail = tail[smin - 1 :, smin - 1 :] + this_result[max(-smin + 1, 0) :, max(-smin + 1, 0) :] -= tail if smin < 0: - head = _toeplitz_dot(this_X[:-smin, ch0], - this_X[:-smin, ch1])[::-1, ::-1] + head = _toeplitz_dot(this_X[:-smin, ch0], this_X[:-smin, ch1])[::-1, ::-1] if smax < 0: head = head[:smax, :smax] this_result[:-smin, :-smin] -= head @@ -136,28 +142,28 @@ def _toeplitz_dot(a, b): assert a.shape == b.shape and a.ndim == 1 out = np.outer(a, b) for ii in range(1, len(a)): - out[ii, ii:] += out[ii - 1, ii - 1:-1] - out[ii + 1:, ii] += out[ii:-1, ii - 1] + out[ii, ii:] += out[ii - 1, ii - 1 : -1] + out[ii + 1 :, ii] += out[ii:-1, ii - 1] return out -def _compute_reg_neighbors(n_ch_x, n_delays, reg_type, method='direct', - normed=False): +def _compute_reg_neighbors(n_ch_x, n_delays, reg_type, method="direct", normed=False): """Compute regularization parameter from neighbors.""" from scipy import linalg from scipy.sparse.csgraph import laplacian - known_types = ('ridge', 'laplacian') + + known_types = ("ridge", "laplacian") if isinstance(reg_type, str): reg_type = (reg_type,) * 2 if len(reg_type) != 2: - raise ValueError('reg_type must have two elements, got %s' - % (len(reg_type),)) + raise ValueError("reg_type must have two elements, got %s" % (len(reg_type),)) for r in reg_type: if r not in known_types: - raise ValueError('reg_type entries must be one of %s, got %s' - % (known_types, r)) - reg_time = (reg_type[0] == 'laplacian' and n_delays > 1) - reg_chs = (reg_type[1] == 'laplacian' and n_ch_x > 1) + raise ValueError( + "reg_type entries must be one of %s, got %s" % (known_types, r) + ) + reg_time = reg_type[0] == "laplacian" and n_delays > 1 + reg_chs = reg_type[1] == "laplacian" and n_ch_x > 1 if not reg_time and not reg_chs: return np.eye(n_ch_x * n_delays) # regularize time @@ -166,7 +172,7 @@ def _compute_reg_neighbors(n_ch_x, n_delays, reg_type, method='direct', stride = n_delays + 1 reg.flat[1::stride] += -1 reg.flat[n_delays::stride] += -1 - reg.flat[n_delays + 1:-n_delays - 1:stride] += 1 + reg.flat[n_delays + 1 : -n_delays - 1 : stride] += 1 args = [reg] * n_ch_x reg = linalg.block_diag(*args) else: @@ -178,12 +184,12 @@ def _compute_reg_neighbors(n_ch_x, n_delays, reg_type, method='direct', row_offset = block * n_ch_x stride = n_delays * n_ch_x + 1 reg.flat[n_delays:-row_offset:stride] += -1 - reg.flat[n_delays + row_offset::stride] += 1 + reg.flat[n_delays + row_offset :: stride] += 1 reg.flat[row_offset:-n_delays:stride] += -1 - reg.flat[:-(n_delays + row_offset):stride] += 1 + reg.flat[: -(n_delays + row_offset) : stride] += 1 assert np.array_equal(reg[::-1, ::-1], reg) - if method == 'direct': + if method == "direct": if normed: norm = np.sqrt(np.diag(reg)) reg /= norm @@ -201,6 +207,7 @@ def _fit_corrs(x_xt, x_y, n_ch_x, reg_type, alpha, n_ch_in): """Fit the model using correlation matrices.""" # do the regularized solving from scipy import linalg + n_ch_out = x_y.shape[1] assert x_y.shape[0] % n_ch_x == 0 n_delays = x_y.shape[0] // n_ch_x @@ -211,11 +218,13 @@ def _fit_corrs(x_xt, x_y, n_ch_x, reg_type, alpha, n_ch_in): # Note: we must use overwrite_a=False in order to be able to # use the fall-back solution below in case a LinAlgError # is raised - w = linalg.solve(mat, x_y, overwrite_a=False, assume_a='pos') + w = linalg.solve(mat, x_y, overwrite_a=False, assume_a="pos") except np.linalg.LinAlgError: - warn('Singular matrix in solving dual problem. Using ' - 'least-squares solution instead.') - w = linalg.lstsq(mat, x_y, lapack_driver='gelsy')[0] + warn( + "Singular matrix in solving dual problem. Using " + "least-squares solution instead." + ) + w = linalg.lstsq(mat, x_y, lapack_driver="gelsy")[0] w = w.T.reshape([n_ch_out, n_ch_in, n_delays]) return w @@ -270,11 +279,19 @@ class TimeDelayingRidge(BaseEstimator): _estimator_type = "regressor" - def __init__(self, tmin, tmax, sfreq, alpha=0., reg_type='ridge', - fit_intercept=True, n_jobs=None, edge_correction=True): + def __init__( + self, + tmin, + tmax, + sfreq, + alpha=0.0, + reg_type="ridge", + fit_intercept=True, + n_jobs=None, + edge_correction=True, + ): if tmin > tmax: - raise ValueError('tmin must be <= tmax, got %s and %s' - % (tmin, tmax)) + raise ValueError("tmin must be <= tmax, got %s and %s" % (tmin, tmax)) self.tmin = float(tmin) self.tmax = float(tmax) self.sfreq = float(sfreq) @@ -317,15 +334,22 @@ def fit(self, X, y): # might want to allow people to do them separately (e.g., to test # different regularization parameters). self.cov_, x_y_, n_ch_x, X_offset, y_offset = _compute_corrs( - X, y, self._smin, self._smax, self.n_jobs, self.fit_intercept, - self.edge_correction) - self.coef_ = _fit_corrs(self.cov_, x_y_, n_ch_x, - self.reg_type, self.alpha, n_ch_x) + X, + y, + self._smin, + self._smax, + self.n_jobs, + self.fit_intercept, + self.edge_correction, + ) + self.coef_ = _fit_corrs( + self.cov_, x_y_, n_ch_x, self.reg_type, self.alpha, n_ch_x + ) # This is the sklearn formula from LinearModel (will be 0. for no fit) if self.fit_intercept: self.intercept_ = y_offset - np.dot(X_offset, self.coef_.sum(-1).T) else: - self.intercept_ = 0. + self.intercept_ = 0.0 return self def predict(self, X): @@ -355,8 +379,8 @@ def predict(self, X): for oi in range(self.coef_.shape[0]): for fi in range(self.coef_.shape[1]): temp = fftconvolve(X[:, ei, fi], self.coef_[oi, fi]) - temp = temp[max(-smin, 0):][:len(out) - offset] - out[offset:len(temp) + offset, ei, oi] += temp + temp = temp[max(-smin, 0) :][: len(out) - offset] + out[offset : len(temp) + offset, ei, oi] += temp out += self.intercept_ if singleton: out = out[:, 0, :] diff --git a/mne/decoding/time_frequency.py b/mne/decoding/time_frequency.py index 330cc1ed5c8..d6ed4f6dd56 100644 --- a/mne/decoding/time_frequency.py +++ b/mne/decoding/time_frequency.py @@ -60,13 +60,22 @@ class TimeFrequency(TransformerMixin, BaseEstimator): """ @verbose - def __init__(self, freqs, sfreq=1.0, method='morlet', n_cycles=7.0, - time_bandwidth=None, use_fft=True, decim=1, output='complex', - n_jobs=1, verbose=None): # noqa: D102 + def __init__( + self, + freqs, + sfreq=1.0, + method="morlet", + n_cycles=7.0, + time_bandwidth=None, + use_fft=True, + decim=1, + output="complex", + n_jobs=1, + verbose=None, + ): # noqa: D102 """Init TimeFrequency transformer.""" # Check non-average output - output = _check_option('output', output, - ['complex', 'power', 'phase']) + output = _check_option("output", output, ["complex", "power", "phase"]) self.freqs = freqs self.sfreq = sfreq @@ -137,10 +146,20 @@ def transform(self, X): X = X[:, np.newaxis, :] # Compute time-frequency - Xt = _compute_tfr(X, self.freqs, self.sfreq, self.method, - self.n_cycles, True, self.time_bandwidth, - self.use_fft, self.decim, self.output, self.n_jobs, - self.verbose) + Xt = _compute_tfr( + X, + self.freqs, + self.sfreq, + self.method, + self.n_cycles, + True, + self.time_bandwidth, + self.use_fft, + self.decim, + self.output, + self.n_jobs, + self.verbose, + ) # Back to original shape if not shape: diff --git a/mne/decoding/transformer.py b/mne/decoding/transformer.py index b6faf66cf97..2d4316e768a 100644 --- a/mne/decoding/transformer.py +++ b/mne/decoding/transformer.py @@ -13,12 +13,11 @@ from ..filter import filter_data from ..time_frequency import psd_array_multitaper from ..utils import fill_doc, _check_option, _validate_type, verbose -from ..io.pick import (pick_info, _pick_data_channels, _picks_by_type, - _picks_to_idx) +from ..io.pick import pick_info, _pick_data_channels, _picks_by_type, _picks_to_idx from ..cov import _check_scalings_user -class _ConstantScaler(): +class _ConstantScaler: """Scale channel types using constant values.""" def __init__(self, info, scalings, do_scaling=True): @@ -28,15 +27,17 @@ def __init__(self, info, scalings, do_scaling=True): def fit(self, X, y=None): scalings = _check_scalings_user(self._scalings) - picks_by_type = _picks_by_type(pick_info( - self._info, _pick_data_channels(self._info, exclude=()))) + picks_by_type = _picks_by_type( + pick_info(self._info, _pick_data_channels(self._info, exclude=())) + ) std = np.ones(sum(len(p[1]) for p in picks_by_type)) if X.shape[1] != len(std): - raise ValueError('info had %d data channels but X has %d channels' - % (len(std), len(X))) + raise ValueError( + "info had %d data channels but X has %d channels" % (len(std), len(X)) + ) if self._do_scaling: # this is silly, but necessary for completeness for kind, picks in picks_by_type: - std[picks] = 1. / scalings[kind] + std[picks] = 1.0 / scalings[kind] self.std_ = std self.mean_ = np.zeros_like(std) return self @@ -101,31 +102,38 @@ class Scaler(TransformerMixin, BaseEstimator): if ``scalings`` is a dict or None). """ - def __init__(self, info=None, scalings=None, with_mean=True, - with_std=True): # noqa: D102 + def __init__( + self, info=None, scalings=None, with_mean=True, with_std=True + ): # noqa: D102 self.info = info self.with_mean = with_mean self.with_std = with_std self.scalings = scalings if not (scalings is None or isinstance(scalings, (dict, str))): - raise ValueError('scalings type should be dict, str, or None, ' - 'got %s' % type(scalings)) + raise ValueError( + "scalings type should be dict, str, or None, " "got %s" % type(scalings) + ) if isinstance(scalings, str): - _check_option('scalings', scalings, ['mean', 'median']) + _check_option("scalings", scalings, ["mean", "median"]) if scalings is None or isinstance(scalings, dict): if info is None: - raise ValueError('Need to specify "info" if scalings is' - '%s' % type(scalings)) + raise ValueError( + 'Need to specify "info" if scalings is' "%s" % type(scalings) + ) self._scaler = _ConstantScaler(info, scalings, self.with_std) - elif scalings == 'mean': + elif scalings == "mean": from sklearn.preprocessing import StandardScaler + self._scaler = StandardScaler( - with_mean=self.with_mean, with_std=self.with_std) + with_mean=self.with_mean, with_std=self.with_std + ) else: # scalings == 'median': from sklearn.preprocessing import RobustScaler + self._scaler = RobustScaler( - with_centering=self.with_mean, with_scaling=self.with_std) + with_centering=self.with_mean, with_scaling=self.with_std + ) def fit(self, epochs_data, y=None): """Standardize data across channels. @@ -142,7 +150,7 @@ def fit(self, epochs_data, y=None): self : instance of Scaler The modified instance. """ - _validate_type(epochs_data, np.ndarray, 'epochs_data') + _validate_type(epochs_data, np.ndarray, "epochs_data") if epochs_data.ndim == 2: epochs_data = epochs_data[..., np.newaxis] assert epochs_data.ndim == 3, epochs_data.shape @@ -167,14 +175,13 @@ def transform(self, epochs_data): This function makes a copy of the data before the operations and the memory usage may be large with big data. """ - _validate_type(epochs_data, np.ndarray, 'epochs_data') + _validate_type(epochs_data, np.ndarray, "epochs_data") if epochs_data.ndim == 2: # can happen with SlidingEstimator if self.info is not None: - assert len(self.info['ch_names']) == epochs_data.shape[1] + assert len(self.info["ch_names"]) == epochs_data.shape[1] epochs_data = epochs_data[..., np.newaxis] assert epochs_data.ndim == 3, epochs_data.shape - return _sklearn_reshape_apply(self._scaler.transform, True, - epochs_data) + return _sklearn_reshape_apply(self._scaler.transform, True, epochs_data) def fit_transform(self, epochs_data, y=None): """Fit to data, then transform it. @@ -221,8 +228,7 @@ def inverse_transform(self, epochs_data): memory usage may be large with big data. """ assert epochs_data.ndim == 3, epochs_data.shape - return _sklearn_reshape_apply(self._scaler.inverse_transform, True, - epochs_data) + return _sklearn_reshape_apply(self._scaler.inverse_transform, True, epochs_data) class Vectorizer(TransformerMixin): @@ -282,8 +288,7 @@ def transform(self, X): """ X = np.asarray(X) if X.shape[1:] != self.features_shape_: - raise ValueError("Shape of X used in fit and transform must be " - "same") + raise ValueError("Shape of X used in fit and transform must be " "same") return X.reshape(len(X), -1) def fit_transform(self, X, y=None): @@ -322,8 +327,9 @@ def inverse_transform(self, X): """ X = np.asarray(X) if X.ndim not in (2, 3): - raise ValueError("X should be of 2 or 3 dimensions but has shape " - "%s" % (X.shape,)) + raise ValueError( + "X should be of 2 or 3 dimensions but has shape " "%s" % (X.shape,) + ) return X.reshape(X.shape[:-1] + self.features_shape_) @@ -361,9 +367,19 @@ class PSDEstimator(TransformerMixin): """ @verbose - def __init__(self, sfreq=2 * np.pi, fmin=0, fmax=np.inf, bandwidth=None, - adaptive=False, low_bias=True, n_jobs=None, - normalization='length', *, verbose=None): # noqa: D102 + def __init__( + self, + sfreq=2 * np.pi, + fmin=0, + fmax=np.inf, + bandwidth=None, + adaptive=False, + low_bias=True, + n_jobs=None, + normalization="length", + *, + verbose=None + ): # noqa: D102 self.sfreq = sfreq self.fmin = fmin self.fmax = fmax @@ -389,8 +405,9 @@ def fit(self, epochs_data, y): The modified instance. """ if not isinstance(epochs_data, np.ndarray): - raise ValueError("epochs_data should be of type ndarray (got %s)." - % type(epochs_data)) + raise ValueError( + "epochs_data should be of type ndarray (got %s)." % type(epochs_data) + ) return self @@ -408,13 +425,20 @@ def transform(self, epochs_data): The computed PSD. """ if not isinstance(epochs_data, np.ndarray): - raise ValueError("epochs_data should be of type ndarray (got %s)." - % type(epochs_data)) + raise ValueError( + "epochs_data should be of type ndarray (got %s)." % type(epochs_data) + ) psd, _ = psd_array_multitaper( - epochs_data, sfreq=self.sfreq, fmin=self.fmin, fmax=self.fmax, - bandwidth=self.bandwidth, adaptive=self.adaptive, - low_bias=self.low_bias, normalization=self.normalization, - n_jobs=self.n_jobs) + epochs_data, + sfreq=self.sfreq, + fmin=self.fmin, + fmax=self.fmax, + bandwidth=self.bandwidth, + adaptive=self.adaptive, + low_bias=self.low_bias, + normalization=self.normalization, + n_jobs=self.n_jobs, + ) return psd @@ -469,10 +493,22 @@ class FilterEstimator(TransformerMixin): caution. """ - def __init__(self, info, l_freq, h_freq, picks=None, filter_length='auto', - l_trans_bandwidth='auto', h_trans_bandwidth='auto', - n_jobs=None, method='fir', iir_params=None, - fir_design='firwin', *, verbose=None): # noqa: D102 + def __init__( + self, + info, + l_freq, + h_freq, + picks=None, + filter_length="auto", + l_trans_bandwidth="auto", + h_trans_bandwidth="auto", + n_jobs=None, + method="fir", + iir_params=None, + fir_design="firwin", + *, + verbose=None + ): # noqa: D102 self.info = info self.l_freq = l_freq self.h_freq = h_freq @@ -501,37 +537,39 @@ def fit(self, epochs_data, y): The modified instance. """ if not isinstance(epochs_data, np.ndarray): - raise ValueError("epochs_data should be of type ndarray (got %s)." - % type(epochs_data)) + raise ValueError( + "epochs_data should be of type ndarray (got %s)." % type(epochs_data) + ) if self.picks is None: - self.picks = pick_types(self.info, meg=True, eeg=True, - ref_meg=False, exclude=[]) + self.picks = pick_types( + self.info, meg=True, eeg=True, ref_meg=False, exclude=[] + ) if self.l_freq == 0: self.l_freq = None - if self.h_freq is not None and self.h_freq > (self.info['sfreq'] / 2.): + if self.h_freq is not None and self.h_freq > (self.info["sfreq"] / 2.0): self.h_freq = None if self.l_freq is not None and not isinstance(self.l_freq, float): self.l_freq = float(self.l_freq) if self.h_freq is not None and not isinstance(self.h_freq, float): self.h_freq = float(self.h_freq) - if self.info['lowpass'] is None or (self.h_freq is not None and - (self.l_freq is None or - self.l_freq < self.h_freq) and - self.h_freq < - self.info['lowpass']): + if self.info["lowpass"] is None or ( + self.h_freq is not None + and (self.l_freq is None or self.l_freq < self.h_freq) + and self.h_freq < self.info["lowpass"] + ): with self.info._unlock(): - self.info['lowpass'] = self.h_freq + self.info["lowpass"] = self.h_freq - if self.info['highpass'] is None or (self.l_freq is not None and - (self.h_freq is None or - self.l_freq < self.h_freq) and - self.l_freq > - self.info['highpass']): + if self.info["highpass"] is None or ( + self.l_freq is not None + and (self.h_freq is None or self.l_freq < self.h_freq) + and self.l_freq > self.info["highpass"] + ): with self.info._unlock(): - self.info['highpass'] = self.l_freq + self.info["highpass"] = self.l_freq return self @@ -549,15 +587,26 @@ def transform(self, epochs_data): The data after filtering. """ if not isinstance(epochs_data, np.ndarray): - raise ValueError("epochs_data should be of type ndarray (got %s)." - % type(epochs_data)) + raise ValueError( + "epochs_data should be of type ndarray (got %s)." % type(epochs_data) + ) epochs_data = np.atleast_3d(epochs_data) return filter_data( - epochs_data, self.info['sfreq'], self.l_freq, self.h_freq, - self.picks, self.filter_length, self.l_trans_bandwidth, - self.h_trans_bandwidth, method=self.method, - iir_params=self.iir_params, n_jobs=self.n_jobs, copy=False, - fir_design=self.fir_design, verbose=False) + epochs_data, + self.info["sfreq"], + self.l_freq, + self.h_freq, + self.picks, + self.filter_length, + self.l_trans_bandwidth, + self.h_trans_bandwidth, + method=self.method, + iir_params=self.iir_params, + n_jobs=self.n_jobs, + copy=False, + fir_design=self.fir_design, + verbose=False, + ) class UnsupervisedSpatialFilter(TransformerMixin, BaseEstimator): @@ -574,14 +623,17 @@ class UnsupervisedSpatialFilter(TransformerMixin, BaseEstimator): def __init__(self, estimator, average=False): # noqa: D102 # XXX: Use _check_estimator #3381 - for attr in ('fit', 'transform', 'fit_transform'): + for attr in ("fit", "transform", "fit_transform"): if not hasattr(estimator, attr): - raise ValueError('estimator must be a scikit-learn ' - 'transformer, missing %s method' % attr) + raise ValueError( + "estimator must be a scikit-learn " + "transformer, missing %s method" % attr + ) if not isinstance(average, bool): - raise ValueError("average parameter must be of bool type, got " - "%s instead" % type(bool)) + raise ValueError( + "average parameter must be of bool type, got " "%s instead" % type(bool) + ) self.estimator = estimator self.average = average @@ -606,8 +658,7 @@ def fit(self, X, y=None): else: n_epochs, n_channels, n_times = X.shape # trial as time samples - X = np.transpose(X, (1, 0, 2)).reshape((n_channels, n_epochs * - n_times)).T + X = np.transpose(X, (1, 0, 2)).reshape((n_channels, n_epochs * n_times)).T self.estimator.fit(X) return self @@ -641,7 +692,7 @@ def transform(self, X): X : array, shape (n_epochs, n_channels, n_times) The transformed data. """ - return self._apply_method(X, 'transform') + return self._apply_method(X, "transform") def inverse_transform(self, X): """Inverse transform the data to its original space. @@ -656,7 +707,7 @@ def inverse_transform(self, X): X : array, shape (n_epochs, n_channels, n_times) The transformed data. """ - return self._apply_method(X, 'inverse_transform') + return self._apply_method(X, "inverse_transform") def _apply_method(self, X, method): """Vectorize time samples as trials, apply method and reshape back. @@ -768,11 +819,22 @@ class TemporalFilter(TransformerMixin): """ @verbose - def __init__(self, l_freq=None, h_freq=None, sfreq=1.0, - filter_length='auto', l_trans_bandwidth='auto', - h_trans_bandwidth='auto', n_jobs=None, method='fir', - iir_params=None, fir_window='hamming', fir_design='firwin', - *, verbose=None): # noqa: D102 + def __init__( + self, + l_freq=None, + h_freq=None, + sfreq=1.0, + filter_length="auto", + l_trans_bandwidth="auto", + h_trans_bandwidth="auto", + n_jobs=None, + method="fir", + iir_params=None, + fir_window="hamming", + fir_design="firwin", + *, + verbose=None + ): # noqa: D102 self.l_freq = l_freq self.h_freq = h_freq self.sfreq = sfreq @@ -785,9 +847,10 @@ def __init__(self, l_freq=None, h_freq=None, sfreq=1.0, self.fir_window = fir_window self.fir_design = fir_design - if not isinstance(self.n_jobs, int) and self.n_jobs == 'cuda': - raise ValueError('n_jobs must be int or "cuda", got %s instead.' - % type(self.n_jobs)) + if not isinstance(self.n_jobs, int) and self.n_jobs == "cuda": + raise ValueError( + 'n_jobs must be int or "cuda", got %s instead.' % type(self.n_jobs) + ) def fit(self, X, y=None): """Do nothing (for scikit-learn compatibility purposes). @@ -824,16 +887,26 @@ def transform(self, X): X = np.atleast_2d(X) if X.ndim > 3: - raise ValueError("Array must be of at max 3 dimensions instead " - "got %s dimensional matrix" % (X.ndim)) + raise ValueError( + "Array must be of at max 3 dimensions instead " + "got %s dimensional matrix" % (X.ndim) + ) shape = X.shape X = X.reshape(-1, shape[-1]) - X = filter_data(X, self.sfreq, self.l_freq, self.h_freq, - filter_length=self.filter_length, - l_trans_bandwidth=self.l_trans_bandwidth, - h_trans_bandwidth=self.h_trans_bandwidth, - n_jobs=self.n_jobs, method=self.method, - iir_params=self.iir_params, copy=False, - fir_window=self.fir_window, fir_design=self.fir_design) + X = filter_data( + X, + self.sfreq, + self.l_freq, + self.h_freq, + filter_length=self.filter_length, + l_trans_bandwidth=self.l_trans_bandwidth, + h_trans_bandwidth=self.h_trans_bandwidth, + n_jobs=self.n_jobs, + method=self.method, + iir_params=self.iir_params, + copy=False, + fir_window=self.fir_window, + fir_design=self.fir_design, + ) return X.reshape(shape) diff --git a/mne/defaults.py b/mne/defaults.py index 16b3b843406..498312caa15 100644 --- a/mne/defaults.py +++ b/mne/defaults.py @@ -7,81 +7,232 @@ from copy import deepcopy DEFAULTS = dict( - color=dict(mag='darkblue', grad='b', eeg='k', eog='k', ecg='m', emg='k', - ref_meg='steelblue', misc='k', stim='k', resp='k', chpi='k', - exci='k', ias='k', syst='k', seeg='saddlebrown', dbs='seagreen', - dipole='k', gof='k', bio='k', ecog='k', hbo='#AA3377', hbr='b', - fnirs_cw_amplitude='k', fnirs_fd_ac_amplitude='k', - fnirs_fd_phase='k', fnirs_od='k', csd='k', whitened='k', - gsr='#666633', temperature='#663333', - eyegaze='k', pupil='k'), - si_units=dict(mag='T', grad='T/m', eeg='V', eog='V', ecg='V', emg='V', - misc='AU', seeg='V', dbs='V', dipole='Am', gof='GOF', - bio='V', ecog='V', hbo='M', hbr='M', ref_meg='T', - fnirs_cw_amplitude='V', fnirs_fd_ac_amplitude='V', - fnirs_fd_phase='rad', fnirs_od='V', csd='V/m²', - whitened='Z', gsr='S', temperature='C', - eyegaze='AU', pupil='AU'), - units=dict(mag='fT', grad='fT/cm', eeg='µV', eog='µV', ecg='µV', emg='µV', - misc='AU', seeg='mV', dbs='µV', dipole='nAm', gof='GOF', - bio='µV', ecog='µV', hbo='µM', hbr='µM', ref_meg='fT', - fnirs_cw_amplitude='V', fnirs_fd_ac_amplitude='V', - fnirs_fd_phase='rad', fnirs_od='V', csd='mV/m²', - whitened='Z', gsr='S', temperature='C', - eyegaze='AU', pupil='AU'), + color=dict( + mag="darkblue", + grad="b", + eeg="k", + eog="k", + ecg="m", + emg="k", + ref_meg="steelblue", + misc="k", + stim="k", + resp="k", + chpi="k", + exci="k", + ias="k", + syst="k", + seeg="saddlebrown", + dbs="seagreen", + dipole="k", + gof="k", + bio="k", + ecog="k", + hbo="#AA3377", + hbr="b", + fnirs_cw_amplitude="k", + fnirs_fd_ac_amplitude="k", + fnirs_fd_phase="k", + fnirs_od="k", + csd="k", + whitened="k", + gsr="#666633", + temperature="#663333", + eyegaze="k", + pupil="k", + ), + si_units=dict( + mag="T", + grad="T/m", + eeg="V", + eog="V", + ecg="V", + emg="V", + misc="AU", + seeg="V", + dbs="V", + dipole="Am", + gof="GOF", + bio="V", + ecog="V", + hbo="M", + hbr="M", + ref_meg="T", + fnirs_cw_amplitude="V", + fnirs_fd_ac_amplitude="V", + fnirs_fd_phase="rad", + fnirs_od="V", + csd="V/m²", + whitened="Z", + gsr="S", + temperature="C", + eyegaze="AU", + pupil="AU", + ), + units=dict( + mag="fT", + grad="fT/cm", + eeg="µV", + eog="µV", + ecg="µV", + emg="µV", + misc="AU", + seeg="mV", + dbs="µV", + dipole="nAm", + gof="GOF", + bio="µV", + ecog="µV", + hbo="µM", + hbr="µM", + ref_meg="fT", + fnirs_cw_amplitude="V", + fnirs_fd_ac_amplitude="V", + fnirs_fd_phase="rad", + fnirs_od="V", + csd="mV/m²", + whitened="Z", + gsr="S", + temperature="C", + eyegaze="AU", + pupil="AU", + ), # scalings for the units - scalings=dict(mag=1e15, grad=1e13, eeg=1e6, eog=1e6, emg=1e6, ecg=1e6, - misc=1.0, seeg=1e3, dbs=1e6, ecog=1e6, dipole=1e9, gof=1.0, - bio=1e6, hbo=1e6, hbr=1e6, ref_meg=1e15, - fnirs_cw_amplitude=1.0, fnirs_fd_ac_amplitude=1.0, - fnirs_fd_phase=1., fnirs_od=1.0, csd=1e3, whitened=1., - gsr=1., temperature=1., eyegaze=1., pupil=1.), + scalings=dict( + mag=1e15, + grad=1e13, + eeg=1e6, + eog=1e6, + emg=1e6, + ecg=1e6, + misc=1.0, + seeg=1e3, + dbs=1e6, + ecog=1e6, + dipole=1e9, + gof=1.0, + bio=1e6, + hbo=1e6, + hbr=1e6, + ref_meg=1e15, + fnirs_cw_amplitude=1.0, + fnirs_fd_ac_amplitude=1.0, + fnirs_fd_phase=1.0, + fnirs_od=1.0, + csd=1e3, + whitened=1.0, + gsr=1.0, + temperature=1.0, + eyegaze=1.0, + pupil=1.0, + ), # rough guess for a good plot - scalings_plot_raw=dict(mag=1e-12, grad=4e-11, eeg=20e-6, eog=150e-6, - ecg=5e-4, emg=1e-3, ref_meg=1e-12, misc='auto', - stim=1, resp=1, chpi=1e-4, exci=1, ias=1, syst=1, - seeg=1e-4, dbs=1e-4, bio=1e-6, ecog=1e-4, hbo=10e-6, - hbr=10e-6, whitened=10., fnirs_cw_amplitude=2e-2, - fnirs_fd_ac_amplitude=2e-2, fnirs_fd_phase=2e-1, - fnirs_od=2e-2, csd=200e-4, - dipole=1e-7, gof=1e2, - gsr=1., temperature=0.1, - eyegaze=3e-1, pupil=1e3), - scalings_cov_rank=dict(mag=1e12, grad=1e11, eeg=1e5, # ~100x scalings - seeg=1e1, dbs=1e4, ecog=1e4, hbo=1e4, hbr=1e4), - ylim=dict(mag=(-600., 600.), grad=(-200., 200.), eeg=(-200., 200.), - misc=(-5., 5.), seeg=(-20., 20.), dbs=(-200., 200.), - dipole=(-100., 100.), gof=(0., 1.), bio=(-500., 500.), - ecog=(-200., 200.), hbo=(0, 20), hbr=(0, 20), csd=(-50., 50.), - eyegaze=(0., 5000.), pupil=(0., 5000.)), - titles=dict(mag='Magnetometers', grad='Gradiometers', eeg='EEG', eog='EOG', - ecg='ECG', emg='EMG', misc='misc', seeg='sEEG', dbs='DBS', - bio='BIO', dipole='Dipole', ecog='ECoG', hbo='Oxyhemoglobin', - ref_meg='Reference Magnetometers', - fnirs_cw_amplitude='fNIRS (CW amplitude)', - fnirs_fd_ac_amplitude='fNIRS (FD AC amplitude)', - fnirs_fd_phase='fNIRS (FD phase)', - fnirs_od='fNIRS (OD)', hbr='Deoxyhemoglobin', - gof='Goodness of fit', csd='Current source density', - stim='Stimulus', gsr='Galvanic skin response', - temperature='Temperature', - eyegaze='Eye-tracking (Gaze position)', - pupil='Eye-tracking (Pupil size)', - ), - mask_params=dict(marker='o', - markerfacecolor='w', - markeredgecolor='k', - linewidth=0, - markeredgewidth=1, - markersize=4), + scalings_plot_raw=dict( + mag=1e-12, + grad=4e-11, + eeg=20e-6, + eog=150e-6, + ecg=5e-4, + emg=1e-3, + ref_meg=1e-12, + misc="auto", + stim=1, + resp=1, + chpi=1e-4, + exci=1, + ias=1, + syst=1, + seeg=1e-4, + dbs=1e-4, + bio=1e-6, + ecog=1e-4, + hbo=10e-6, + hbr=10e-6, + whitened=10.0, + fnirs_cw_amplitude=2e-2, + fnirs_fd_ac_amplitude=2e-2, + fnirs_fd_phase=2e-1, + fnirs_od=2e-2, + csd=200e-4, + dipole=1e-7, + gof=1e2, + gsr=1.0, + temperature=0.1, + eyegaze=3e-1, + pupil=1e3, + ), + scalings_cov_rank=dict( + mag=1e12, + grad=1e11, + eeg=1e5, # ~100x scalings + seeg=1e1, + dbs=1e4, + ecog=1e4, + hbo=1e4, + hbr=1e4, + ), + ylim=dict( + mag=(-600.0, 600.0), + grad=(-200.0, 200.0), + eeg=(-200.0, 200.0), + misc=(-5.0, 5.0), + seeg=(-20.0, 20.0), + dbs=(-200.0, 200.0), + dipole=(-100.0, 100.0), + gof=(0.0, 1.0), + bio=(-500.0, 500.0), + ecog=(-200.0, 200.0), + hbo=(0, 20), + hbr=(0, 20), + csd=(-50.0, 50.0), + eyegaze=(0.0, 5000.0), + pupil=(0.0, 5000.0), + ), + titles=dict( + mag="Magnetometers", + grad="Gradiometers", + eeg="EEG", + eog="EOG", + ecg="ECG", + emg="EMG", + misc="misc", + seeg="sEEG", + dbs="DBS", + bio="BIO", + dipole="Dipole", + ecog="ECoG", + hbo="Oxyhemoglobin", + ref_meg="Reference Magnetometers", + fnirs_cw_amplitude="fNIRS (CW amplitude)", + fnirs_fd_ac_amplitude="fNIRS (FD AC amplitude)", + fnirs_fd_phase="fNIRS (FD phase)", + fnirs_od="fNIRS (OD)", + hbr="Deoxyhemoglobin", + gof="Goodness of fit", + csd="Current source density", + stim="Stimulus", + gsr="Galvanic skin response", + temperature="Temperature", + eyegaze="Eye-tracking (Gaze position)", + pupil="Eye-tracking (Pupil size)", + ), + mask_params=dict( + marker="o", + markerfacecolor="w", + markeredgecolor="k", + linewidth=0, + markeredgewidth=1, + markersize=4, + ), coreg=dict( mri_fid_opacity=1.0, dig_fid_opacity=1.0, - mri_fid_scale=5e-3, dig_fid_scale=8e-3, extra_scale=4e-3, - eeg_scale=4e-3, eegp_scale=20e-3, eegp_height=0.1, + eeg_scale=4e-3, + eegp_scale=20e-3, + eegp_height=0.1, ecog_scale=5e-3, seeg_scale=5e-3, dbs_scale=5e-3, @@ -89,49 +240,74 @@ source_scale=5e-3, detector_scale=5e-3, hpi_scale=4e-3, - head_color=(0.988, 0.89, 0.74), - hpi_color=(1., 0., 1.), - extra_color=(1., 1., 1.), - meg_color=(0., 0.25, 0.5), ref_meg_color=(0.5, 0.5, 0.5), + hpi_color=(1.0, 0.0, 1.0), + extra_color=(1.0, 1.0, 1.0), + meg_color=(0.0, 0.25, 0.5), + ref_meg_color=(0.5, 0.5, 0.5), helmet_color=(0.0, 0.0, 0.6), - eeg_color=(1., 0.596, 0.588), eegp_color=(0.839, 0.15, 0.16), - ecog_color=(1., 1., 1.), + eeg_color=(1.0, 0.596, 0.588), + eegp_color=(0.839, 0.15, 0.16), + ecog_color=(1.0, 1.0, 1.0), dbs_color=(0.82, 0.455, 0.659), - seeg_color=(1., 1., .3), - fnirs_color=(1., .647, 0.), - source_color=(1., .05, 0.), - detector_color=(.3, .15, .15), - lpa_color=(1., 0., 0.), - nasion_color=(0., 1., 0.), - rpa_color=(0., 0., 1.), + seeg_color=(1.0, 1.0, 0.3), + fnirs_color=(1.0, 0.647, 0.0), + source_color=(1.0, 0.05, 0.0), + detector_color=(0.3, 0.15, 0.15), + lpa_color=(1.0, 0.0, 0.0), + nasion_color=(0.0, 1.0, 0.0), + rpa_color=(0.0, 0.0, 1.0), ), noise_std=dict(grad=5e-13, mag=20e-15, eeg=0.2e-6), eloreta_options=dict(eps=1e-6, max_iter=20, force_equal=False), - depth_mne=dict(exp=0.8, limit=10., limit_depth_chs=True, - combine_xyz='spectral', allow_fixed_depth=False), - depth_sparse=dict(exp=0.8, limit=None, limit_depth_chs='whiten', - combine_xyz='fro', allow_fixed_depth=True), - interpolation_method=dict(eeg='spline', meg='MNE', fnirs='nearest'), + depth_mne=dict( + exp=0.8, + limit=10.0, + limit_depth_chs=True, + combine_xyz="spectral", + allow_fixed_depth=False, + ), + depth_sparse=dict( + exp=0.8, + limit=None, + limit_depth_chs="whiten", + combine_xyz="fro", + allow_fixed_depth=True, + ), + interpolation_method=dict(eeg="spline", meg="MNE", fnirs="nearest"), volume_options=dict( - alpha=None, resolution=1., surface_alpha=None, blending='mip', - silhouette_alpha=None, silhouette_linewidth=2.), - prefixes={'k': 1e-3, 'h': 1e-2, '': 1e0, 'd': 1e1, 'c': 1e2, 'm': 1e3, - 'µ': 1e6, 'u': 1e6, 'n': 1e9, 'p': 1e12, 'f': 1e15}, - transform_zooms=dict( - translation=None, rigid=None, affine=None, sdr=None), + alpha=None, + resolution=1.0, + surface_alpha=None, + blending="mip", + silhouette_alpha=None, + silhouette_linewidth=2.0, + ), + prefixes={ + "k": 1e-3, + "h": 1e-2, + "": 1e0, + "d": 1e1, + "c": 1e2, + "m": 1e3, + "µ": 1e6, + "u": 1e6, + "n": 1e9, + "p": 1e12, + "f": 1e15, + }, + transform_zooms=dict(translation=None, rigid=None, affine=None, sdr=None), transform_niter=dict( translation=(10000, 1000, 100), rigid=(10000, 1000, 100), affine=(10000, 1000, 100), - sdr=(10, 10, 5)), + sdr=(10, 10, 5), + ), volume_label_indices=( # Left and middle 4, # Left-Lateral-Ventricle 5, # Left-Inf-Lat-Vent - 8, # Left-Cerebellum-Cortex - 10, # Left-Thalamus-Proper 11, # Left-Caudate 12, # Left-Putamen @@ -141,44 +317,32 @@ 16, # Brain-Stem 17, # Left-Hippocampus 18, # Left-Amygdala - 26, # Left-Accumbens-area - 28, # Left-VentralDC - # Right 43, # Right-Lateral-Ventricle 44, # Right-Inf-Lat-Vent - 47, # Right-Cerebellum-Cortex - 49, # Right-Thalamus-Proper 50, # Right-Caudate 51, # Right-Putamen 52, # Right-Pallidum 53, # Right-Hippocampus 54, # Right-Amygdala - 58, # Right-Accumbens-area - 60, # Right-VentralDC ), report_stc_plot_kwargs=dict( - views=('lateral', 'medial'), - hemi='split', - backend='pyvistaqt', + views=("lateral", "medial"), + hemi="split", + backend="pyvistaqt", time_viewer=False, show_traces=False, size=(450, 450), - background='white', + background="white", time_label=None, - add_data_kwargs={ - 'colorbar_kwargs': { - 'label_font_size': 12, - 'n_labels': 5 - } - } - ) + add_data_kwargs={"colorbar_kwargs": {"label_font_size": 12, "n_labels": 5}}, + ), ) @@ -201,6 +365,6 @@ def _handle_default(k, v=None): HEAD_SIZE_DEFAULT = 0.095 # in [m] -_BORDER_DEFAULT = 'mean' -_INTERPOLATION_DEFAULT = 'cubic' -_EXTRAPOLATE_DEFAULT = 'auto' +_BORDER_DEFAULT = "mean" +_INTERPOLATION_DEFAULT = "cubic" +_EXTRAPOLATE_DEFAULT = "auto" diff --git a/mne/dipole.py b/mne/dipole.py index 65fe90a39a3..6083b9bfbdd 100644 --- a/mne/dipole.py +++ b/mne/dipole.py @@ -23,21 +23,36 @@ from .transforms import _print_coord_trans, _coord_frame_name, apply_trans from .viz.evoked import _plot_evoked from ._freesurfer import head_to_mni, head_to_mri -from .forward._make_forward import (_get_trans, _setup_bem, - _prep_meg_channels, _prep_eeg_channels) -from .forward._compute_forward import (_compute_forwards_meeg, - _prep_field_computation) - -from .surface import (transform_surface_to, _compute_nearest, - _points_outside_surface) +from .forward._make_forward import ( + _get_trans, + _setup_bem, + _prep_meg_channels, + _prep_eeg_channels, +) +from .forward._compute_forward import _compute_forwards_meeg, _prep_field_computation + +from .surface import transform_surface_to, _compute_nearest, _points_outside_surface from .bem import _bem_find_surface, _bem_surf_name from .source_space import _make_volume_source_space, SourceSpaces from .parallel import parallel_func -from .utils import (logger, verbose, _time_mask, warn, _check_fname, - check_fname, _pl, fill_doc, _check_option, - _svd_lwork, _repeated_svd, _get_blas_funcs, _validate_type, - copy_function_doc_to_method_doc, TimeMixin, - _verbose_safe_false) +from .utils import ( + logger, + verbose, + _time_mask, + warn, + _check_fname, + check_fname, + _pl, + fill_doc, + _check_option, + _svd_lwork, + _repeated_svd, + _get_blas_funcs, + _validate_type, + copy_function_doc_to_method_doc, + TimeMixin, + _verbose_safe_false, +) from .viz import plot_dipole_locations @@ -101,9 +116,20 @@ class Dipole(TimeMixin): """ @verbose - def __init__(self, times, pos, amplitude, ori, gof, - name=None, conf=None, khi2=None, nfree=None, - *, verbose=None): # noqa: D102 + def __init__( + self, + times, + pos, + amplitude, + ori, + gof, + name=None, + conf=None, + khi2=None, + nfree=None, + *, + verbose=None, + ): # noqa: D102 self._set_times(np.array(times)) self.pos = np.array(pos) self.amplitude = np.array(amplitude) @@ -168,11 +194,12 @@ def crop(self, tmin=None, tmax=None, include_tmax=True, verbose=None): """ sfreq = None if len(self.times) > 1: - sfreq = 1. / np.median(np.diff(self.times)) - mask = _time_mask(self.times, tmin, tmax, sfreq=sfreq, - include_tmax=include_tmax) + sfreq = 1.0 / np.median(np.diff(self.times)) + mask = _time_mask( + self.times, tmin, tmax, sfreq=sfreq, include_tmax=include_tmax + ) self._set_times(self.times[mask]) - for attr in ('pos', 'gof', 'amplitude', 'ori', 'khi2', 'nfree'): + for attr in ("pos", "gof", "amplitude", "ori", "khi2", "nfree"): if getattr(self, attr) is not None: setattr(self, attr, getattr(self, attr)[mask]) for key in self.conf.keys(): @@ -191,21 +218,53 @@ def copy(self): @verbose @copy_function_doc_to_method_doc(plot_dipole_locations) - def plot_locations(self, trans, subject, subjects_dir=None, - mode='orthoview', coord_frame='mri', idx='gof', - show_all=True, ax=None, block=False, show=True, - scale=None, color=None, *, highlight_color='r', - fig=None, title=None, head_source='seghead', - surf='pial', width=None, verbose=None): + def plot_locations( + self, + trans, + subject, + subjects_dir=None, + mode="orthoview", + coord_frame="mri", + idx="gof", + show_all=True, + ax=None, + block=False, + show=True, + scale=None, + color=None, + *, + highlight_color="r", + fig=None, + title=None, + head_source="seghead", + surf="pial", + width=None, + verbose=None, + ): return plot_dipole_locations( - self, trans, subject, subjects_dir, mode, coord_frame, idx, - show_all, ax, block, show, scale=scale, color=color, - highlight_color=highlight_color, fig=fig, title=title, - head_source=head_source, surf=surf, width=width) + self, + trans, + subject, + subjects_dir, + mode, + coord_frame, + idx, + show_all, + ax, + block, + show, + scale=scale, + color=color, + highlight_color=highlight_color, + fig=fig, + title=title, + head_source=head_source, + surf=surf, + width=width, + ) @verbose - def to_mni(self, subject, trans, subjects_dir=None, - verbose=None): + def to_mni(self, subject, trans, subjects_dir=None, verbose=None): """Convert dipole location from head to MNI coordinates. Parameters @@ -221,12 +280,12 @@ def to_mni(self, subject, trans, subjects_dir=None, The MNI coordinates (in mm) of pos. """ mri_head_t, trans = _get_trans(trans) - return head_to_mni(self.pos, subject, mri_head_t, - subjects_dir=subjects_dir, verbose=verbose) + return head_to_mni( + self.pos, subject, mri_head_t, subjects_dir=subjects_dir, verbose=verbose + ) @verbose - def to_mri(self, subject, trans, subjects_dir=None, - verbose=None): + def to_mri(self, subject, trans, subjects_dir=None, verbose=None): """Convert dipole location from head to MRI surface RAS coordinates. Parameters @@ -242,13 +301,24 @@ def to_mri(self, subject, trans, subjects_dir=None, The Freesurfer surface RAS coordinates (in mm) of pos. """ mri_head_t, trans = _get_trans(trans) - return head_to_mri(self.pos, subject, mri_head_t, - subjects_dir=subjects_dir, verbose=verbose, - kind='mri') + return head_to_mri( + self.pos, + subject, + mri_head_t, + subjects_dir=subjects_dir, + verbose=verbose, + kind="mri", + ) @verbose - def to_volume_labels(self, trans, subject='fsaverage', aseg='aparc+aseg', - subjects_dir=None, verbose=None): + def to_volume_labels( + self, + trans, + subject="fsaverage", + aseg="aparc+aseg", + subjects_dir=None, + verbose=None, + ): """Find an ROI in atlas for the dipole positions. Parameters @@ -279,16 +349,15 @@ def to_volume_labels(self, trans, subject='fsaverage', aseg='aparc+aseg', lut = {v: k for k, v in lut_inv.items()} # transform to voxel space from head space - pos = self.to_mri(subject, trans, subjects_dir=subjects_dir, - verbose=verbose) + pos = self.to_mri(subject, trans, subjects_dir=subjects_dir, verbose=verbose) pos = apply_trans(mri_vox_t, pos) pos = np.rint(pos).astype(int) # Get voxel value and label from LUT - labels = [lut.get(aseg_data[tuple(coord)], 'Unknown') for coord in pos] + labels = [lut.get(aseg_data[tuple(coord)], "Unknown") for coord in pos] return labels - def plot_amplitudes(self, color='k', show=True): + def plot_amplitudes(self, color="k", show=True): """Plot the dipole amplitudes as a function of time. Parameters @@ -304,6 +373,7 @@ def plot_amplitudes(self, color='k', show=True): The figure object containing the plot. """ from .viz import plot_dipole_amplitudes + return plot_dipole_amplitudes([self], [color], show) def __getitem__(self, item): @@ -334,9 +404,16 @@ def __getitem__(self, item): selected_khi2 = self.khi2[item] if self.khi2 is not None else None selected_nfree = self.nfree[item] if self.nfree is not None else None return Dipole( - selected_times, selected_pos, selected_amplitude, selected_ori, - selected_gof, selected_name, selected_conf, selected_khi2, - selected_nfree) + selected_times, + selected_pos, + selected_amplitude, + selected_ori, + selected_gof, + selected_name, + selected_conf, + selected_khi2, + selected_nfree, + ) def __len__(self): """Return the number of dipoles. @@ -358,7 +435,7 @@ def __len__(self): def _read_dipole_fixed(fname): """Read a fixed dipole FIF file.""" - logger.info('Reading %s ...' % fname) + logger.info("Reading %s ..." % fname) info, nave, aspect_kind, comment, times, data, _ = _read_evoked(fname) return DipoleFixed(info, data, times, nave, aspect_kind, comment=comment) @@ -403,12 +480,13 @@ class DipoleFixed(TimeMixin): """ @verbose - def __init__(self, info, data, times, nave, aspect_kind, - comment='', *, verbose=None): # noqa: D102 + def __init__( + self, info, data, times, nave, aspect_kind, comment="", *, verbose=None + ): # noqa: D102 self.info = info self.nave = nave self._aspect_kind = aspect_kind - self.kind = _aspect_rev.get(aspect_kind, 'unknown') + self.kind = _aspect_rev.get(aspect_kind, "unknown") self.comment = comment self._set_times(np.array(times)) self.data = data @@ -438,7 +516,7 @@ def copy(self): @property def ch_names(self): """Channel names.""" - return self.info['ch_names'] + return self.info["ch_names"] @verbose def save(self, fname, verbose=None): @@ -452,12 +530,20 @@ def save(self, fname, verbose=None): dipole information in FIF format. %(verbose)s """ - check_fname(fname, 'DipoleFixed', ('-dip.fif', '-dip.fif.gz', - '_dip.fif', '_dip.fif.gz',), - ('.fif', '.fif.gz')) + check_fname( + fname, + "DipoleFixed", + ( + "-dip.fif", + "-dip.fif.gz", + "_dip.fif", + "_dip.fif.gz", + ), + (".fif", ".fif.gz"), + ) _write_evokeds(fname, self, check=False) - def plot(self, show=True, time_unit='s'): + def plot(self, show=True, time_unit="s"): """Plot dipole data. Parameters @@ -474,12 +560,27 @@ def plot(self, show=True, time_unit='s'): fig : instance of matplotlib.figure.Figure The figure containing the time courses. """ - return _plot_evoked(self, picks=None, exclude=(), unit=True, show=show, - ylim=None, xlim='tight', proj=False, hline=None, - units=None, scalings=None, titles=None, axes=None, - gfp=False, window_title=None, spatial_colors=False, - plot_type="butterfly", selectable=False, - time_unit=time_unit) + return _plot_evoked( + self, + picks=None, + exclude=(), + unit=True, + show=show, + ylim=None, + xlim="tight", + proj=False, + hline=None, + units=None, + scalings=None, + titles=None, + axes=None, + gfp=False, + window_title=None, + spatial_colors=False, + plot_type="butterfly", + selectable=False, + time_unit=time_unit, + ) # ############################################################################# @@ -509,7 +610,7 @@ def read_dipole(fname, verbose=None): .. versionchanged:: 0.20 Support for reading bdip (Xfit binary) format. """ - fname = _check_fname(fname, overwrite='read', must_exist=True) + fname = _check_fname(fname, overwrite="read", must_exist=True) if fname.suffix == ".fif" or fname.name.endswith(".fif.gz"): return _read_dipole_fixed(fname) elif fname.suffix == ".bdip": @@ -526,69 +627,96 @@ def _read_dipole_text(fname): # There is a bug in older np.loadtxt regarding skipping fields, # so just read the data ourselves (need to get name and header anyway) data = list() - with open(fname, 'r') as fid: + with open(fname, "r") as fid: for line in fid: - if not (line.startswith('%') or line.startswith('#')): + if not (line.startswith("%") or line.startswith("#")): need_header = False data.append(line.strip().split()) else: if need_header: def_line = line - if line.startswith('##') or line.startswith('%%'): + if line.startswith("##") or line.startswith("%%"): m = re.search('Name "(.*) dipoles"', line) if m: name = m.group(1) del line data = np.atleast_2d(np.array(data, float)) if def_line is None: - raise OSError('Dipole text file is missing field definition ' - 'comment, cannot parse %s' % (fname,)) + raise OSError( + "Dipole text file is missing field definition " + "comment, cannot parse %s" % (fname,) + ) # actually parse the fields - def_line = def_line.lstrip('%').lstrip('#').strip() + def_line = def_line.lstrip("%").lstrip("#").strip() # MNE writes it out differently than Elekta, let's standardize them... - fields = re.sub(r'([X|Y|Z] )\(mm\)', # "X (mm)", etc. - lambda match: match.group(1).strip() + '/mm', def_line) - fields = re.sub(r'\((.*?)\)', # "Q(nAm)", etc. - lambda match: '/' + match.group(1), fields) - fields = re.sub('(begin|end) ', # "begin" and "end" with no units - lambda match: match.group(1) + '/ms', fields) + fields = re.sub( + r"([X|Y|Z] )\(mm\)", # "X (mm)", etc. + lambda match: match.group(1).strip() + "/mm", + def_line, + ) + fields = re.sub( + r"\((.*?)\)", lambda match: "/" + match.group(1), fields # "Q(nAm)", etc. + ) + fields = re.sub( + "(begin|end) ", # "begin" and "end" with no units + lambda match: match.group(1) + "/ms", + fields, + ) fields = fields.lower().split() - required_fields = ('begin/ms', - 'x/mm', 'y/mm', 'z/mm', - 'q/nam', 'qx/nam', 'qy/nam', 'qz/nam', - 'g/%') - optional_fields = ('khi^2', 'free', # standard ones - # now the confidence fields (up to 5!) - 'vol/mm^3', 'depth/mm', 'long/mm', 'trans/mm', - 'qlong/nam', 'qtrans/nam') + required_fields = ( + "begin/ms", + "x/mm", + "y/mm", + "z/mm", + "q/nam", + "qx/nam", + "qy/nam", + "qz/nam", + "g/%", + ) + optional_fields = ( + "khi^2", + "free", # standard ones + # now the confidence fields (up to 5!) + "vol/mm^3", + "depth/mm", + "long/mm", + "trans/mm", + "qlong/nam", + "qtrans/nam", + ) conf_scales = [1e-9, 1e-3, 1e-3, 1e-3, 1e-9, 1e-9] missing_fields = sorted(set(required_fields) - set(fields)) if len(missing_fields) > 0: - raise RuntimeError('Could not find necessary fields in header: %s' - % (missing_fields,)) + raise RuntimeError( + "Could not find necessary fields in header: %s" % (missing_fields,) + ) handled_fields = set(required_fields) | set(optional_fields) assert len(handled_fields) == len(required_fields) + len(optional_fields) - ignored_fields = sorted(set(fields) - - set(handled_fields) - - {'end/ms'}) + ignored_fields = sorted(set(fields) - set(handled_fields) - {"end/ms"}) if len(ignored_fields) > 0: - warn('Ignoring extra fields in dipole file: %s' % (ignored_fields,)) + warn("Ignoring extra fields in dipole file: %s" % (ignored_fields,)) if len(fields) != data.shape[1]: - raise OSError('More data fields (%s) found than data columns (%s): %s' - % (len(fields), data.shape[1], fields)) + raise OSError( + "More data fields (%s) found than data columns (%s): %s" + % (len(fields), data.shape[1], fields) + ) logger.info("%d dipole(s) found" % len(data)) - if 'end/ms' in fields: - if np.diff(data[:, [fields.index('begin/ms'), - fields.index('end/ms')]], 1, -1).any(): - warn('begin and end fields differed, but only begin will be used ' - 'to store time values') + if "end/ms" in fields: + if np.diff( + data[:, [fields.index("begin/ms"), fields.index("end/ms")]], 1, -1 + ).any(): + warn( + "begin and end fields differed, but only begin will be used " + "to store time values" + ) # Find the correct column in our data array, then scale to proper units idx = [fields.index(field) for field in required_fields] assert len(idx) >= 9 - times = data[:, idx[0]] / 1000. + times = data[:, idx[0]] / 1000.0 pos = 1e-3 * data[:, idx[1:4]] # put data in meters amplitude = data[:, idx[4]] norm = amplitude.copy() @@ -605,36 +733,39 @@ def _read_dipole_text(fname): conf = dict() for field, scale in zip(optional_fields[2:], conf_scales): # confidence if field in fields: - conf[field.split('/')[0]] = scale * data[:, fields.index(field)] + conf[field.split("/")[0]] = scale * data[:, fields.index(field)] return Dipole(times, pos, amplitude, ori, gof, name, conf, khi2, nfree) def _write_dipole_text(fname, dip): - fmt = ' %7.1f %7.1f %8.2f %8.2f %8.2f %8.3f %8.3f %8.3f %8.3f %6.2f' - header = ('# begin end X (mm) Y (mm) Z (mm)' - ' Q(nAm) Qx(nAm) Qy(nAm) Qz(nAm) g/%') - t = dip.times[:, np.newaxis] * 1000. + fmt = " %7.1f %7.1f %8.2f %8.2f %8.2f %8.3f %8.3f %8.3f %8.3f %6.2f" + header = ( + "# begin end X (mm) Y (mm) Z (mm)" + " Q(nAm) Qx(nAm) Qy(nAm) Qz(nAm) g/%" + ) + t = dip.times[:, np.newaxis] * 1000.0 gof = dip.gof[:, np.newaxis] amp = 1e9 * dip.amplitude[:, np.newaxis] out = (t, t, dip.pos / 1e-3, amp, dip.ori * amp, gof) # optional fields - fmts = dict(khi2=(' khi^2', ' %8.1f', 1.), - nfree=(' free', ' %5d', 1), - vol=(' vol/mm^3', ' %9.3f', 1e9), - depth=(' depth/mm', ' %9.3f', 1e3), - long=(' long/mm', ' %8.3f', 1e3), - trans=(' trans/mm', ' %9.3f', 1e3), - qlong=(' Qlong/nAm', ' %10.3f', 1e9), - qtrans=(' Qtrans/nAm', ' %11.3f', 1e9), - ) - for key in ('khi2', 'nfree'): + fmts = dict( + khi2=(" khi^2", " %8.1f", 1.0), + nfree=(" free", " %5d", 1), + vol=(" vol/mm^3", " %9.3f", 1e9), + depth=(" depth/mm", " %9.3f", 1e3), + long=(" long/mm", " %8.3f", 1e3), + trans=(" trans/mm", " %9.3f", 1e3), + qlong=(" Qlong/nAm", " %10.3f", 1e9), + qtrans=(" Qtrans/nAm", " %11.3f", 1e9), + ) + for key in ("khi2", "nfree"): data = getattr(dip, key) if data is not None: header += fmts[key][0] fmt += fmts[key][1] out += (data[:, np.newaxis] * fmts[key][2],) - for key in ('vol', 'depth', 'long', 'trans', 'qlong', 'qtrans'): + for key in ("vol", "depth", "long", "trans", "qlong", "qtrans"): data = dip.conf.get(key) if data is not None: header += fmts[key][0] @@ -643,22 +774,23 @@ def _write_dipole_text(fname, dip): out = np.concatenate(out, axis=-1) # NB CoordinateSystem is hard-coded as Head here - with open(fname, 'wb') as fid: - fid.write('# CoordinateSystem "Head"\n'.encode('utf-8')) - fid.write((header + '\n').encode('utf-8')) + with open(fname, "wb") as fid: + fid.write('# CoordinateSystem "Head"\n'.encode("utf-8")) + fid.write((header + "\n").encode("utf-8")) np.savetxt(fid, out, fmt=fmt) if dip.name is not None: - fid.write(('## Name "%s dipoles" Style "Dipoles"' - % dip.name).encode('utf-8')) + fid.write( + ('## Name "%s dipoles" Style "Dipoles"' % dip.name).encode("utf-8") + ) -_BDIP_ERROR_KEYS = ('depth', 'long', 'trans', 'qlong', 'qtrans') +_BDIP_ERROR_KEYS = ("depth", "long", "trans", "qlong", "qtrans") def _read_dipole_bdip(fname): name = None nfree = None - with open(fname, 'rb') as fid: + with open(fname, "rb") as fid: # Which dipole in a multi-dipole set times = list() pos = list() @@ -669,75 +801,77 @@ def _read_dipole_bdip(fname): khi2 = list() has_errors = None while True: - num = np.frombuffer(fid.read(4), '>i4') + num = np.frombuffer(fid.read(4), ">i4") if len(num) == 0: break - times.append(np.frombuffer(fid.read(4), '>f4')[0]) + times.append(np.frombuffer(fid.read(4), ">f4")[0]) fid.read(4) # end fid.read(12) # r0 - pos.append(np.frombuffer(fid.read(12), '>f4')) - Q = np.frombuffer(fid.read(12), '>f4') + pos.append(np.frombuffer(fid.read(12), ">f4")) + Q = np.frombuffer(fid.read(12), ">f4") amplitude.append(np.linalg.norm(Q)) ori.append(Q / amplitude[-1]) - gof.append(100 * np.frombuffer(fid.read(4), '>f4')[0]) - this_has_errors = bool(np.frombuffer(fid.read(4), '>i4')[0]) + gof.append(100 * np.frombuffer(fid.read(4), ">f4")[0]) + this_has_errors = bool(np.frombuffer(fid.read(4), ">i4")[0]) if has_errors is None: has_errors = this_has_errors for key in _BDIP_ERROR_KEYS: conf[key] = list() assert has_errors == this_has_errors fid.read(4) # Noise level used for error computations - limits = np.frombuffer(fid.read(20), '>f4') # error limits + limits = np.frombuffer(fid.read(20), ">f4") # error limits for key, lim in zip(_BDIP_ERROR_KEYS, limits): conf[key].append(lim) fid.read(100) # (5, 5) fully describes the conf. ellipsoid - conf['vol'].append(np.frombuffer(fid.read(4), '>f4')[0]) - khi2.append(np.frombuffer(fid.read(4), '>f4')[0]) + conf["vol"].append(np.frombuffer(fid.read(4), ">f4")[0]) + khi2.append(np.frombuffer(fid.read(4), ">f4")[0]) fid.read(4) # prob fid.read(4) # total noise estimate return Dipole(times, pos, amplitude, ori, gof, name, conf, khi2, nfree) def _write_dipole_bdip(fname, dip): - with open(fname, 'wb+') as fid: + with open(fname, "wb+") as fid: for ti, t in enumerate(dip.times): - fid.write(np.zeros(1, '>i4').tobytes()) # int dipole - fid.write(np.array([t, 0]).astype('>f4').tobytes()) - fid.write(np.zeros(3, '>f4').tobytes()) # r0 - fid.write(dip.pos[ti].astype('>f4').tobytes()) # pos + fid.write(np.zeros(1, ">i4").tobytes()) # int dipole + fid.write(np.array([t, 0]).astype(">f4").tobytes()) + fid.write(np.zeros(3, ">f4").tobytes()) # r0 + fid.write(dip.pos[ti].astype(">f4").tobytes()) # pos Q = dip.amplitude[ti] * dip.ori[ti] - fid.write(Q.astype('>f4').tobytes()) - fid.write(np.array(dip.gof[ti] / 100., '>f4').tobytes()) + fid.write(Q.astype(">f4").tobytes()) + fid.write(np.array(dip.gof[ti] / 100.0, ">f4").tobytes()) has_errors = int(bool(len(dip.conf))) - fid.write(np.array(has_errors, '>i4').tobytes()) # has_errors - fid.write(np.zeros(1, '>f4').tobytes()) # noise level + fid.write(np.array(has_errors, ">i4").tobytes()) # has_errors + fid.write(np.zeros(1, ">f4").tobytes()) # noise level for key in _BDIP_ERROR_KEYS: - val = dip.conf[key][ti] if key in dip.conf else 0. + val = dip.conf[key][ti] if key in dip.conf else 0.0 assert val.shape == () - fid.write(np.array(val, '>f4').tobytes()) - fid.write(np.zeros(25, '>f4').tobytes()) - conf = dip.conf['vol'][ti] if 'vol' in dip.conf else 0. - fid.write(np.array(conf, '>f4').tobytes()) + fid.write(np.array(val, ">f4").tobytes()) + fid.write(np.zeros(25, ">f4").tobytes()) + conf = dip.conf["vol"][ti] if "vol" in dip.conf else 0.0 + fid.write(np.array(conf, ">f4").tobytes()) khi2 = dip.khi2[ti] if dip.khi2 is not None else 0 - fid.write(np.array(khi2, '>f4').tobytes()) - fid.write(np.zeros(1, '>f4').tobytes()) # prob - fid.write(np.zeros(1, '>f4').tobytes()) # total noise est + fid.write(np.array(khi2, ">f4").tobytes()) + fid.write(np.zeros(1, ">f4").tobytes()) # prob + fid.write(np.zeros(1, ">f4").tobytes()) # total noise est # ############################################################################# # Fitting + def _dipole_forwards(*, sensors, fwd_data, whitener, rr, n_jobs=None): """Compute the forward solution and do other nice stuff.""" B = _compute_forwards_meeg( - rr, sensors=sensors, fwd_data=fwd_data, n_jobs=n_jobs, silent=True) + rr, sensors=sensors, fwd_data=fwd_data, n_jobs=n_jobs, silent=True + ) B = np.concatenate(list(B.values()), axis=1) assert np.isfinite(B).all() B_orig = B.copy() # Apply projection and whiten (cov has projections already) _, _, dgemm = _get_ddot_dgemv_dgemm() - B = dgemm(1., B, whitener.T) + B = dgemm(1.0, B, whitener.T) # column normalization doesn't affect our fitting, so skip for now # S = np.sum(B * B, axis=1) # across channels @@ -751,21 +885,30 @@ def _dipole_forwards(*, sensors, fwd_data, whitener, rr, n_jobs=None): @verbose def _make_guesses(surf, grid, exclude, mindist, n_jobs=None, verbose=None): """Make a guess space inside a sphere or BEM surface.""" - if 'rr' in surf: - logger.info('Guess surface (%s) is in %s coordinates' - % (_bem_surf_name[surf['id']], - _coord_frame_name(surf['coord_frame']))) + if "rr" in surf: + logger.info( + "Guess surface (%s) is in %s coordinates" + % (_bem_surf_name[surf["id"]], _coord_frame_name(surf["coord_frame"])) + ) else: - logger.info('Making a spherical guess space with radius %7.1f mm...' - % (1000 * surf['R'])) - logger.info('Filtering (grid = %6.f mm)...' % (1000 * grid)) - src = _make_volume_source_space(surf, grid, exclude, 1000 * mindist, - do_neighbors=False, n_jobs=n_jobs)[0] - assert 'vertno' in src + logger.info( + "Making a spherical guess space with radius %7.1f mm..." + % (1000 * surf["R"]) + ) + logger.info("Filtering (grid = %6.f mm)..." % (1000 * grid)) + src = _make_volume_source_space( + surf, grid, exclude, 1000 * mindist, do_neighbors=False, n_jobs=n_jobs + )[0] + assert "vertno" in src # simplify the result to make things easier later - src = dict(rr=src['rr'][src['vertno']], nn=src['nn'][src['vertno']], - nuse=src['nuse'], coord_frame=src['coord_frame'], - vertno=np.arange(src['nuse']), type='discrete') + src = dict( + rr=src["rr"][src["vertno"]], + nn=src["nn"][src["vertno"]], + nuse=src["nuse"], + coord_frame=src["coord_frame"], + vertno=np.arange(src["nuse"]), + type="discrete", + ) return SourceSpaces([src]) @@ -774,26 +917,26 @@ def _fit_eval(rd, B, B2, *, sensors, fwd_data, whitener, lwork, fwd_svd): if fwd_svd is None: assert sensors is not None fwd = _dipole_forwards( - sensors=sensors, fwd_data=fwd_data, whitener=whitener, - rr=rd[np.newaxis, :])[0] + sensors=sensors, fwd_data=fwd_data, whitener=whitener, rr=rd[np.newaxis, :] + )[0] uu, sing, vv = _repeated_svd(fwd, lwork, overwrite_a=True) else: uu, sing, vv = fwd_svd gof = _dipole_gof(uu, sing, vv, B, B2)[0] # mne-c uses fitness=B2-Bm2, but ours (1-gof) is just a normalized version - return 1. - gof + return 1.0 - gof @functools.lru_cache(None) def _get_ddot_dgemv_dgemm(): - return _get_blas_funcs(np.float64, ('dot', 'gemv', 'gemm')) + return _get_blas_funcs(np.float64, ("dot", "gemv", "gemm")) def _dipole_gof(uu, sing, vv, B, B2): """Calculate the goodness of fit from the forward SVD.""" ddot, dgemv, _ = _get_ddot_dgemv_dgemm() - ncomp = 3 if sing[2] / (sing[0] if sing[0] > 0 else 1.) > 0.2 else 2 - one = dgemv(1., vv[:ncomp], B) # np.dot(vv[:ncomp], B) + ncomp = 3 if sing[2] / (sing[0] if sing[0] > 0 else 1.0) > 0.2 else 2 + one = dgemv(1.0, vv[:ncomp], B) # np.dot(vv[:ncomp], B) Bm2 = ddot(one, one) # np.sum(one * one) gof = Bm2 / B2 return gof, one @@ -802,20 +945,21 @@ def _dipole_gof(uu, sing, vv, B, B2): def _fit_Q(*, sensors, fwd_data, whitener, B, B2, B_orig, rd, ori=None): """Fit the dipole moment once the location is known.""" from scipy import linalg - if 'fwd' in fwd_data: + + if "fwd" in fwd_data: # should be a single precomputed "guess" (i.e., fixed position) assert rd is None - fwd = fwd_data['fwd'] + fwd = fwd_data["fwd"] assert fwd.shape[0] == 3 - fwd_orig = fwd_data['fwd_orig'] + fwd_orig = fwd_data["fwd_orig"] assert fwd_orig.shape[0] == 3 - scales = fwd_data['scales'] + scales = fwd_data["scales"] assert scales.shape == (3,) - fwd_svd = fwd_data['fwd_svd'][0] + fwd_svd = fwd_data["fwd_svd"][0] else: fwd, fwd_orig, scales = _dipole_forwards( - sensors=sensors, fwd_data=fwd_data, whitener=whitener, - rr=rd[np.newaxis, :]) + sensors=sensors, fwd_data=fwd_data, whitener=whitener, rr=rd[np.newaxis, :] + ) fwd_svd = None if ori is None: if fwd_svd is None: @@ -838,19 +982,44 @@ def _fit_Q(*, sensors, fwd_data, whitener, B, B2, B_orig, rd, ori=None): return Q, gof, B_residual_noproj, ncomp -def _fit_dipoles(fun, min_dist_to_inner_skull, data, times, guess_rrs, - guess_data, *, sensors, fwd_data, whitener, ori, n_jobs, - rank, rhoend): +def _fit_dipoles( + fun, + min_dist_to_inner_skull, + data, + times, + guess_rrs, + guess_data, + *, + sensors, + fwd_data, + whitener, + ori, + n_jobs, + rank, + rhoend, +): """Fit a single dipole to the given whitened, projected data.""" from scipy.optimize import fmin_cobyla + parallel, p_fun, n_jobs = parallel_func(fun, n_jobs) # parallel over time points res = parallel( p_fun( - min_dist_to_inner_skull, B, t, guess_rrs, guess_data, - sensors=sensors, fwd_data=fwd_data, whitener=whitener, - fmin_cobyla=fmin_cobyla, ori=ori, rank=rank, rhoend=rhoend) - for B, t in zip(data.T, times)) + min_dist_to_inner_skull, + B, + t, + guess_rrs, + guess_data, + sensors=sensors, + fwd_data=fwd_data, + whitener=whitener, + fmin_cobyla=fmin_cobyla, + ori=ori, + rank=rank, + rhoend=rhoend, + ) + for B, t in zip(data.T, times) + ) pos = np.array([r[0] for r in res]) amp = np.array([r[1] for r in res]) ori = np.array([r[2] for r in res]) @@ -858,7 +1027,7 @@ def _fit_dipoles(fun, min_dist_to_inner_skull, data, times, guess_rrs, conf = None if res[0][4] is not None: conf = np.array([r[4] for r in res]) - keys = ['vol', 'depth', 'long', 'trans', 'qlong', 'qtrans'] + keys = ["vol", "depth", "long", "trans", "qlong", "qtrans"] conf = {key: conf[:, ki] for ki, key in enumerate(keys)} khi2 = np.array([r[5] for r in res]) nfree = np.array([r[6] for r in res]) @@ -971,11 +1140,12 @@ def _fit_confidence(*, rd, Q, ori, whitener, fwd_data, sensors): # And then the confidence interval is the diagonal of C, scaled by 1.96 # (for 95% confidence). from scipy import linalg + direction = np.empty((3, 3)) # The coordinate system has the x axis aligned with the dipole orientation, direction[0] = ori # the z axis through the origin of the sphere model - rvec = rd - fwd_data['inner_skull']['r0'] + rvec = rd - fwd_data["inner_skull"]["r0"] direction[2] = rvec - ori * np.dot(ori, rvec) # orthogonalize direction[2] /= np.linalg.norm(direction[2]) # and the y axis perpendical with these forming a right-handed system. @@ -989,15 +1159,19 @@ def _fit_confidence(*, rd, Q, ori, whitener, fwd_data, sensors): for delta in deltas: this_r = rd[np.newaxis] + delta * direction[ii] fwds.append( - np.dot(Q, _dipole_forwards( - sensors=sensors, fwd_data=fwd_data, - whitener=whitener, rr=this_r)[0])) + np.dot( + Q, + _dipole_forwards( + sensors=sensors, fwd_data=fwd_data, whitener=whitener, rr=this_r + )[0], + ) + ) J[:, ii] = np.diff(fwds, axis=0)[0] / np.diff(deltas)[0] # Get current (Q) deltas in the dipole directions deltas = np.array([-0.01, 0.01]) * np.linalg.norm(Q) this_fwd = _dipole_forwards( - sensors=sensors, fwd_data=fwd_data, whitener=whitener, - rr=rd[np.newaxis])[0] + sensors=sensors, fwd_data=fwd_data, whitener=whitener, rr=rd[np.newaxis] + )[0] for ii in range(3): fwds = [] for delta in deltas: @@ -1018,8 +1192,12 @@ def _fit_confidence(*, rd, Q, ori, whitener, fwd_data, sensors): # The confidence volume of the dipole location is obtained from by # taking the eigenvalues of the upper left submatrix and computing # v = 4π/3 √(c^3 λ1 λ2 λ3) with c = 7.81, or: - vol_conf = 4 * np.pi / 3. * np.sqrt( - 476.379541 * np.prod(linalg.eigh(C[:3, :3], eigvals_only=True))) + vol_conf = ( + 4 + * np.pi + / 3.0 + * np.sqrt(476.379541 * np.prod(linalg.eigh(C[:3, :3], eigvals_only=True))) + ) conf = np.concatenate([conf, [vol_conf]]) # Now we reorder and subselect the proper columns: # vol, depth, long, trans, Qlong, Qtrans (discard Qdepth, assumed zero) @@ -1029,10 +1207,9 @@ def _fit_confidence(*, rd, Q, ori, whitener, fwd_data, sensors): def _surface_constraint(rd, surf, min_dist_to_inner_skull): """Surface fitting constraint.""" - dist = _compute_nearest(surf['rr'], rd[np.newaxis, :], - return_dists=True)[1][0] + dist = _compute_nearest(surf["rr"], rd[np.newaxis, :], return_dists=True)[1][0] if _points_outside_surface(rd[np.newaxis, :], surf, 1)[0]: - dist *= -1. + dist *= -1.0 # Once we know the dipole is below the inner skull, # let's check if its distance to the inner skull is at least # min_dist_to_inner_skull. This can be enforced by adding a @@ -1046,45 +1223,82 @@ def _sphere_constraint(rd, r0, R_adj): return R_adj - np.sqrt(np.sum((rd - r0) ** 2)) -def _fit_dipole(min_dist_to_inner_skull, B_orig, t, guess_rrs, - guess_data, *, sensors, fwd_data, whitener, fmin_cobyla, - ori, rank, rhoend): +def _fit_dipole( + min_dist_to_inner_skull, + B_orig, + t, + guess_rrs, + guess_data, + *, + sensors, + fwd_data, + whitener, + fmin_cobyla, + ori, + rank, + rhoend, +): """Fit a single bit of data.""" B = np.dot(whitener, B_orig) # make constraint function to keep the solver within the inner skull - if 'rr' in fwd_data['inner_skull']: # bem - surf = fwd_data['inner_skull'] - constraint = partial(_surface_constraint, surf=surf, - min_dist_to_inner_skull=min_dist_to_inner_skull) + if "rr" in fwd_data["inner_skull"]: # bem + surf = fwd_data["inner_skull"] + constraint = partial( + _surface_constraint, + surf=surf, + min_dist_to_inner_skull=min_dist_to_inner_skull, + ) else: # sphere surf = None constraint = partial( - _sphere_constraint, r0=fwd_data['inner_skull']['r0'], - R_adj=fwd_data['inner_skull']['R'] - min_dist_to_inner_skull) + _sphere_constraint, + r0=fwd_data["inner_skull"]["r0"], + R_adj=fwd_data["inner_skull"]["R"] - min_dist_to_inner_skull, + ) # Find a good starting point (find_best_guess in C) B2 = np.dot(B, B) if B2 == 0: - warn('Zero field found for time %s' % t) + warn("Zero field found for time %s" % t) return np.zeros(3), 0, np.zeros(3), 0, B - idx = np.argmin([ - _fit_eval(guess_rrs[[fi], :], B, B2, fwd_svd=fwd_svd, - fwd_data=None, sensors=None, whitener=None, lwork=None) - for fi, fwd_svd in enumerate(guess_data['fwd_svd'])]) + idx = np.argmin( + [ + _fit_eval( + guess_rrs[[fi], :], + B, + B2, + fwd_svd=fwd_svd, + fwd_data=None, + sensors=None, + whitener=None, + lwork=None, + ) + for fi, fwd_svd in enumerate(guess_data["fwd_svd"]) + ] + ) x0 = guess_rrs[idx] lwork = _svd_lwork((3, B.shape[0])) - fun = partial(_fit_eval, B=B, B2=B2, fwd_data=fwd_data, whitener=whitener, - lwork=lwork, sensors=sensors, fwd_svd=None) + fun = partial( + _fit_eval, + B=B, + B2=B2, + fwd_data=fwd_data, + whitener=whitener, + lwork=lwork, + sensors=sensors, + fwd_svd=None, + ) # Tested minimizers: # Simplex, BFGS, CG, COBYLA, L-BFGS-B, Powell, SLSQP, TNC # Several were similar, but COBYLA won for having a handy constraint # function we can use to ensure we stay inside the inner skull / # smallest sphere - rd_final = fmin_cobyla(fun, x0, (constraint,), consargs=(), - rhobeg=5e-2, rhoend=rhoend, disp=False) + rd_final = fmin_cobyla( + fun, x0, (constraint,), consargs=(), rhobeg=5e-2, rhoend=rhoend, disp=False + ) # simplex = _make_tetra_simplex() + x0 # _simplex_minimize(simplex, 1e-4, 2e-4, fun) @@ -1092,45 +1306,71 @@ def _fit_dipole(min_dist_to_inner_skull, B_orig, t, guess_rrs, # Compute the dipole moment at the final point Q, gof, residual_noproj, n_comp = _fit_Q( - sensors=sensors, fwd_data=fwd_data, whitener=whitener, B=B, B2=B2, - B_orig=B_orig, rd=rd_final, ori=ori) + sensors=sensors, + fwd_data=fwd_data, + whitener=whitener, + B=B, + B2=B2, + B_orig=B_orig, + rd=rd_final, + ori=ori, + ) khi2 = (1 - gof) * B2 nfree = rank - n_comp amp = np.sqrt(np.dot(Q, Q)) - norm = 1. if amp == 0. else amp + norm = 1.0 if amp == 0.0 else amp ori = Q / norm conf = _fit_confidence( - sensors=sensors, rd=rd_final, Q=Q, ori=ori, whitener=whitener, - fwd_data=fwd_data) + sensors=sensors, rd=rd_final, Q=Q, ori=ori, whitener=whitener, fwd_data=fwd_data + ) - msg = '---- Fitted : %7.1f ms' % (1000. * t) + msg = "---- Fitted : %7.1f ms" % (1000.0 * t) if surf is not None: dist_to_inner_skull = _compute_nearest( - surf['rr'], rd_final[np.newaxis, :], return_dists=True)[1][0] - msg += (", distance to inner skull : %2.4f mm" - % (dist_to_inner_skull * 1000.)) + surf["rr"], rd_final[np.newaxis, :], return_dists=True + )[1][0] + msg += ", distance to inner skull : %2.4f mm" % (dist_to_inner_skull * 1000.0) logger.info(msg) return rd_final, amp, ori, gof, conf, khi2, nfree, residual_noproj -def _fit_dipole_fixed(min_dist_to_inner_skull, B_orig, t, guess_rrs, - guess_data, *, sensors, fwd_data, whitener, - fmin_cobyla, ori, rank, rhoend): +def _fit_dipole_fixed( + min_dist_to_inner_skull, + B_orig, + t, + guess_rrs, + guess_data, + *, + sensors, + fwd_data, + whitener, + fmin_cobyla, + ori, + rank, + rhoend, +): """Fit a data using a fixed position.""" B = np.dot(whitener, B_orig) B2 = np.dot(B, B) if B2 == 0: - warn('Zero field found for time %s' % t) + warn("Zero field found for time %s" % t) return np.zeros(3), 0, np.zeros(3), 0, np.zeros(6) # Compute the dipole moment Q, gof, residual_noproj = _fit_Q( - fwd_data=guess_data, whitener=whitener, B=B, B2=B2, B_orig=B_orig, - sensors=sensors, rd=None, ori=ori)[:3] + fwd_data=guess_data, + whitener=whitener, + B=B, + B2=B2, + B_orig=B_orig, + sensors=sensors, + rd=None, + ori=ori, + )[:3] if ori is None: amp = np.sqrt(np.dot(Q, Q)) - norm = 1. if amp == 0. else amp + norm = 1.0 if amp == 0.0 else amp ori = Q / norm else: amp = np.dot(Q, ori) @@ -1143,9 +1383,20 @@ def _fit_dipole_fixed(min_dist_to_inner_skull, B_orig, t, guess_rrs, @verbose -def fit_dipole(evoked, cov, bem, trans=None, min_dist=5., n_jobs=None, - pos=None, ori=None, rank=None, accuracy='normal', tol=5e-5, - verbose=None): +def fit_dipole( + evoked, + cov, + bem, + trans=None, + min_dist=5.0, + n_jobs=None, + pos=None, + ori=None, + rank=None, + accuracy="normal", + tol=5e-5, + verbose=None, +): """Fit a dipole. Parameters @@ -1219,76 +1470,84 @@ def fit_dipole(evoked, cov, bem, trans=None, min_dist=5., n_jobs=None, .. versionadded:: 0.9.0 """ from scipy import linalg + # This could eventually be adapted to work with other inputs, these # are what is needed: evoked = evoked.copy() - _validate_type(accuracy, str, 'accuracy') - _check_option('accuracy', accuracy, ('accurate', 'normal')) + _validate_type(accuracy, str, "accuracy") + _check_option("accuracy", accuracy, ("accurate", "normal")) # Determine if a list of projectors has an average EEG ref if _needs_eeg_average_ref_proj(evoked.info): - raise ValueError('EEG average reference is mandatory for dipole ' - 'fitting.') + raise ValueError("EEG average reference is mandatory for dipole " "fitting.") if min_dist < 0: - raise ValueError('min_dist should be positive. Got %s' % min_dist) + raise ValueError("min_dist should be positive. Got %s" % min_dist) if ori is not None and pos is None: - raise ValueError('pos must be provided if ori is not None') + raise ValueError("pos must be provided if ori is not None") data = evoked.data if not np.isfinite(data).all(): - raise ValueError('Evoked data must be finite') + raise ValueError("Evoked data must be finite") info = evoked.info times = evoked.times.copy() comment = evoked.comment # Convert the min_dist to meters - min_dist_to_inner_skull = min_dist / 1000. + min_dist_to_inner_skull = min_dist / 1000.0 del min_dist # Figure out our inputs - neeg = len(pick_types(info, meg=False, eeg=True, ref_meg=False, - exclude=[])) + neeg = len(pick_types(info, meg=False, eeg=True, ref_meg=False, exclude=[])) if isinstance(bem, str): bem_extra = bem else: bem_extra = repr(bem) - logger.info('BEM : %s' % bem_extra) + logger.info("BEM : %s" % bem_extra) mri_head_t, trans = _get_trans(trans) - logger.info('MRI transform : %s' % trans) + logger.info("MRI transform : %s" % trans) safe_false = _verbose_safe_false() bem = _setup_bem(bem, bem_extra, neeg, mri_head_t, verbose=safe_false) - if not bem['is_sphere']: + if not bem["is_sphere"]: # Find the best-fitting sphere - inner_skull = _bem_find_surface(bem, 'inner_skull') + inner_skull = _bem_find_surface(bem, "inner_skull") inner_skull = inner_skull.copy() - R, r0 = _fit_sphere(inner_skull['rr'], disp=False) + R, r0 = _fit_sphere(inner_skull["rr"], disp=False) # r0 back to head frame for logging - r0 = apply_trans(mri_head_t['trans'], r0[np.newaxis, :])[0] - inner_skull['r0'] = r0 - logger.info('Head origin : ' - '%6.1f %6.1f %6.1f mm rad = %6.1f mm.' - % (1000 * r0[0], 1000 * r0[1], 1000 * r0[2], 1000 * R)) + r0 = apply_trans(mri_head_t["trans"], r0[np.newaxis, :])[0] + inner_skull["r0"] = r0 + logger.info( + "Head origin : " + "%6.1f %6.1f %6.1f mm rad = %6.1f mm." + % (1000 * r0[0], 1000 * r0[1], 1000 * r0[2], 1000 * R) + ) del R, r0 else: - r0 = bem['r0'] - if len(bem.get('layers', [])) > 0: - R = bem['layers'][0]['rad'] - kind = 'rad' + r0 = bem["r0"] + if len(bem.get("layers", [])) > 0: + R = bem["layers"][0]["rad"] + kind = "rad" else: # MEG-only # Use the minimum distance to the MEG sensors as the radius then - R = np.dot(np.linalg.inv(info['dev_head_t']['trans']), - np.hstack([r0, [1.]]))[:3] # r0 -> device - R = R - [info['chs'][pick]['loc'][:3] - for pick in pick_types(info, meg=True, exclude=[])] + R = np.dot( + np.linalg.inv(info["dev_head_t"]["trans"]), np.hstack([r0, [1.0]]) + )[ + :3 + ] # r0 -> device + R = R - [ + info["chs"][pick]["loc"][:3] + for pick in pick_types(info, meg=True, exclude=[]) + ] if len(R) == 0: - raise RuntimeError('No MEG channels found, but MEG-only ' - 'sphere model used') + raise RuntimeError( + "No MEG channels found, but MEG-only " "sphere model used" + ) R = np.min(np.sqrt(np.sum(R * R, axis=1))) # use dist to sensors - kind = 'max_rad' - logger.info('Sphere model : origin at (% 7.2f % 7.2f % 7.2f) mm, ' - '%s = %6.1f mm' - % (1000 * r0[0], 1000 * r0[1], 1000 * r0[2], kind, R)) + kind = "max_rad" + logger.info( + "Sphere model : origin at (% 7.2f % 7.2f % 7.2f) mm, " + "%s = %6.1f mm" % (1000 * r0[0], 1000 * r0[1], 1000 * r0[2], kind, R) + ) inner_skull = dict(R=R, r0=r0) # NB sphere model defined in head frame del R, r0 @@ -1297,23 +1556,22 @@ def fit_dipole(evoked, cov, bem, trans=None, min_dist=5., n_jobs=None, fixed_position = True pos = np.array(pos, float) if pos.shape != (3,): - raise ValueError('pos must be None or a 3-element array-like,' - ' got %s' % (pos,)) - logger.info('Fixed position : %6.1f %6.1f %6.1f mm' - % tuple(1000 * pos)) + raise ValueError( + "pos must be None or a 3-element array-like," " got %s" % (pos,) + ) + logger.info("Fixed position : %6.1f %6.1f %6.1f mm" % tuple(1000 * pos)) if ori is not None: ori = np.array(ori, float) if ori.shape != (3,): - raise ValueError('oris must be None or a 3-element array-like,' - ' got %s' % (ori,)) + raise ValueError( + "oris must be None or a 3-element array-like," " got %s" % (ori,) + ) norm = np.sqrt(np.sum(ori * ori)) if not np.isclose(norm, 1): - raise ValueError('ori must be a unit vector, got length %s' - % (norm,)) - logger.info('Fixed orientation : %6.4f %6.4f %6.4f mm' - % tuple(ori)) + raise ValueError("ori must be a unit vector, got length %s" % (norm,)) + logger.info("Fixed orientation : %6.4f %6.4f %6.4f mm" % tuple(ori)) else: - logger.info('Free orientation : ') + logger.info("Free orientation : ") fit_n_jobs = 1 # only use 1 job to do the guess fitting else: fixed_position = False @@ -1323,39 +1581,37 @@ def fit_dipole(evoked, cov, bem, trans=None, min_dist=5., n_jobs=None, guess_mindist = max(0.005, min_dist_to_inner_skull) guess_exclude = 0.02 - logger.info('Guess grid : %6.1f mm' % (1000 * guess_grid,)) + logger.info("Guess grid : %6.1f mm" % (1000 * guess_grid,)) if guess_mindist > 0.0: - logger.info('Guess mindist : %6.1f mm' - % (1000 * guess_mindist,)) + logger.info("Guess mindist : %6.1f mm" % (1000 * guess_mindist,)) if guess_exclude > 0: - logger.info('Guess exclude : %6.1f mm' - % (1000 * guess_exclude,)) - logger.info(f'Using {accuracy} MEG coil definitions.') + logger.info("Guess exclude : %6.1f mm" % (1000 * guess_exclude,)) + logger.info(f"Using {accuracy} MEG coil definitions.") fit_n_jobs = n_jobs cov = _ensure_cov(cov) - logger.info('') + logger.info("") _print_coord_trans(mri_head_t) - _print_coord_trans(info['dev_head_t']) - logger.info('%d bad channels total' % len(info['bads'])) + _print_coord_trans(info["dev_head_t"]) + logger.info("%d bad channels total" % len(info["bads"])) # Forward model setup (setup_forward_model from setup.c) ch_types = evoked.get_channel_types() sensors = dict() - if 'grad' in ch_types or 'mag' in ch_types: - sensors['meg'] = _prep_meg_channels( - info, exclude='bads', accuracy=accuracy, verbose=verbose) - if 'eeg' in ch_types: - sensors['eeg'] = _prep_eeg_channels( - info, exclude='bads', verbose=verbose) + if "grad" in ch_types or "mag" in ch_types: + sensors["meg"] = _prep_meg_channels( + info, exclude="bads", accuracy=accuracy, verbose=verbose + ) + if "eeg" in ch_types: + sensors["eeg"] = _prep_eeg_channels(info, exclude="bads", verbose=verbose) # Ensure that MEG and/or EEG channels are present if len(sensors) == 0: - raise RuntimeError('No MEG or EEG channels found.') + raise RuntimeError("No MEG or EEG channels found.") # Whitener for the data - logger.info('Decomposing the sensor noise covariance matrix...') + logger.info("Decomposing the sensor noise covariance matrix...") picks = pick_types(info, meg=True, eeg=True, ref_meg=False) # In case we want to more closely match MNE-C for debugging: @@ -1369,63 +1625,85 @@ def fit_dipole(evoked, cov, bem, trans=None, min_dist=5., n_jobs=None, # whitener[nzero, nzero] = 1.0 / np.sqrt(cov['eig'][nzero]) # whitener = np.dot(whitener, cov['eigvec']) - whitener, _, rank = compute_whitener(cov, info, picks=picks, - rank=rank, return_rank=True) + whitener, _, rank = compute_whitener( + cov, info, picks=picks, rank=rank, return_rank=True + ) # Proceed to computing the fits (make_guess_data) if fixed_position: guess_src = dict(nuse=1, rr=pos[np.newaxis], inuse=np.array([True])) - logger.info('Compute forward for dipole location...') + logger.info("Compute forward for dipole location...") else: - logger.info('\n---- Computing the forward solution for the guesses...') - guess_src = _make_guesses(inner_skull, guess_grid, guess_exclude, - guess_mindist, n_jobs=n_jobs)[0] + logger.info("\n---- Computing the forward solution for the guesses...") + guess_src = _make_guesses( + inner_skull, guess_grid, guess_exclude, guess_mindist, n_jobs=n_jobs + )[0] # grid coordinates go from mri to head frame - transform_surface_to(guess_src, 'head', mri_head_t) - logger.info('Go through all guess source locations...') + transform_surface_to(guess_src, "head", mri_head_t) + logger.info("Go through all guess source locations...") # inner_skull goes from mri to head frame - if 'rr' in inner_skull: - transform_surface_to(inner_skull, 'head', mri_head_t) + if "rr" in inner_skull: + transform_surface_to(inner_skull, "head", mri_head_t) if fixed_position: - if 'rr' in inner_skull: - check = _surface_constraint(pos, inner_skull, - min_dist_to_inner_skull) + if "rr" in inner_skull: + check = _surface_constraint(pos, inner_skull, min_dist_to_inner_skull) else: check = _sphere_constraint( - pos, inner_skull['r0'], - R_adj=inner_skull['R'] - min_dist_to_inner_skull) + pos, inner_skull["r0"], R_adj=inner_skull["R"] - min_dist_to_inner_skull + ) if check <= 0: - raise ValueError('fixed position is %0.1fmm outside the inner ' - 'skull boundary' % (-1000 * check,)) + raise ValueError( + "fixed position is %0.1fmm outside the inner " + "skull boundary" % (-1000 * check,) + ) # C code computes guesses w/sphere model for speed, don't bother here fwd_data = _prep_field_computation( - guess_src['rr'], sensors=sensors, bem=bem, n_jobs=n_jobs, - verbose=safe_false) - fwd_data['inner_skull'] = inner_skull + guess_src["rr"], sensors=sensors, bem=bem, n_jobs=n_jobs, verbose=safe_false + ) + fwd_data["inner_skull"] = inner_skull guess_fwd, guess_fwd_orig, guess_fwd_scales = _dipole_forwards( - sensors=sensors, fwd_data=fwd_data, whitener=whitener, - rr=guess_src['rr'], n_jobs=fit_n_jobs) + sensors=sensors, + fwd_data=fwd_data, + whitener=whitener, + rr=guess_src["rr"], + n_jobs=fit_n_jobs, + ) # decompose ahead of time - guess_fwd_svd = [linalg.svd(fwd, full_matrices=False) - for fwd in np.array_split(guess_fwd, - len(guess_src['rr']))] - guess_data = dict(fwd=guess_fwd, fwd_svd=guess_fwd_svd, - fwd_orig=guess_fwd_orig, scales=guess_fwd_scales) + guess_fwd_svd = [ + linalg.svd(fwd, full_matrices=False) + for fwd in np.array_split(guess_fwd, len(guess_src["rr"])) + ] + guess_data = dict( + fwd=guess_fwd, + fwd_svd=guess_fwd_svd, + fwd_orig=guess_fwd_orig, + scales=guess_fwd_scales, + ) del guess_fwd, guess_fwd_svd, guess_fwd_orig, guess_fwd_scales # destroyed - logger.info('[done %d source%s]' % (guess_src['nuse'], - _pl(guess_src['nuse']))) + logger.info("[done %d source%s]" % (guess_src["nuse"], _pl(guess_src["nuse"]))) # Do actual fits data = data[picks] - ch_names = [info['ch_names'][p] for p in picks] - proj_op = make_projector(info['projs'], ch_names, info['bads'])[0] + ch_names = [info["ch_names"][p] for p in picks] + proj_op = make_projector(info["projs"], ch_names, info["bads"])[0] fun = _fit_dipole_fixed if fixed_position else _fit_dipole out = _fit_dipoles( - fun, min_dist_to_inner_skull, data, times, guess_src['rr'], - guess_data, sensors=sensors, fwd_data=fwd_data, whitener=whitener, - ori=ori, n_jobs=n_jobs, rank=rank, rhoend=tol) + fun, + min_dist_to_inner_skull, + data, + times, + guess_src["rr"], + guess_data, + sensors=sensors, + fwd_data=fwd_data, + whitener=whitener, + ori=ori, + n_jobs=n_jobs, + rank=rank, + rhoend=tol, + ) assert len(out) == 8 if fixed_position and ori is not None: # DipoleFixed @@ -1433,38 +1711,66 @@ def fit_dipole(evoked, cov, bem, trans=None, min_dist=5., n_jobs=None, out_info = deepcopy(info) loc = np.concatenate([pos, ori, np.zeros(6)]) out_info._unlocked = True - out_info['chs'] = [ - dict(ch_name='dip 01', loc=loc, kind=FIFF.FIFFV_DIPOLE_WAVE, - coord_frame=FIFF.FIFFV_COORD_UNKNOWN, unit=FIFF.FIFF_UNIT_AM, - coil_type=FIFF.FIFFV_COIL_DIPOLE, - unit_mul=0, range=1, cal=1., scanno=1, logno=1), - dict(ch_name='goodness', loc=np.full(12, np.nan), - kind=FIFF.FIFFV_GOODNESS_FIT, unit=FIFF.FIFF_UNIT_AM, - coord_frame=FIFF.FIFFV_COORD_UNKNOWN, - coil_type=FIFF.FIFFV_COIL_NONE, - unit_mul=0, range=1., cal=1., scanno=2, logno=100)] - for key in ['hpi_meas', 'hpi_results', 'projs']: + out_info["chs"] = [ + dict( + ch_name="dip 01", + loc=loc, + kind=FIFF.FIFFV_DIPOLE_WAVE, + coord_frame=FIFF.FIFFV_COORD_UNKNOWN, + unit=FIFF.FIFF_UNIT_AM, + coil_type=FIFF.FIFFV_COIL_DIPOLE, + unit_mul=0, + range=1, + cal=1.0, + scanno=1, + logno=1, + ), + dict( + ch_name="goodness", + loc=np.full(12, np.nan), + kind=FIFF.FIFFV_GOODNESS_FIT, + unit=FIFF.FIFF_UNIT_AM, + coord_frame=FIFF.FIFFV_COORD_UNKNOWN, + coil_type=FIFF.FIFFV_COIL_NONE, + unit_mul=0, + range=1.0, + cal=1.0, + scanno=2, + logno=100, + ), + ] + for key in ["hpi_meas", "hpi_results", "projs"]: out_info[key] = list() - for key in ['acq_pars', 'acq_stim', 'description', 'dig', - 'experimenter', 'hpi_subsystem', 'proj_id', 'proj_name', - 'subject_info']: + for key in [ + "acq_pars", + "acq_stim", + "description", + "dig", + "experimenter", + "hpi_subsystem", + "proj_id", + "proj_name", + "subject_info", + ]: out_info[key] = None out_info._unlocked = False - out_info['bads'] = [] + out_info["bads"] = [] out_info._update_redundant() out_info._check_consistency() - dipoles = DipoleFixed(out_info, data, times, evoked.nave, - evoked._aspect_kind, comment=comment) + dipoles = DipoleFixed( + out_info, data, times, evoked.nave, evoked._aspect_kind, comment=comment + ) else: - dipoles = Dipole(times, out[0], out[1], out[2], out[3], comment, - out[4], out[5], out[6]) + dipoles = Dipole( + times, out[0], out[1], out[2], out[3], comment, out[4], out[5], out[6] + ) residual = evoked.copy().apply_proj() # set the projs active residual.data[picks] = np.dot(proj_op, out[-1]) - logger.info('%d time points fitted' % len(dipoles.times)) + logger.info("%d time points fitted" % len(dipoles.times)) return dipoles, residual -def get_phantom_dipoles(kind='vectorview'): +def get_phantom_dipoles(kind="vectorview"): """Get standard phantom dipole locations and orientations. Parameters @@ -1493,8 +1799,8 @@ def get_phantom_dipoles(kind='vectorview'): The Elekta phantoms have a radius of 79.5mm, and HPI coil locations in the XY-plane at the axis extrema (e.g., (79.5, 0), (0, -79.5), ...). """ - _check_option('kind', kind, ['vectorview', 'otaniemi']) - if kind == 'vectorview': + _check_option("kind", kind, ["vectorview", "otaniemi"]) + if kind == "vectorview": # these values were pulled from a scanned image provided by # Elekta folks a = np.array([59.7, 48.6, 35.8, 24.8, 37.2, 27.5, 15.8, 7.9]) @@ -1505,7 +1811,7 @@ def get_phantom_dipoles(kind='vectorview'): d = [44.4, 34.0, 21.6, 12.7, 62.4, 51.5, 39.1, 27.9] z = np.concatenate((c, c, d, d)) signs = ([1, -1] * 4 + [-1, 1] * 4) * 2 - elif kind == 'otaniemi': + elif kind == "otaniemi": # these values were pulled from an Neuromag manual # (NM20456A, 13.7.1999, p.65) a = np.array([56.3, 47.6, 39.0, 30.3]) @@ -1515,7 +1821,7 @@ def get_phantom_dipoles(kind='vectorview'): y = np.concatenate((c, c, -a, -b, c, c, b, a)) z = np.concatenate((b, a, b, a, b, a, a, b)) signs = [-1] * 8 + [1] * 16 + [-1] * 8 - pos = np.vstack((x, y, z)).T / 1000. + pos = np.vstack((x, y, z)).T / 1000.0 # Locs are always in XZ or YZ, and so are the oris. The oris are # also in the same plane and tangential, so it's easy to determine # the orientation. @@ -1525,8 +1831,7 @@ def get_phantom_dipoles(kind='vectorview'): idx = np.where(this_pos == 0)[0] # assert len(idx) == 1 idx = np.setdiff1d(np.arange(3), idx[0]) - this_ori[idx] = (this_pos[idx][::-1] / - np.linalg.norm(this_pos[idx])) * [1, -1] + this_ori[idx] = (this_pos[idx][::-1] / np.linalg.norm(this_pos[idx])) * [1, -1] this_ori *= signs[pi] # Now we have this quality, which we could uncomment to # double-check: @@ -1548,6 +1853,11 @@ def _concatenate_dipoles(dipoles): ori.append(dipole.ori) gof.append(dipole.gof) - return Dipole(np.concatenate(times), np.concatenate(pos), - np.concatenate(amplitude), np.concatenate(ori), - np.concatenate(gof), name=None) + return Dipole( + np.concatenate(times), + np.concatenate(pos), + np.concatenate(amplitude), + np.concatenate(ori), + np.concatenate(gof), + name=None, + ) diff --git a/mne/epochs.py b/mne/epochs.py index 8a9e83d22d9..050d0cfaec4 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -19,66 +19,105 @@ import numpy as np from .io.utils import _construct_bids_filename -from .io.write import (start_and_end_file, start_block, end_block, - write_int, write_float, write_float_matrix, - write_double_matrix, write_complex_float_matrix, - write_complex_double_matrix, write_id, write_string, - _get_split_size, _NEXT_FILE_BUFFER, INT32_MAX) -from .io.meas_info import (read_meas_info, write_meas_info, - _ensure_infos_match, ContainsMixin) +from .io.write import ( + start_and_end_file, + start_block, + end_block, + write_int, + write_float, + write_float_matrix, + write_double_matrix, + write_complex_float_matrix, + write_complex_double_matrix, + write_id, + write_string, + _get_split_size, + _NEXT_FILE_BUFFER, + INT32_MAX, +) +from .io.meas_info import ( + read_meas_info, + write_meas_info, + _ensure_infos_match, + ContainsMixin, +) from .io.open import fiff_open, _get_next_fname from .io.tree import dir_tree_find from .io.tag import read_tag, read_tag_info from .io.constants import FIFF from .io.fiff.raw import _get_fname_rep -from .io.pick import (channel_indices_by_type, channel_type, - pick_channels, pick_info, _pick_data_channels, - _DATA_CH_TYPES_SPLIT, _picks_to_idx) +from .io.pick import ( + channel_indices_by_type, + channel_type, + pick_channels, + pick_info, + _pick_data_channels, + _DATA_CH_TYPES_SPLIT, + _picks_to_idx, +) from .io.proj import setup_proj, ProjMixin from .io.base import BaseRaw, TimeMixin, _get_ch_factors from .bem import _check_origin from .evoked import EvokedArray from .baseline import rescale, _log_rescale, _check_baseline -from .channels.channels import (UpdateChannelsMixin, - SetChannelsMixin, InterpolationMixin) +from .channels.channels import UpdateChannelsMixin, SetChannelsMixin, InterpolationMixin from .filter import detrend, FilterMixin, _check_fun from .parallel import parallel_func -from .event import (_read_events_fif, make_fixed_length_events, - match_event_names) +from .event import _read_events_fif, make_fixed_length_events, match_event_names from .fixes import rng_uniform -from .time_frequency.spectrum import (EpochsSpectrum, SpectrumMixin, - _validate_method) -from .viz import (plot_epochs, plot_epochs_image, - plot_topo_image_epochs, plot_drop_log) -from .utils import (_check_fname, check_fname, logger, verbose, repr_html, - check_random_state, warn, _pl, - sizeof_fmt, SizeMixin, copy_function_doc_to_method_doc, - _check_pandas_installed, - _check_preload, GetEpochsMixin, - _prepare_read_metadata, _prepare_write_metadata, - _check_event_id, _gen_events, _check_option, - _check_combine, _build_data_frame, - _check_pandas_index_arguments, _convert_times, - _scale_dataframe_data, _check_time_format, object_size, - _on_missing, _validate_type, _ensure_events, - _path_like) +from .time_frequency.spectrum import EpochsSpectrum, SpectrumMixin, _validate_method +from .viz import plot_epochs, plot_epochs_image, plot_topo_image_epochs, plot_drop_log +from .utils import ( + _check_fname, + check_fname, + logger, + verbose, + repr_html, + check_random_state, + warn, + _pl, + sizeof_fmt, + SizeMixin, + copy_function_doc_to_method_doc, + _check_pandas_installed, + _check_preload, + GetEpochsMixin, + _prepare_read_metadata, + _prepare_write_metadata, + _check_event_id, + _gen_events, + _check_option, + _check_combine, + _build_data_frame, + _check_pandas_index_arguments, + _convert_times, + _scale_dataframe_data, + _check_time_format, + object_size, + _on_missing, + _validate_type, + _ensure_events, + _path_like, +) from .utils.docs import fill_doc -from .annotations import (_write_annotations, _read_annotations_fif, - EpochAnnotationsMixin) +from .annotations import ( + _write_annotations, + _read_annotations_fif, + EpochAnnotationsMixin, +) def _pack_reject_params(epochs): reject_params = dict() - for key in ('reject', 'flat', 'reject_tmin', 'reject_tmax'): + for key in ("reject", "flat", "reject_tmin", "reject_tmax"): val = getattr(epochs, key, None) if val is not None: reject_params[key] = val return reject_params -def _save_split(epochs, fname, part_idx, n_parts, fmt, split_naming, - overwrite): +def _save_split(epochs, fname, part_idx, n_parts, fmt, split_naming, overwrite): """Split epochs. Anything new added to this function also needs to be added to @@ -87,22 +126,22 @@ def _save_split(epochs, fname, part_idx, n_parts, fmt, split_naming, # insert index in filename base, ext = op.splitext(fname) if part_idx > 0: - if split_naming == 'neuromag': - fname = '%s-%d%s' % (base, part_idx, ext) + if split_naming == "neuromag": + fname = "%s-%d%s" % (base, part_idx, ext) else: - assert split_naming == 'bids' - fname = _construct_bids_filename(base, ext, part_idx, - validate=False) + assert split_naming == "bids" + fname = _construct_bids_filename(base, ext, part_idx, validate=False) _check_fname(fname, overwrite=overwrite) next_fname = None if part_idx < n_parts - 1: - if split_naming == 'neuromag': - next_fname = '%s-%d%s' % (base, part_idx + 1, ext) + if split_naming == "neuromag": + next_fname = "%s-%d%s" % (base, part_idx + 1, ext) else: - assert split_naming == 'bids' - next_fname = _construct_bids_filename(base, ext, part_idx + 1, - validate=False) + assert split_naming == "bids" + next_fname = _construct_bids_filename( + base, ext, part_idx + 1, validate=False + ) next_idx = part_idx + 1 else: next_idx = None @@ -113,12 +152,12 @@ def _save_split(epochs, fname, part_idx, n_parts, fmt, split_naming, def _save_part(fid, epochs, fmt, n_parts, next_fname, next_idx): info = epochs.info - meas_id = info['meas_id'] + meas_id = info["meas_id"] start_block(fid, FIFF.FIFFB_MEAS) write_id(fid, FIFF.FIFF_BLOCK_ID) - if info['meas_id'] is not None: - write_id(fid, FIFF.FIFF_PARENT_BLOCK_ID, info['meas_id']) + if info["meas_id"] is not None: + write_id(fid, FIFF.FIFF_PARENT_BLOCK_ID, info["meas_id"]) # Write measurement info write_meas_info(fid, info) @@ -130,21 +169,21 @@ def _save_part(fid, epochs, fmt, n_parts, next_fname, next_idx): # write events out after getting data to ensure bad events are dropped data = epochs.get_data() - _check_option('fmt', fmt, ['single', 'double']) + _check_option("fmt", fmt, ["single", "double"]) if np.iscomplexobj(data): - if fmt == 'single': + if fmt == "single": write_function = write_complex_float_matrix - elif fmt == 'double': + elif fmt == "double": write_function = write_complex_double_matrix else: - if fmt == 'single': + if fmt == "single": write_function = write_float_matrix - elif fmt == 'double': + elif fmt == "double": write_function = write_double_matrix # Epoch annotations are written if there are any - annotations = getattr(epochs, 'annotations', []) + annotations = getattr(epochs, "annotations", []) if annotations is not None and len(annotations): _write_annotations(fid, annotations) @@ -162,7 +201,7 @@ def _save_part(fid, epochs, fmt, n_parts, next_fname, next_idx): end_block(fid, FIFF.FIFFB_MNE_METADATA) # First and last sample - first = int(round(epochs.tmin * info['sfreq'])) # round just to be safe + first = int(round(epochs.tmin * info["sfreq"])) # round just to be safe last = first + len(epochs.times) - 1 write_int(fid, FIFF.FIFF_FIRST_SAMPLE, first) write_int(fid, FIFF.FIFF_LAST_SAMPLE, last) @@ -177,10 +216,9 @@ def _save_part(fid, epochs, fmt, n_parts, next_fname, next_idx): write_float(fid, FIFF.FIFF_MNE_BASELINE_MAX, bmax) # The epochs itself - decal = np.empty(info['nchan']) - for k in range(info['nchan']): - decal[k] = 1.0 / (info['chs'][k]['cal'] * - info['chs'][k].get('scale', 1.0)) + decal = np.empty(info["nchan"]) + for k in range(info["nchan"]): + decal[k] = 1.0 / (info["chs"][k]["cal"] * info["chs"][k].get("scale", 1.0)) data *= decal[np.newaxis, :, np.newaxis] @@ -189,16 +227,13 @@ def _save_part(fid, epochs, fmt, n_parts, next_fname, next_idx): # undo modifications to data data /= decal[np.newaxis, :, np.newaxis] - write_string(fid, FIFF.FIFF_MNE_EPOCHS_DROP_LOG, - json.dumps(epochs.drop_log)) + write_string(fid, FIFF.FIFF_MNE_EPOCHS_DROP_LOG, json.dumps(epochs.drop_log)) reject_params = _pack_reject_params(epochs) if reject_params: - write_string(fid, FIFF.FIFF_MNE_EPOCHS_REJECT_FLAT, - json.dumps(reject_params)) + write_string(fid, FIFF.FIFF_MNE_EPOCHS_REJECT_FLAT, json.dumps(reject_params)) - write_int(fid, FIFF.FIFF_MNE_EPOCHS_SELECTION, - epochs.selection) + write_int(fid, FIFF.FIFF_MNE_EPOCHS_SELECTION, epochs.selection) # And now write the next file info in case epochs are split on disk if next_fname is not None and n_parts > 1: @@ -216,7 +251,7 @@ def _save_part(fid, epochs, fmt, n_parts, next_fname, next_idx): def _event_id_string(event_id): - return ';'.join([k + ':' + str(v) for k, v in event_id.items()]) + return ";".join([k + ":" + str(v) for k, v in event_id.items()]) def _merge_events(events, event_id, selection): @@ -226,7 +261,6 @@ def _merge_events(events, event_id, selection): event_idxs_to_delete = list() unique_events, counts = np.unique(events[:, 0], return_counts=True) for ev in unique_events[counts > 1]: - # indices at which the non-unique events happened idxs = (events[:, 0] == ev).nonzero()[0] @@ -242,18 +276,18 @@ def _merge_events(events, event_id, selection): # Else, make a new event_id for the merged event else: - # Find all event_id keys involved in duplicated events. These # keys will be merged to become a new entry in "event_id" event_id_keys = list(event_id.keys()) event_id_vals = list(event_id.values()) - new_key_comps = [event_id_keys[event_id_vals.index(value)] - for value in ev_vals] + new_key_comps = [ + event_id_keys[event_id_vals.index(value)] for value in ev_vals + ] # Check if we already have an entry for merged keys of duplicate # events ... if yes, reuse it for key in event_id: - if set(key.split('/')) == set(new_key_comps): + if set(key.split("/")) == set(new_key_comps): new_event_val = event_id[key] break @@ -261,9 +295,10 @@ def _merge_events(events, event_id, selection): # the event_id dict else: ev_vals = np.unique( - np.concatenate((list(event_id.values()), - events[:, 1:].flatten()), - axis=0)) + np.concatenate( + (list(event_id.values()), events[:, 1:].flatten()), axis=0 + ) + ) if ev_vals[0] > 1: new_event_val = 1 else: @@ -272,7 +307,7 @@ def _merge_events(events, event_id, selection): idx = -1 if len(idx) == 0 else idx[0] new_event_val = ev_vals[idx] + 1 - new_event_id_key = '/'.join(sorted(new_key_comps)) + new_event_id_key = "/".join(sorted(new_key_comps)) event_id[new_event_id_key] = int(new_event_val) # Replace duplicate event times with merged event and remember which @@ -288,8 +323,7 @@ def _merge_events(events, event_id, selection): return new_events, event_id, new_selection -def _handle_event_repeated(events, event_id, event_repeated, selection, - drop_log): +def _handle_event_repeated(events, event_id, event_repeated, selection, drop_log): """Handle repeated events. Note that drop_log will be modified inplace @@ -304,29 +338,34 @@ def _handle_event_repeated(events, event_id, event_repeated, selection, return events, event_id, selection, drop_log # Else, we have duplicates. Triage ... - _check_option('event_repeated', event_repeated, ['error', 'drop', 'merge']) + _check_option("event_repeated", event_repeated, ["error", "drop", "merge"]) drop_log = list(drop_log) - if event_repeated == 'error': - raise RuntimeError('Event time samples were not unique. Consider ' - 'setting the `event_repeated` parameter."') + if event_repeated == "error": + raise RuntimeError( + "Event time samples were not unique. Consider " + 'setting the `event_repeated` parameter."' + ) - elif event_repeated == 'drop': - logger.info('Multiple event values for single event times found. ' - 'Keeping the first occurrence and dropping all others.') + elif event_repeated == "drop": + logger.info( + "Multiple event values for single event times found. " + "Keeping the first occurrence and dropping all others." + ) new_events = events[u_ev_idxs] new_selection = selection[u_ev_idxs] drop_ev_idxs = np.setdiff1d(selection, new_selection) for idx in drop_ev_idxs: - drop_log[idx] = drop_log[idx] + ('DROP DUPLICATE',) + drop_log[idx] = drop_log[idx] + ("DROP DUPLICATE",) selection = new_selection - elif event_repeated == 'merge': - logger.info('Multiple event values for single event times found. ' - 'Creating new event value to reflect simultaneous events.') - new_events, event_id, new_selection = \ - _merge_events(events, event_id, selection) + elif event_repeated == "merge": + logger.info( + "Multiple event values for single event times found. " + "Creating new event value to reflect simultaneous events." + ) + new_events, event_id, new_selection = _merge_events(events, event_id, selection) drop_ev_idxs = np.setdiff1d(selection, new_selection) for idx in drop_ev_idxs: - drop_log[idx] = drop_log[idx] + ('MERGE DUPLICATE',) + drop_log[idx] = drop_log[idx] + ("MERGE DUPLICATE",) selection = new_selection drop_log = tuple(drop_log) @@ -338,10 +377,19 @@ def _handle_event_repeated(events, event_id, event_repeated, selection, @fill_doc -class BaseEpochs(ProjMixin, ContainsMixin, UpdateChannelsMixin, - SetChannelsMixin, InterpolationMixin, FilterMixin, - TimeMixin, SizeMixin, GetEpochsMixin, EpochAnnotationsMixin, - SpectrumMixin): +class BaseEpochs( + ProjMixin, + ContainsMixin, + UpdateChannelsMixin, + SetChannelsMixin, + InterpolationMixin, + FilterMixin, + TimeMixin, + SizeMixin, + GetEpochsMixin, + EpochAnnotationsMixin, + SpectrumMixin, +): """Abstract base class for `~mne.Epochs`-type classes. .. note:: @@ -399,22 +447,44 @@ class BaseEpochs(ProjMixin, ContainsMixin, UpdateChannelsMixin, """ @verbose - def __init__(self, info, data, events, event_id=None, - tmin=-0.2, tmax=0.5, - baseline=(None, 0), raw=None, picks=None, reject=None, - flat=None, decim=1, reject_tmin=None, reject_tmax=None, - detrend=None, proj=True, on_missing='raise', - preload_at_end=False, selection=None, drop_log=None, - filename=None, metadata=None, event_repeated='error', - *, raw_sfreq=None, - annotations=None, verbose=None): # noqa: D102 + def __init__( + self, + info, + data, + events, + event_id=None, + tmin=-0.2, + tmax=0.5, + baseline=(None, 0), + raw=None, + picks=None, + reject=None, + flat=None, + decim=1, + reject_tmin=None, + reject_tmax=None, + detrend=None, + proj=True, + on_missing="raise", + preload_at_end=False, + selection=None, + drop_log=None, + filename=None, + metadata=None, + event_repeated="error", + *, + raw_sfreq=None, + annotations=None, + verbose=None, + ): # noqa: D102 if events is not None: # RtEpochs can have events=None events = _ensure_events(events) events_max = events.max() if events_max > INT32_MAX: raise ValueError( - f'events array values must not exceed {INT32_MAX}, ' - f'got {events_max}') + f"events array values must not exceed {INT32_MAX}, " + f"got {events_max}" + ) event_id = _check_event_id(event_id, events) self.event_id = event_id del event_id @@ -422,8 +492,10 @@ def __init__(self, info, data, events, event_id=None, if events is not None: # RtEpochs can have events=None for key, val in self.event_id.items(): if val not in events[:, 2]: - msg = ('No matching events found for %s ' - '(event id %i)' % (key, val)) + msg = "No matching events found for %s " "(event id %i)" % ( + key, + val, + ) _on_missing(on_missing, msg) # ensure metadata matches original events size @@ -442,23 +514,33 @@ def __init__(self, info, data, events, event_id=None, else: selection = np.array(selection, int) if selection.shape != (len(selected),): - raise ValueError('selection must be shape %s got shape %s' - % (selected.shape, selection.shape)) + raise ValueError( + "selection must be shape %s got shape %s" + % (selected.shape, selection.shape) + ) self.selection = selection if drop_log is None: self.drop_log = tuple( - () if k in self.selection else ('IGNORED',) - for k in range(max(len(self.events), - max(self.selection) + 1))) + () if k in self.selection else ("IGNORED",) + for k in range(max(len(self.events), max(self.selection) + 1)) + ) else: self.drop_log = drop_log self.events = self.events[selected] - self.events, self.event_id, self.selection, self.drop_log = \ - _handle_event_repeated( - self.events, self.event_id, event_repeated, - self.selection, self.drop_log) + ( + self.events, + self.event_id, + self.selection, + self.drop_log, + ) = _handle_event_repeated( + self.events, + self.event_id, + event_repeated, + self.selection, + self.drop_log, + ) # then subselect sub = np.where(np.in1d(selection, self.selection))[0] @@ -477,13 +559,16 @@ def __init__(self, info, data, events, event_id=None, n_events = len(self.events) if n_events > 1: if np.diff(self.events.astype(np.int64)[:, 0]).min() <= 0: - warn('The events passed to the Epochs constructor are not ' - 'chronologically ordered.', RuntimeWarning) + warn( + "The events passed to the Epochs constructor are not " + "chronologically ordered.", + RuntimeWarning, + ) if n_events > 0: - logger.info('%d matching events found' % n_events) + logger.info("%d matching events found" % n_events) else: - raise ValueError('No desired events found.') + raise ValueError("No desired events found.") else: self.drop_log = tuple() self.selection = np.array([], int) @@ -491,13 +576,14 @@ def __init__(self, info, data, events, event_id=None, # do not set self.events here, let subclass do it if (detrend not in [None, 0, 1]) or isinstance(detrend, bool): - raise ValueError('detrend must be None, 0, or 1') + raise ValueError("detrend must be None, 0, or 1") self.detrend = detrend self._raw = raw info._check_consistency() - self.picks = _picks_to_idx(info, picks, none='all', exclude=(), - allow_empty=False) + self.picks = _picks_to_idx( + info, picks, none="all", exclude=(), allow_empty=False + ) self.info = pick_info(info, self.picks) del info self._current = 0 @@ -508,48 +594,54 @@ def __init__(self, info, data, events, event_id=None, self._do_baseline = True else: assert decim == 1 - if data.ndim != 3 or data.shape[2] != \ - round((tmax - tmin) * self.info['sfreq']) + 1: - raise RuntimeError('bad data shape') + if ( + data.ndim != 3 + or data.shape[2] != round((tmax - tmin) * self.info["sfreq"]) + 1 + ): + raise RuntimeError("bad data shape") if data.shape[0] != len(self.events): raise ValueError( - 'The number of epochs and the number of events must match') + "The number of epochs and the number of events must match" + ) self.preload = True self._data = data self._do_baseline = False self._offset = None if tmin > tmax: - raise ValueError('tmin has to be less than or equal to tmax') + raise ValueError("tmin has to be less than or equal to tmax") # Handle times - sfreq = float(self.info['sfreq']) + sfreq = float(self.info["sfreq"]) start_idx = int(round(tmin * sfreq)) - self._raw_times = np.arange(start_idx, - int(round(tmax * sfreq)) + 1) / sfreq + self._raw_times = np.arange(start_idx, int(round(tmax * sfreq)) + 1) / sfreq self._set_times(self._raw_times) # check reject_tmin and reject_tmax if reject_tmin is not None: - if (np.isclose(reject_tmin, tmin)): + if np.isclose(reject_tmin, tmin): # adjust for potential small deviations due to sampling freq reject_tmin = self.tmin elif reject_tmin < tmin: - raise ValueError(f'reject_tmin needs to be None or >= tmin ' - f'(got {reject_tmin})') + raise ValueError( + f"reject_tmin needs to be None or >= tmin " f"(got {reject_tmin})" + ) if reject_tmax is not None: - if (np.isclose(reject_tmax, tmax)): + if np.isclose(reject_tmax, tmax): # adjust for potential small deviations due to sampling freq reject_tmax = self.tmax elif reject_tmax > tmax: - raise ValueError(f'reject_tmax needs to be None or <= tmax ' - f'(got {reject_tmax})') + raise ValueError( + f"reject_tmax needs to be None or <= tmax " f"(got {reject_tmax})" + ) if (reject_tmin is not None) and (reject_tmax is not None): if reject_tmin >= reject_tmax: - raise ValueError(f'reject_tmin ({reject_tmin}) needs to be ' - f' < reject_tmax ({reject_tmax})') + raise ValueError( + f"reject_tmin ({reject_tmin}) needs to be " + f" < reject_tmax ({reject_tmax})" + ) self.reject_tmin = reject_tmin self.reject_tmax = reject_tmax @@ -559,11 +651,14 @@ def __init__(self, info, data, events, event_id=None, self.decimate(decim) # baseline correction: replace `None` tuple elements with actual times - self.baseline = _check_baseline(baseline, times=self.times, - sfreq=self.info['sfreq']) + self.baseline = _check_baseline( + baseline, times=self.times, sfreq=self.info["sfreq"] + ) if self.baseline is not None and self.baseline != baseline: - logger.info(f'Setting baseline interval to ' - f'[{self.baseline[0]}, {self.baseline[1]}] s') + logger.info( + f"Setting baseline interval to " + f"[{self.baseline[0]}, {self.baseline[1]}] s" + ) logger.info(_log_rescale(self.baseline)) @@ -573,18 +668,16 @@ def __init__(self, info, data, events, event_id=None, self._reject_setup(reject, flat) # do the rest - valid_proj = [True, 'delayed', False] + valid_proj = [True, "delayed", False] if proj not in valid_proj: - raise ValueError('"proj" must be one of %s, not %s' - % (valid_proj, proj)) - if proj == 'delayed': + raise ValueError('"proj" must be one of %s, not %s' % (valid_proj, proj)) + if proj == "delayed": self._do_delayed_proj = True - logger.info('Entering delayed SSP mode.') + logger.info("Entering delayed SSP mode.") else: self._do_delayed_proj = False activate = False if self._do_delayed_proj else proj - self._projector, self.info = setup_proj(self.info, False, - activate=activate) + self._projector, self.info = setup_proj(self.info, False, activate=activate) if preload_at_end: assert self._data is None assert self.preload is False @@ -598,20 +691,19 @@ def __init__(self, info, data, events, event_id=None, self._data[ii] = np.dot(self._projector, epoch) self._filename = str(filename) if filename is not None else filename if raw_sfreq is None: - raw_sfreq = self.info['sfreq'] + raw_sfreq = self.info["sfreq"] self._raw_sfreq = raw_sfreq self._check_consistency() self.set_annotations(annotations) def _check_consistency(self): """Check invariants of epochs object.""" - if hasattr(self, 'events'): + if hasattr(self, "events"): assert len(self.selection) == len(self.events) assert len(self.drop_log) >= len(self.events) - assert len(self.selection) == sum( - (len(dl) == 0 for dl in self.drop_log)) - assert hasattr(self, '_times_readonly') - assert not self.times.flags['WRITEABLE'] + assert len(self.selection) == sum((len(dl) == 0 for dl in self.drop_log)) + assert hasattr(self, "_times_readonly") + assert not self.times.flags["WRITEABLE"] assert isinstance(self.drop_log, tuple) assert all(isinstance(log, tuple) for log in self.drop_log) assert all(isinstance(s, str) for log in self.drop_log for s in log) @@ -678,14 +770,15 @@ def apply_baseline(self, baseline=(None, 0), *, verbose=None): .. versionadded:: 0.10.0 """ - baseline = _check_baseline(baseline, times=self.times, - sfreq=self.info['sfreq']) + baseline = _check_baseline(baseline, times=self.times, sfreq=self.info["sfreq"]) if self.preload: if self.baseline is not None and baseline is None: - raise RuntimeError('You cannot remove baseline correction ' - 'from preloaded data once it has been ' - 'applied.') + raise RuntimeError( + "You cannot remove baseline correction " + "from preloaded data once it has been " + "applied." + ) self._do_baseline = True picks = self._detrend_picks rescale(self._data, self.times, baseline, copy=False, picks=picks) @@ -704,39 +797,45 @@ def _reject_setup(self, reject, flat): idx = channel_indices_by_type(self.info) reject = deepcopy(reject) if reject is not None else dict() flat = deepcopy(flat) if flat is not None else dict() - for rej, kind in zip((reject, flat), ('reject', 'flat')): + for rej, kind in zip((reject, flat), ("reject", "flat")): if not isinstance(rej, dict): - raise TypeError('reject and flat must be dict or None, not %s' - % type(rej)) + raise TypeError( + "reject and flat must be dict or None, not %s" % type(rej) + ) bads = set(rej.keys()) - set(idx.keys()) if len(bads) > 0: - raise KeyError('Unknown channel types found in %s: %s' - % (kind, bads)) + raise KeyError("Unknown channel types found in %s: %s" % (kind, bads)) for key in idx.keys(): # don't throw an error if rejection/flat would do nothing - if len(idx[key]) == 0 and (np.isfinite(reject.get(key, np.inf)) or - flat.get(key, -1) >= 0): + if len(idx[key]) == 0 and ( + np.isfinite(reject.get(key, np.inf)) or flat.get(key, -1) >= 0 + ): # This is where we could eventually add e.g. # self.allow_missing_reject_keys check to allow users to # provide keys that don't exist in data - raise ValueError("No %s channel found. Cannot reject based on " - "%s." % (key.upper(), key.upper())) + raise ValueError( + "No %s channel found. Cannot reject based on " + "%s." % (key.upper(), key.upper()) + ) # check for invalid values - for rej, kind in zip((reject, flat), ('Rejection', 'Flat')): + for rej, kind in zip((reject, flat), ("Rejection", "Flat")): for key, val in rej.items(): if val is None or val < 0: - raise ValueError('%s value must be a number >= 0, not "%s"' - % (kind, val)) + raise ValueError( + '%s value must be a number >= 0, not "%s"' % (kind, val) + ) # now check to see if our rejection and flat are getting more # restrictive old_reject = self.reject if self.reject is not None else dict() old_flat = self.flat if self.flat is not None else dict() - bad_msg = ('{kind}["{key}"] == {new} {op} {old} (old value), new ' - '{kind} values must be at least as stringent as ' - 'previous ones') + bad_msg = ( + '{kind}["{key}"] == {new} {op} {old} (old value), new ' + "{kind} values must be at least as stringent as " + "previous ones" + ) # copy thresholds for channel types that were used previously, but not # passed this time @@ -746,8 +845,14 @@ def _reject_setup(self, reject, flat): for key in reject: if key in old_reject and reject[key] > old_reject[key]: raise ValueError( - bad_msg.format(kind='reject', key=key, new=reject[key], - old=old_reject[key], op='>')) + bad_msg.format( + kind="reject", + key=key, + new=reject[key], + old=old_reject[key], + op=">", + ) + ) # same for flat thresholds for key in set(old_flat) - set(flat): @@ -755,8 +860,10 @@ def _reject_setup(self, reject, flat): for key in flat: if key in old_flat and flat[key] < old_flat[key]: raise ValueError( - bad_msg.format(kind='flat', key=key, new=flat[key], - old=old_flat[key], op='<')) + bad_msg.format( + kind="flat", key=key, new=flat[key], old=old_flat[key], op="<" + ) + ) # after validation, set parameters self._bad_dropped = False @@ -785,20 +892,26 @@ def _is_good_epoch(self, data, verbose=None): if isinstance(data, str): return False, (data,) if data is None: - return False, ('NO_DATA',) + return False, ("NO_DATA",) n_times = len(self.times) if data.shape[1] < n_times: # epoch is too short ie at the end of the data - return False, ('TOO_SHORT',) + return False, ("TOO_SHORT",) if self.reject is None and self.flat is None: return True, None else: if self._reject_time is not None: data = data[:, self._reject_time] - return _is_good(data, self.ch_names, self._channel_type_idx, - self.reject, self.flat, full_report=True, - ignore_chs=self.info['bads']) + return _is_good( + data, + self.ch_names, + self._channel_type_idx, + self.reject, + self.flat, + full_report=True, + ignore_chs=self.info["bads"], + ) @verbose def _detrend_offset_decim(self, epoch, picks, verbose=None): @@ -819,8 +932,13 @@ def _detrend_offset_decim(self, epoch, picks, verbose=None): # Baseline correct if self._do_baseline: rescale( - epoch, self._raw_times, self.baseline, picks=picks, copy=False, - verbose=False) + epoch, + self._raw_times, + self.baseline, + picks=picks, + copy=False, + verbose=False, + ) # Decimate if necessary (i.e., epoch not preloaded) epoch = epoch[:, self._decim_slice] @@ -883,14 +1001,13 @@ def subtract_evoked(self, evoked=None): .. [1] David et al. "Mechanisms of evoked and induced responses in MEG/EEG", NeuroImage, vol. 31, no. 4, pp. 1580-1591, July 2006. """ - logger.info('Subtracting Evoked from Epochs') + logger.info("Subtracting Evoked from Epochs") if evoked is None: picks = _pick_data_channels(self.info, exclude=[]) evoked = self.average(picks) # find the indices of the channels to use - picks = pick_channels( - evoked.ch_names, include=self.ch_names, ordered=False) + picks = pick_channels(evoked.ch_names, include=self.ch_names, ordered=False) # make sure the omitted channels are not data channels if len(picks) < len(self.ch_names): @@ -898,24 +1015,32 @@ def subtract_evoked(self, evoked=None): diff_ch = list(set(self.ch_names).difference(sel_ch)) diff_idx = [self.ch_names.index(ch) for ch in diff_ch] diff_types = [channel_type(self.info, idx) for idx in diff_idx] - bad_idx = [diff_types.index(t) for t in diff_types if t in - _DATA_CH_TYPES_SPLIT] + bad_idx = [ + diff_types.index(t) for t in diff_types if t in _DATA_CH_TYPES_SPLIT + ] if len(bad_idx) > 0: - bad_str = ', '.join([diff_ch[ii] for ii in bad_idx]) - raise ValueError('The following data channels are missing ' - 'in the evoked response: %s' % bad_str) - logger.info(' The following channels are not included in the ' - 'subtraction: %s' % ', '.join(diff_ch)) + bad_str = ", ".join([diff_ch[ii] for ii in bad_idx]) + raise ValueError( + "The following data channels are missing " + "in the evoked response: %s" % bad_str + ) + logger.info( + " The following channels are not included in the " + "subtraction: %s" % ", ".join(diff_ch) + ) # make sure the times match - if (len(self.times) != len(evoked.times) or - np.max(np.abs(self.times - evoked.times)) >= 1e-7): - raise ValueError('Epochs and Evoked object do not contain ' - 'the same time points.') + if ( + len(self.times) != len(evoked.times) + or np.max(np.abs(self.times - evoked.times)) >= 1e-7 + ): + raise ValueError( + "Epochs and Evoked object do not contain " "the same time points." + ) # handle SSPs if not self.proj and evoked.proj: - warn('Evoked has SSP applied while Epochs has not.') + warn("Evoked has SSP applied while Epochs has not.") if self.proj and not evoked.proj: evoked = evoked.copy().apply_proj() @@ -927,10 +1052,11 @@ def subtract_evoked(self, evoked=None): self._data[:, ep_picks, :] -= evoked.data[picks][None, :, :] else: if self._offset is None: - self._offset = np.zeros((len(self.ch_names), len(self.times)), - dtype=np.float64) + self._offset = np.zeros( + (len(self.ch_names), len(self.times)), dtype=np.float64 + ) self._offset[ep_picks] -= evoked.data[picks] - logger.info('[done]') + logger.info("[done]") return self @@ -978,8 +1104,7 @@ def average(self, picks=None, method="mean", by_event_type=False): if by_event_type: evokeds = list() for event_type in self.event_id.keys(): - ev = self[event_type]._compute_aggregate(picks=picks, - mode=method) + ev = self[event_type]._compute_aggregate(picks=picks, mode=method) ev.comment = event_type evokeds.append(ev) else: @@ -999,39 +1124,43 @@ def standard_error(self, picks=None, by_event_type=False): ------- %(std_err_by_event_type_returns)s """ - return self.average(picks=picks, method="std", - by_event_type=by_event_type) + return self.average(picks=picks, method="std", by_event_type=by_event_type) - def _compute_aggregate(self, picks, mode='mean'): + def _compute_aggregate(self, picks, mode="mean"): """Compute the mean, median, or std over epochs and return Evoked.""" # if instance contains ICA channels they won't be included unless picks # is specified if picks is None: - check_ICA = [x.startswith('ICA') for x in self.ch_names] + check_ICA = [x.startswith("ICA") for x in self.ch_names] if np.all(check_ICA): - raise TypeError('picks must be specified (i.e. not None) for ' - 'ICA channel data') + raise TypeError( + "picks must be specified (i.e. not None) for " "ICA channel data" + ) elif np.any(check_ICA): - warn('ICA channels will not be included unless explicitly ' - 'selected in picks') + warn( + "ICA channels will not be included unless explicitly " + "selected in picks" + ) n_channels = len(self.ch_names) n_times = len(self.times) if self.preload: n_events = len(self.events) - fun = _check_combine(mode, valid=('mean', 'median', 'std')) + fun = _check_combine(mode, valid=("mean", "median", "std")) data = fun(self._data) assert len(self.events) == len(self._data) if data.shape != self._data.shape[1:]: raise RuntimeError( - 'You passed a function that resulted n data of shape {}, ' - 'but it should be {}.'.format( - data.shape, self._data.shape[1:])) + "You passed a function that resulted n data of shape {}, " + "but it should be {}.".format(data.shape, self._data.shape[1:]) + ) else: if mode not in {"mean", "std"}: - raise ValueError("If data are not preloaded, can only compute " - "mean or standard deviation.") + raise ValueError( + "If data are not preloaded, can only compute " + "mean or standard deviation." + ) data = np.zeros((n_channels, n_times)) n_events = 0 for e in self: @@ -1049,26 +1178,27 @@ def _compute_aggregate(self, picks, mode='mean'): # two (slower) in case there are large numbers if mode == "std": data_mean = data.copy() - data.fill(0.) + data.fill(0.0) for e in self: data += (e - data_mean) ** 2 data = np.sqrt(data / n_events) if mode == "std": - kind = 'standard_error' + kind = "standard_error" data /= np.sqrt(n_events) else: kind = "average" - return self._evoked_from_epoch_data(data, self.info, picks, n_events, - kind, self._name) + return self._evoked_from_epoch_data( + data, self.info, picks, n_events, kind, self._name + ) @property def _name(self): """Give a nice string representation based on event ids.""" return self._get_name() - def _get_name(self, count='frac', ms='×', sep='+'): + def _get_name(self, count="frac", ms="×", sep="+"): """Generate human-readable name for epochs and evokeds from event_id. Parameters @@ -1084,7 +1214,7 @@ def _get_name(self, count='frac', ms='×', sep='+'): How to separate the different events names. Ignored if only one event type is present. """ - _check_option('count', value=count, allowed_values=['frac', 'total']) + _check_option("count", value=count, allowed_values=["frac", "total"]) if len(self.event_id) == 1: comment = next(iter(self.event_id.keys())) @@ -1094,28 +1224,34 @@ def _get_name(self, count='frac', ms='×', sep='+'): # Take care of padding if ms is None: - ms = ' ' + ms = " " else: - ms = f' {ms} ' + ms = f" {ms} " for event_name, event_code in self.event_id.items(): - if count == 'frac': + if count == "frac": frac = float(counter[event_code]) / len(self.events) - comment = f'{frac:.2f}{ms}{event_name}' + comment = f"{frac:.2f}{ms}{event_name}" else: # 'total' - comment = f'{counter[event_code]}{ms}{event_name}' + comment = f"{counter[event_code]}{ms}{event_name}" comments.append(comment) - comment = f' {sep} '.join(comments) + comment = f" {sep} ".join(comments) return comment - def _evoked_from_epoch_data(self, data, info, picks, n_events, kind, - comment): + def _evoked_from_epoch_data(self, data, info, picks, n_events, kind, comment): """Create an evoked object from epoch data.""" info = deepcopy(info) # don't apply baseline correction; we'll set evoked.baseline manually - evoked = EvokedArray(data, info, tmin=self.times[0], comment=comment, - nave=n_events, kind=kind, baseline=None) + evoked = EvokedArray( + data, + info, + tmin=self.times[0], + comment=comment, + nave=n_events, + kind=kind, + baseline=None, + ) evoked.baseline = self.baseline # the above constructor doesn't recreate the times object precisely @@ -1123,58 +1259,116 @@ def _evoked_from_epoch_data(self, data, info, picks, n_events, kind, evoked._set_times(self.times.copy()) # pick channels - picks = _picks_to_idx(self.info, picks, 'data_or_ica', ()) + picks = _picks_to_idx(self.info, picks, "data_or_ica", ()) ch_names = [evoked.ch_names[p] for p in picks] evoked.pick_channels(ch_names) - if len(evoked.info['ch_names']) == 0: - raise ValueError('No data channel found when averaging.') + if len(evoked.info["ch_names"]) == 0: + raise ValueError("No data channel found when averaging.") if evoked.nave < 1: - warn('evoked object is empty (based on less than 1 epoch)') + warn("evoked object is empty (based on less than 1 epoch)") return evoked @property def ch_names(self): """Channel names.""" - return self.info['ch_names'] + return self.info["ch_names"] @copy_function_doc_to_method_doc(plot_epochs) - def plot(self, picks=None, scalings=None, n_epochs=20, n_channels=20, - title=None, events=None, event_color=None, - order=None, show=True, block=False, decim='auto', noise_cov=None, - butterfly=False, show_scrollbars=True, show_scalebars=True, - epoch_colors=None, event_id=None, group_by='type', - precompute=None, use_opengl=None, *, theme=None, - overview_mode=None): - return plot_epochs(self, picks=picks, scalings=scalings, - n_epochs=n_epochs, n_channels=n_channels, - title=title, events=events, event_color=event_color, - order=order, show=show, block=block, decim=decim, - noise_cov=noise_cov, butterfly=butterfly, - show_scrollbars=show_scrollbars, - show_scalebars=show_scalebars, - epoch_colors=epoch_colors, event_id=event_id, - group_by=group_by, precompute=precompute, - use_opengl=use_opengl, theme=theme, - overview_mode=overview_mode) + def plot( + self, + picks=None, + scalings=None, + n_epochs=20, + n_channels=20, + title=None, + events=None, + event_color=None, + order=None, + show=True, + block=False, + decim="auto", + noise_cov=None, + butterfly=False, + show_scrollbars=True, + show_scalebars=True, + epoch_colors=None, + event_id=None, + group_by="type", + precompute=None, + use_opengl=None, + *, + theme=None, + overview_mode=None, + ): + return plot_epochs( + self, + picks=picks, + scalings=scalings, + n_epochs=n_epochs, + n_channels=n_channels, + title=title, + events=events, + event_color=event_color, + order=order, + show=show, + block=block, + decim=decim, + noise_cov=noise_cov, + butterfly=butterfly, + show_scrollbars=show_scrollbars, + show_scalebars=show_scalebars, + epoch_colors=epoch_colors, + event_id=event_id, + group_by=group_by, + precompute=precompute, + use_opengl=use_opengl, + theme=theme, + overview_mode=overview_mode, + ) @copy_function_doc_to_method_doc(plot_topo_image_epochs) - def plot_topo_image(self, layout=None, sigma=0., vmin=None, vmax=None, - colorbar=None, order=None, cmap='RdBu_r', - layout_scale=.95, title=None, scalings=None, - border='none', fig_facecolor='k', fig_background=None, - font_color='w', show=True): + def plot_topo_image( + self, + layout=None, + sigma=0.0, + vmin=None, + vmax=None, + colorbar=None, + order=None, + cmap="RdBu_r", + layout_scale=0.95, + title=None, + scalings=None, + border="none", + fig_facecolor="k", + fig_background=None, + font_color="w", + show=True, + ): return plot_topo_image_epochs( - self, layout=layout, sigma=sigma, vmin=vmin, vmax=vmax, - colorbar=colorbar, order=order, cmap=cmap, - layout_scale=layout_scale, title=title, scalings=scalings, - border=border, fig_facecolor=fig_facecolor, - fig_background=fig_background, font_color=font_color, show=show) + self, + layout=layout, + sigma=sigma, + vmin=vmin, + vmax=vmax, + colorbar=colorbar, + order=order, + cmap=cmap, + layout_scale=layout_scale, + title=title, + scalings=scalings, + border=border, + fig_facecolor=fig_facecolor, + fig_background=fig_background, + font_color=font_color, + show=show, + ) @verbose - def drop_bad(self, reject='existing', flat='existing', verbose=None): + def drop_bad(self, reject="existing", flat="existing", verbose=None): """Drop bad epochs without retaining the epochs data. Should be used before slicing operations. @@ -1206,20 +1400,19 @@ def drop_bad(self, reject='existing', flat='existing', verbose=None): subsequently be applied, `epochs.copy ` should be used. """ - if reject == 'existing': - if flat == 'existing' and self._bad_dropped: + if reject == "existing": + if flat == "existing" and self._bad_dropped: return reject = self.reject - if flat == 'existing': + if flat == "existing": flat = self.flat - if any(isinstance(rej, str) and rej != 'existing' for - rej in (reject, flat)): + if any(isinstance(rej, str) and rej != "existing" for rej in (reject, flat)): raise ValueError('reject and flat, if strings, must be "existing"') self._reject_setup(reject, flat) self._get_data(out=False, verbose=verbose) return self - def drop_log_stats(self, ignore=('IGNORED',)): + def drop_log_stats(self, ignore=("IGNORED",)): """Compute the channel stats based on a drop_log from Epochs. Parameters @@ -1239,33 +1432,81 @@ def drop_log_stats(self, ignore=('IGNORED',)): return _drop_log_stats(self.drop_log, ignore) @copy_function_doc_to_method_doc(plot_drop_log) - def plot_drop_log(self, threshold=0, n_max_plot=20, subject=None, - color=(0.9, 0.9, 0.9), width=0.8, ignore=('IGNORED',), - show=True): + def plot_drop_log( + self, + threshold=0, + n_max_plot=20, + subject=None, + color=(0.9, 0.9, 0.9), + width=0.8, + ignore=("IGNORED",), + show=True, + ): if not self._bad_dropped: - raise ValueError("You cannot use plot_drop_log since bad " - "epochs have not yet been dropped. " - "Use epochs.drop_bad().") - return plot_drop_log(self.drop_log, threshold, n_max_plot, subject, - color=color, width=width, ignore=ignore, - show=show) + raise ValueError( + "You cannot use plot_drop_log since bad " + "epochs have not yet been dropped. " + "Use epochs.drop_bad()." + ) + return plot_drop_log( + self.drop_log, + threshold, + n_max_plot, + subject, + color=color, + width=width, + ignore=ignore, + show=show, + ) @copy_function_doc_to_method_doc(plot_epochs_image) - def plot_image(self, picks=None, sigma=0., vmin=None, vmax=None, - colorbar=True, order=None, show=True, units=None, - scalings=None, cmap=None, fig=None, axes=None, - overlay_times=None, combine=None, group_by=None, - evoked=True, ts_args=None, title=None, clear=False): - return plot_epochs_image(self, picks=picks, sigma=sigma, vmin=vmin, - vmax=vmax, colorbar=colorbar, order=order, - show=show, units=units, scalings=scalings, - cmap=cmap, fig=fig, axes=axes, - overlay_times=overlay_times, combine=combine, - group_by=group_by, evoked=evoked, - ts_args=ts_args, title=title, clear=clear) + def plot_image( + self, + picks=None, + sigma=0.0, + vmin=None, + vmax=None, + colorbar=True, + order=None, + show=True, + units=None, + scalings=None, + cmap=None, + fig=None, + axes=None, + overlay_times=None, + combine=None, + group_by=None, + evoked=True, + ts_args=None, + title=None, + clear=False, + ): + return plot_epochs_image( + self, + picks=picks, + sigma=sigma, + vmin=vmin, + vmax=vmax, + colorbar=colorbar, + order=order, + show=show, + units=units, + scalings=scalings, + cmap=cmap, + fig=fig, + axes=axes, + overlay_times=overlay_times, + combine=combine, + group_by=group_by, + evoked=evoked, + ts_args=ts_args, + title=title, + clear=clear, + ) @verbose - def drop(self, indices, reason='USER', verbose=None): + def drop(self, indices, reason="USER", verbose=None): """Drop epochs based on indices or boolean mask. .. note:: The indices refer to the current set of undropped epochs @@ -1309,8 +1550,10 @@ def drop(self, indices, reason='USER', verbose=None): keep = np.setdiff1d(np.arange(len(self.events)), try_idx) self._getitem(keep, reason, copy=False, drop_event_id=False) count = len(try_idx) - logger.info('Dropped %d epoch%s: %s' % - (count, _pl(count), ', '.join(map(str, np.sort(try_idx))))) + logger.info( + "Dropped %d epoch%s: %s" + % (count, _pl(count), ", ".join(map(str, np.sort(try_idx)))) + ) return self @@ -1330,8 +1573,17 @@ def _project_epoch(self, epoch): return epoch @verbose - def _get_data(self, out=True, picks=None, item=None, *, units=None, - tmin=None, tmax=None, verbose=None): + def _get_data( + self, + out=True, + picks=None, + item=None, + *, + units=None, + tmin=None, + tmax=None, + verbose=None, + ): """Load all data, dropping bad epochs along the way. Parameters @@ -1354,13 +1606,19 @@ def _get_data(self, out=True, picks=None, item=None, *, units=None, if not out: # make sure first and last epoch not out of bounds of raw in_bounds = self.preload or ( - self._get_epoch_from_raw(idx=0) is not None and - self._get_epoch_from_raw(idx=-1) is not None) + self._get_epoch_from_raw(idx=0) is not None + and self._get_epoch_from_raw(idx=-1) is not None + ) # might be BaseEpochs or Epochs, only the latter has the attribute - reject_by_annotation = getattr(self, 'reject_by_annotation', False) - if (self.reject is None and self.flat is None and in_bounds and - self._reject_time is None and not reject_by_annotation): - logger.debug('_get_data is a noop, returning') + reject_by_annotation = getattr(self, "reject_by_annotation", False) + if ( + self.reject is None + and self.flat is None + and in_bounds + and self._reject_time is None + and not reject_by_annotation + ): + logger.debug("_get_data is a noop, returning") self._bad_dropped = True return None start, stop = self._handle_tmin_tmax(tmin, tmax) @@ -1369,8 +1627,9 @@ def _get_data(self, out=True, picks=None, item=None, *, units=None, item = slice(None) elif not self._bad_dropped: raise ValueError( - 'item must be None in epochs.get_data() unless bads have been ' - 'dropped. Consider using epochs.drop_bad().') + "item must be None in epochs.get_data() unless bads have been " + "dropped. Consider using epochs.drop_bad()." + ) select = self._item_to_select(item) # indices or slice use_idx = np.arange(len(self.events))[select] n_events = len(use_idx) @@ -1380,15 +1639,17 @@ def _get_data(self, out=True, picks=None, item=None, *, units=None, data = self._data else: # we start out with an empty array, allocate only if necessary - data = np.empty((0, len(self.info['ch_names']), len(self.times))) - msg = (f'for {n_events} events and {len(self._raw_times)} ' - 'original time points') + data = np.empty((0, len(self.info["ch_names"]), len(self.times))) + msg = ( + f"for {n_events} events and {len(self._raw_times)} " + "original time points" + ) if self._decim > 1: - msg += ' (prior to decimation)' + msg += " (prior to decimation)" if getattr(self._raw, "preload", False): - logger.info(f'Using data from preloaded Raw {msg} ...') + logger.info(f"Using data from preloaded Raw {msg} ...") else: - logger.info(f'Loading data {msg} ...') + logger.info(f"Loading data {msg} ...") orig_picks = picks if orig_picks is None: @@ -1418,15 +1679,16 @@ def _get_data(self, out=True, picks=None, item=None, *, units=None, for ii, idx in enumerate(use_idx): # faster to pre-allocate memory here epoch_noproj = self._get_epoch_from_raw(idx) - epoch_noproj = self._detrend_offset_decim( - epoch_noproj, detrend_picks) + epoch_noproj = self._detrend_offset_decim(epoch_noproj, detrend_picks) if self._do_delayed_proj: epoch_out = epoch_noproj else: epoch_out = self._project_epoch(epoch_noproj) if ii == 0: - data = np.empty((n_events, len(self.ch_names), - len(self.times)), dtype=epoch_out.dtype) + data = np.empty( + (n_events, len(self.ch_names), len(self.times)), + dtype=epoch_out.dtype, + ) data[ii] = epoch_out else: # bads need to be dropped, this might occur after a preload @@ -1448,12 +1710,12 @@ def _get_data(self, out=True, picks=None, item=None, *, units=None, else: # from disk epoch_noproj = self._get_epoch_from_raw(idx) epoch_noproj = self._detrend_offset_decim( - epoch_noproj, detrend_picks) + epoch_noproj, detrend_picks + ) epoch = self._project_epoch(epoch_noproj) epoch_out = epoch_noproj if self._do_delayed_proj else epoch - is_good, bad_tuple = self._is_good_epoch( - epoch, verbose=verbose) + is_good, bad_tuple = self._is_good_epoch(epoch, verbose=verbose) if not is_good: assert isinstance(bad_tuple, tuple) assert all(isinstance(x, str) for x in bad_tuple) @@ -1465,9 +1727,11 @@ def _get_data(self, out=True, picks=None, item=None, *, units=None, if out or self.preload: # faster to pre-allocate, then trim as necessary if n_out == 0 and not self.preload: - data = np.empty((n_events, epoch_out.shape[0], - epoch_out.shape[1]), - dtype=epoch_out.dtype, order='C') + data = np.empty( + (n_events, epoch_out.shape[0], epoch_out.shape[1]), + dtype=epoch_out.dtype, + order="C", + ) data[n_out] = epoch_out n_out += 1 self.drop_log = tuple(drop_log) @@ -1478,7 +1742,7 @@ def _get_data(self, out=True, picks=None, item=None, *, units=None, # adjust the data size if there is a reason to (output or update) if out or self.preload: - if data.flags['OWNDATA'] and data.flags['C_CONTIGUOUS']: + if data.flags["OWNDATA"] and data.flags["C_CONTIGUOUS"]: data.resize((n_out,) + data.shape[1:], refcheck=False) else: data = data[:n_out] @@ -1486,8 +1750,9 @@ def _get_data(self, out=True, picks=None, item=None, *, units=None, self._data = data # Now update our properties (excepd data, which is already fixed) - self._getitem(good_idx, None, copy=False, drop_event_id=False, - select_data=False) + self._getitem( + good_idx, None, copy=False, drop_event_id=False, select_data=False + ) if out: if orig_picks is not None: @@ -1504,13 +1769,13 @@ def _get_data(self, out=True, picks=None, item=None, *, units=None, def _detrend_picks(self): if self._do_baseline: return _pick_data_channels( - self.info, with_ref_meg=True, with_aux=True, exclude=()) + self.info, with_ref_meg=True, with_aux=True, exclude=() + ) else: return [] @fill_doc - def get_data(self, picks=None, item=None, units=None, tmin=None, - tmax=None): + def get_data(self, picks=None, item=None, units=None, tmin=None, tmax=None): """Get all epochs as a 3D array. Parameters @@ -1541,12 +1806,19 @@ def get_data(self, picks=None, item=None, units=None, tmin=None, data : array of shape (n_epochs, n_channels, n_times) A view on epochs data. """ - return self._get_data(picks=picks, item=item, units=units, tmin=tmin, - tmax=tmax) + return self._get_data(picks=picks, item=item, units=units, tmin=tmin, tmax=tmax) @verbose - def apply_function(self, fun, picks=None, dtype=None, n_jobs=None, - channel_wise=True, verbose=None, **kwargs): + def apply_function( + self, + fun, + picks=None, + dtype=None, + n_jobs=None, + channel_wise=True, + verbose=None, + **kwargs, + ): """Apply a function to a subset of channels. %(applyfun_summary_epochs)s @@ -1567,11 +1839,11 @@ def apply_function(self, fun, picks=None, dtype=None, n_jobs=None, self : instance of Epochs The epochs object with transformed data. """ - _check_preload(self, 'epochs.apply_function') + _check_preload(self, "epochs.apply_function") picks = _picks_to_idx(self.info, picks, exclude=(), with_ref_meg=False) if not callable(fun): - raise ValueError('fun needs to be a function') + raise ValueError("fun needs to be a function") data_in = self._data if dtype is not None and dtype != self._data.dtype: @@ -1584,11 +1856,13 @@ def apply_function(self, fun, picks=None, dtype=None, n_jobs=None, # modify data inplace to save memory for idx in picks: self._data[:, idx, :] = np.apply_along_axis( - _fun, -1, data_in[:, idx, :]) + _fun, -1, data_in[:, idx, :] + ) else: # use parallel function - data_picks_new = parallel(p_fun( - fun, data_in[:, p, :], **kwargs) for p in picks) + data_picks_new = parallel( + p_fun(fun, data_in[:, p, :], **kwargs) for p in picks + ) for pp, p in enumerate(picks): self._data[:, p, :] = data_picks_new[pp] else: @@ -1603,60 +1877,66 @@ def filename(self): def __repr__(self): """Build string representation.""" - s = ' %s events ' % len(self.events) - s += '(all good)' if self._bad_dropped else '(good & bad)' - s += ', %g – %g s' % (self.tmin, self.tmax) - s += ', baseline ' + s = " %s events " % len(self.events) + s += "(all good)" if self._bad_dropped else "(good & bad)" + s += ", %g – %g s" % (self.tmin, self.tmax) + s += ", baseline " if self.baseline is None: - s += 'off' + s += "off" else: - s += f'{self.baseline[0]:g} – {self.baseline[1]:g} s' + s += f"{self.baseline[0]:g} – {self.baseline[1]:g} s" if self.baseline != _check_baseline( - self.baseline, times=self.times, sfreq=self.info['sfreq'], - on_baseline_outside_data='adjust'): - s += ' (baseline period was cropped after baseline correction)' - - s += ', ~%s' % (sizeof_fmt(self._size),) - s += ', data%s loaded' % ('' if self.preload else ' not') - s += ', with metadata' if self.metadata is not None else '' + self.baseline, + times=self.times, + sfreq=self.info["sfreq"], + on_baseline_outside_data="adjust", + ): + s += " (baseline period was cropped after baseline correction)" + + s += ", ~%s" % (sizeof_fmt(self._size),) + s += ", data%s loaded" % ("" if self.preload else " not") + s += ", with metadata" if self.metadata is not None else "" max_events = 10 - counts = ['%r: %i' % (k, sum(self.events[:, 2] == v)) - for k, v in list(self.event_id.items())[:max_events]] + counts = [ + "%r: %i" % (k, sum(self.events[:, 2] == v)) + for k, v in list(self.event_id.items())[:max_events] + ] if len(self.event_id) > 0: - s += ',' + '\n '.join([''] + counts) + s += "," + "\n ".join([""] + counts) if len(self.event_id) > max_events: not_shown_events = len(self.event_id) - max_events s += f"\n and {not_shown_events} more events ..." class_name = self.__class__.__name__ - class_name = 'Epochs' if class_name == 'BaseEpochs' else class_name - return '<%s | %s>' % (class_name, s) + class_name = "Epochs" if class_name == "BaseEpochs" else class_name + return "<%s | %s>" % (class_name, s) @repr_html def _repr_html_(self): from .html_templates import repr_templates_env + if self.baseline is None: - baseline = 'off' + baseline = "off" else: - baseline = tuple([f'{b:.3f}' for b in self.baseline]) - baseline = f'{baseline[0]} – {baseline[1]} s' + baseline = tuple([f"{b:.3f}" for b in self.baseline]) + baseline = f"{baseline[0]} – {baseline[1]} s" if isinstance(self.event_id, dict): event_strings = [] for k, v in sorted(self.event_id.items()): n_events = sum(self.events[:, 2] == v) - event_strings.append(f'{k}: {n_events}') + event_strings.append(f"{k}: {n_events}") elif isinstance(self.event_id, list): event_strings = [] for k in self.event_id: n_events = sum(self.events[:, 2] == k) - event_strings.append(f'{k}: {n_events}') + event_strings.append(f"{k}: {n_events}") elif isinstance(self.event_id, int): n_events = len(self.events[:, 2]) - event_strings = [f'{self.event_id}: {n_events}'] + event_strings = [f"{self.event_id}: {n_events}"] else: event_strings = None - t = repr_templates_env.get_template('epochs.html.jinja') + t = repr_templates_env.get_template("epochs.html.jinja") t = t.render(epochs=self, baseline=baseline, events=event_strings) return t @@ -1683,20 +1963,22 @@ def crop(self, tmin=None, tmax=None, include_tmax=True, verbose=None): %(notes_tmax_included_by_default)s """ # XXX this could be made to work on non-preloaded data... - _check_preload(self, 'Modifying data of epochs') + _check_preload(self, "Modifying data of epochs") super().crop(tmin=tmin, tmax=tmax, include_tmax=include_tmax) # Adjust rejection period if self.reject_tmin is not None and self.reject_tmin < self.tmin: logger.info( - f'reject_tmin is not in epochs time interval. ' - f'Setting reject_tmin to epochs.tmin ({self.tmin} s)') + f"reject_tmin is not in epochs time interval. " + f"Setting reject_tmin to epochs.tmin ({self.tmin} s)" + ) self.reject_tmin = self.tmin if self.reject_tmax is not None and self.reject_tmax > self.tmax: logger.info( - f'reject_tmax is not in epochs time interval. ' - f'Setting reject_tmax to epochs.tmax ({self.tmax} s)') + f"reject_tmax is not in epochs time interval. " + f"Setting reject_tmax to epochs.tmax ({self.tmax} s)" + ) self.reject_tmax = self.tmax return self @@ -1717,7 +1999,7 @@ def __deepcopy__(self, memodict): for k, v in self.__dict__.items(): # drop_log is immutable and _raw is private (and problematic to # deepcopy) - if k in ('drop_log', '_raw', '_times_readonly'): + if k in ("drop_log", "_raw", "_times_readonly"): memodict[id(v)] = v else: v = deepcopy(v, memodict) @@ -1725,8 +2007,15 @@ def __deepcopy__(self, memodict): return result @verbose - def save(self, fname, split_size='2GB', fmt='single', overwrite=False, - split_naming='neuromag', verbose=None): + def save( + self, + fname, + split_size="2GB", + fmt="single", + overwrite=False, + split_naming="neuromag", + verbose=None, + ): """Save epochs in a fif file. Parameters @@ -1765,15 +2054,16 @@ def save(self, fname, split_size='2GB', fmt='single', overwrite=False, ----- Bad epochs will be dropped before saving the epochs to disk. """ - check_fname(fname, 'epochs', ('-epo.fif', '-epo.fif.gz', - '_epo.fif', '_epo.fif.gz')) + check_fname( + fname, "epochs", ("-epo.fif", "-epo.fif.gz", "_epo.fif", "_epo.fif.gz") + ) # check for file existence and expand `~` if present fname = str(_check_fname(fname=fname, overwrite=overwrite)) split_size_bytes = _get_split_size(split_size) - _check_option('fmt', fmt, ['single', 'double']) + _check_option("fmt", fmt, ["single", "double"]) # to know the length accurately. The get_data() call would drop # bad epochs anyway @@ -1781,12 +2071,12 @@ def save(self, fname, split_size='2GB', fmt='single', overwrite=False, # total_size tracks sizes that get split # over_size tracks overhead (tags, things that get written to each) if len(self) == 0: - warn('Saving epochs with no data') + warn("Saving epochs with no data") total_size = 0 else: d = self[0].get_data() # this should be guaranteed by subclasses - assert d.dtype in ('>f8', 'c16', 'f8", "c16", "= 1, n_parts if n_parts > 1: - logger.info(f'Splitting into {n_parts} parts') + logger.info(f"Splitting into {n_parts} parts") if n_parts > 100: # This must be an error raise ValueError( - f'Split size {split_size} would result in writing ' - f'{n_parts} files') + f"Split size {split_size} would result in writing " + f"{n_parts} files" + ) if len(self.drop_log) > 100000: - warn(f'epochs.drop_log contains {len(self.drop_log)} entries ' - f'which will incur up to a {sizeof_fmt(drop_size)} writing ' - f'overhead (per split file), consider using ' - f'epochs.reset_drop_log_selection() prior to writing') + warn( + f"epochs.drop_log contains {len(self.drop_log)} entries " + f"which will incur up to a {sizeof_fmt(drop_size)} writing " + f"overhead (per split file), consider using " + f"epochs.reset_drop_log_selection() prior to writing" + ) epoch_idxs = np.array_split(np.arange(n_epochs), n_parts) @@ -1857,11 +2150,12 @@ def save(self, fname, split_size='2GB', fmt='single', overwrite=False, this_epochs = self[epoch_idx] if n_parts > 1 else self # avoid missing event_ids in splits this_epochs.event_id = self.event_id - _save_split(this_epochs, fname, part_idx, n_parts, fmt, - split_naming, overwrite) + _save_split( + this_epochs, fname, part_idx, n_parts, fmt, split_naming, overwrite + ) @verbose - def export(self, fname, fmt='auto', *, overwrite=False, verbose=None): + def export(self, fname, fmt="auto", *, overwrite=False, verbose=None): """Export Epochs to external formats. %(export_fmt_support_epochs)s @@ -1885,9 +2179,10 @@ def export(self, fname, fmt='auto', *, overwrite=False, verbose=None): %(export_eeglab_note)s """ from .export import export_epochs + export_epochs(fname, self, fmt, overwrite=overwrite, verbose=verbose) - def equalize_event_counts(self, event_ids=None, method='mintime'): + def equalize_event_counts(self, event_ids=None, method="mintime"): """Equalize the number of trials in each condition. It tries to make the remaining epochs occurring as close as possible in @@ -1960,16 +2255,23 @@ def equalize_event_counts(self, event_ids=None, method='mintime'): event names were specified explicitly. """ from collections.abc import Iterable - _validate_type(event_ids, types=(Iterable, None), - item_name='event_ids', type_name='list-like or None') + + _validate_type( + event_ids, + types=(Iterable, None), + item_name="event_ids", + type_name="list-like or None", + ) if isinstance(event_ids, str): - raise TypeError(f'event_ids must be list-like or None, but ' - f'received a string: {event_ids}') + raise TypeError( + f"event_ids must be list-like or None, but " + f"received a string: {event_ids}" + ) if event_ids is None: event_ids = list(self.event_id) elif not event_ids: - raise ValueError('event_ids must have at least one element') + raise ValueError("event_ids must have at least one element") if not self._bad_dropped: self.drop_bad() @@ -1982,8 +2284,7 @@ def equalize_event_counts(self, event_ids=None, method='mintime'): tagging = False if "/" in "".join(ids): # make string inputs a list of length 1 - event_ids = [[x] if isinstance(x, str) else x - for x in event_ids] + event_ids = [[x] if isinstance(x, str) else x for x in event_ids] for ids_ in event_ids: # check if tagging is attempted if any([id_ not in ids for id_ in ids_]): tagging = True @@ -1991,19 +2292,24 @@ def equalize_event_counts(self, event_ids=None, method='mintime'): # 2a. for tags, find all the event_ids matched by the tags # 2b. for non-tag ids, just pass them directly # 3. do this for every input - event_ids = [[k for k in ids - if all((tag in k.split("/") - for tag in id_))] # ids matching all tags - if all(id__ not in ids for id__ in id_) - else id_ # straight pass for non-tag inputs - for id_ in event_ids] + event_ids = [ + [ + k for k in ids if all((tag in k.split("/") for tag in id_)) + ] # ids matching all tags + if all(id__ not in ids for id__ in id_) + else id_ # straight pass for non-tag inputs + for id_ in event_ids + ] for ii, id_ in enumerate(event_ids): if len(id_) == 0: - raise KeyError(f"{orig_ids[ii]} not found in the epoch " - "object's event_id.") + raise KeyError( + f"{orig_ids[ii]} not found in the epoch " "object's event_id." + ) elif len({sub_id in ids for sub_id in id_}) != 1: - err = ("Don't mix hierarchical and regular event_ids" - " like in \'%s\'." % ", ".join(id_)) + err = ( + "Don't mix hierarchical and regular event_ids" + " like in '%s'." % ", ".join(id_) + ) raise ValueError(err) # raise for non-orthogonal tags @@ -2011,9 +2317,11 @@ def equalize_event_counts(self, event_ids=None, method='mintime'): events_ = [set(self[x].events[:, 0]) for x in event_ids] doubles = events_[0].intersection(events_[1]) if len(doubles): - raise ValueError("The two sets of epochs are " - "overlapping. Provide an " - "orthogonal selection.") + raise ValueError( + "The two sets of epochs are " + "overlapping. Provide an " + "orthogonal selection." + ) for eq in event_ids: eq_inds.append(self._keys_to_idx(eq)) @@ -2022,14 +2330,25 @@ def equalize_event_counts(self, event_ids=None, method='mintime'): indices = _get_drop_indices(event_times, method) # need to re-index indices indices = np.concatenate([e[idx] for e, idx in zip(eq_inds, indices)]) - self.drop(indices, reason='EQUALIZED_COUNT') + self.drop(indices, reason="EQUALIZED_COUNT") # actually remove the indices return self, indices @verbose - def compute_psd(self, method='multitaper', fmin=0, fmax=np.inf, tmin=None, - tmax=None, picks=None, proj=False, *, n_jobs=1, - verbose=None, **method_kw): + def compute_psd( + self, + method="multitaper", + fmin=0, + fmax=np.inf, + tmin=None, + tmax=None, + picks=None, + proj=False, + *, + n_jobs=1, + verbose=None, + **method_kw, + ): """Perform spectral analysis on sensor data. Parameters @@ -2061,17 +2380,47 @@ def compute_psd(self, method='multitaper', fmin=0, fmax=np.inf, tmin=None, self._set_legacy_nfft_default(tmin, tmax, method, method_kw) return EpochsSpectrum( - self, method=method, fmin=fmin, fmax=fmax, tmin=tmin, tmax=tmax, - picks=picks, proj=proj, n_jobs=n_jobs, verbose=verbose, - **method_kw) + self, + method=method, + fmin=fmin, + fmax=fmax, + tmin=tmin, + tmax=tmax, + picks=picks, + proj=proj, + n_jobs=n_jobs, + verbose=verbose, + **method_kw, + ) @verbose - def plot_psd(self, fmin=0, fmax=np.inf, tmin=None, tmax=None, picks=None, - proj=False, *, method='auto', average=False, dB=True, - estimate='auto', xscale='linear', area_mode='std', - area_alpha=0.33, color='black', line_alpha=None, - spatial_colors=True, sphere=None, exclude='bads', ax=None, - show=True, n_jobs=1, verbose=None, **method_kw): + def plot_psd( + self, + fmin=0, + fmax=np.inf, + tmin=None, + tmax=None, + picks=None, + proj=False, + *, + method="auto", + average=False, + dB=True, + estimate="auto", + xscale="linear", + area_mode="std", + area_alpha=0.33, + color="black", + line_alpha=None, + spatial_colors=True, + sphere=None, + exclude="bads", + ax=None, + show=True, + n_jobs=1, + verbose=None, + **method_kw, + ): """%(plot_psd_doc)s. Parameters @@ -2115,17 +2464,44 @@ def plot_psd(self, fmin=0, fmax=np.inf, tmin=None, tmax=None, picks=None, %(notes_plot_psd_meth)s """ return super().plot_psd( - fmin=fmin, fmax=fmax, tmin=tmin, tmax=tmax, picks=picks, proj=proj, - reject_by_annotation=False, method=method, average=average, dB=dB, - estimate=estimate, xscale=xscale, area_mode=area_mode, - area_alpha=area_alpha, color=color, line_alpha=line_alpha, - spatial_colors=spatial_colors, sphere=sphere, exclude=exclude, - ax=ax, show=show, n_jobs=n_jobs, verbose=verbose, **method_kw) + fmin=fmin, + fmax=fmax, + tmin=tmin, + tmax=tmax, + picks=picks, + proj=proj, + reject_by_annotation=False, + method=method, + average=average, + dB=dB, + estimate=estimate, + xscale=xscale, + area_mode=area_mode, + area_alpha=area_alpha, + color=color, + line_alpha=line_alpha, + spatial_colors=spatial_colors, + sphere=sphere, + exclude=exclude, + ax=ax, + show=show, + n_jobs=n_jobs, + verbose=verbose, + **method_kw, + ) @verbose - def to_data_frame(self, picks=None, index=None, - scalings=None, copy=True, long_format=False, - time_format=None, *, verbose=None): + def to_data_frame( + self, + picks=None, + index=None, + scalings=None, + copy=True, + long_format=False, + time_format=None, + *, + verbose=None, + ): """Export data in tabular structure as a pandas DataFrame. Channels are converted to columns in the DataFrame. By default, @@ -2155,12 +2531,12 @@ def to_data_frame(self, picks=None, index=None, # check pandas once here, instead of in each private utils function pd = _check_pandas_installed() # noqa # arg checking - valid_index_args = ['time', 'epoch', 'condition'] - valid_time_formats = ['ms', 'timedelta'] + valid_index_args = ["time", "epoch", "condition"] + valid_time_formats = ["ms", "timedelta"] index = _check_pandas_index_arguments(index, valid_index_args) time_format = _check_time_format(time_format, valid_time_formats) # get data - picks = _picks_to_idx(self.info, picks, 'all', exclude=()) + picks = _picks_to_idx(self.info, picks, "all", exclude=()) data = self.get_data()[:, picks, :] times = self.times n_epochs, n_picks, n_times = data.shape @@ -2172,18 +2548,25 @@ def to_data_frame(self, picks=None, index=None, mindex = list() times = np.tile(times, n_epochs) times = _convert_times(self, times, time_format) - mindex.append(('time', times)) + mindex.append(("time", times)) rev_event_id = {v: k for k, v in self.event_id.items()} conditions = [rev_event_id[k] for k in self.events[:, 2]] - mindex.append(('condition', np.repeat(conditions, n_times))) - mindex.append(('epoch', np.repeat(self.selection, n_times))) + mindex.append(("condition", np.repeat(conditions, n_times))) + mindex.append(("epoch", np.repeat(self.selection, n_times))) assert all(len(mdx) == len(mindex[0]) for mdx in mindex) # build DataFrame - df = _build_data_frame(self, data, picks, long_format, mindex, index, - default_index=['condition', 'epoch', 'time']) + df = _build_data_frame( + self, + data, + picks, + long_format, + mindex, + index, + default_index=["condition", "epoch", "time"], + ) return df - def as_type(self, ch_type='grad', mode='fast'): + def as_type(self, ch_type="grad", mode="fast"): """Compute virtual epochs using interpolated fields. .. Warning:: Using virtual epochs to compute inverse can yield @@ -2213,10 +2596,11 @@ def as_type(self, ch_type='grad', mode='fast'): .. versionadded:: 0.20.0 """ from .forward import _as_meg_type_inst + return _as_meg_type_inst(self, ch_type=ch_type, mode=mode) -def _drop_log_stats(drop_log, ignore=('IGNORED',)): +def _drop_log_stats(drop_log, ignore=("IGNORED",)): """Compute drop log stats. Parameters @@ -2231,17 +2615,28 @@ def _drop_log_stats(drop_log, ignore=('IGNORED',)): perc : float Total percentage of epochs dropped. """ - if not isinstance(drop_log, tuple) or \ - not all(isinstance(d, tuple) for d in drop_log) or \ - not all(isinstance(s, str) for d in drop_log for s in d): - raise TypeError('drop_log must be a tuple of tuple of str') - perc = 100 * np.mean([len(d) > 0 for d in drop_log - if not any(r in ignore for r in d)]) + if ( + not isinstance(drop_log, tuple) + or not all(isinstance(d, tuple) for d in drop_log) + or not all(isinstance(s, str) for d in drop_log for s in d) + ): + raise TypeError("drop_log must be a tuple of tuple of str") + perc = 100 * np.mean( + [len(d) > 0 for d in drop_log if not any(r in ignore for r in d)] + ) return perc -def make_metadata(events, event_id, tmin, tmax, sfreq, - row_events=None, keep_first=None, keep_last=None): +def make_metadata( + events, + event_id, + tmin, + tmax, + sfreq, + row_events=None, + keep_first=None, + keep_last=None, +): """Generate metadata from events for use with `mne.Epochs`. This function mimics the epoching process (it constructs time windows @@ -2369,16 +2764,13 @@ def make_metadata(events, event_id, tmin, tmax, sfreq, """ pd = _check_pandas_installed() - _validate_type(event_id, types=(dict,), item_name='event_id') - _validate_type(row_events, types=(None, str, list, tuple), - item_name='row_events') - _validate_type(keep_first, types=(None, str, list, tuple), - item_name='keep_first') - _validate_type(keep_last, types=(None, str, list, tuple), - item_name='keep_last') + _validate_type(event_id, types=(dict,), item_name="event_id") + _validate_type(row_events, types=(None, str, list, tuple), item_name="row_events") + _validate_type(keep_first, types=(None, str, list, tuple), item_name="keep_first") + _validate_type(keep_last, types=(None, str, list, tuple), item_name="keep_last") if not event_id: - raise ValueError('event_id dictionary must contain at least one entry') + raise ValueError("event_id dictionary must contain at least one entry") def _ensure_list(x): if x is None: @@ -2394,26 +2786,29 @@ def _ensure_list(x): keep_first_and_last = set(keep_first) & set(keep_last) if keep_first_and_last: - raise ValueError(f'The event names in keep_first and keep_last must ' - f'be mutually exclusive. Specified in both: ' - f'{", ".join(sorted(keep_first_and_last))}') + raise ValueError( + f"The event names in keep_first and keep_last must " + f"be mutually exclusive. Specified in both: " + f'{", ".join(sorted(keep_first_and_last))}' + ) del keep_first_and_last - for param_name, values in dict(keep_first=keep_first, - keep_last=keep_last).items(): + for param_name, values in dict(keep_first=keep_first, keep_last=keep_last).items(): for first_last_event_name in values: try: match_event_names(event_id, [first_last_event_name]) except KeyError: raise ValueError( f'Event "{first_last_event_name}", specified in ' - f'{param_name}, cannot be found in event_id dictionary') + f"{param_name}, cannot be found in event_id dictionary" + ) event_name_diff = sorted(set(row_events) - set(event_id.keys())) if event_name_diff: raise ValueError( - f'Present in row_events, but missing from event_id: ' - f'{", ".join(event_name_diff)}') + f"Present in row_events, but missing from event_id: " + f'{", ".join(event_name_diff)}' + ) del event_name_diff # First and last sample of each epoch, relative to the time-locked event @@ -2425,12 +2820,12 @@ def _ensure_list(x): # We create the DataFrame before subsetting the events so we end up with # indices corresponding to the original event indices. Not used for now, # but might come in handy sometime later - events_df = pd.DataFrame(events, columns=('sample', 'prev_id', 'id')) + events_df = pd.DataFrame(events, columns=("sample", "prev_id", "id")) id_to_name_map = {v: k for k, v in event_id.items()} # Only keep events that are of interest events = events[np.in1d(events[:, 2], list(event_id.values()))] - events_df = events_df.loc[events_df['id'].isin(event_id.values()), :] + events_df = events_df.loc[events_df["id"].isin(event_id.values()), :] # Prepare & condition the metadata DataFrame @@ -2438,26 +2833,27 @@ def _ensure_list(x): # event_id.keys() and keep_first / keep_last simultaneously keep_first_cols = [col for col in keep_first if col not in event_id] keep_last_cols = [col for col in keep_last if col not in event_id] - first_cols = [f'first_{col}' for col in keep_first_cols] - last_cols = [f'last_{col}' for col in keep_last_cols] - - columns = ['event_name', - *event_id.keys(), - *keep_first_cols, - *keep_last_cols, - *first_cols, - *last_cols] + first_cols = [f"first_{col}" for col in keep_first_cols] + last_cols = [f"last_{col}" for col in keep_last_cols] + + columns = [ + "event_name", + *event_id.keys(), + *keep_first_cols, + *keep_last_cols, + *first_cols, + *last_cols, + ] data = np.empty((len(events_df), len(columns))) metadata = pd.DataFrame(data=data, columns=columns, index=events_df.index) # Event names - metadata.iloc[:, 0] = '' + metadata.iloc[:, 0] = "" # Event times start_idx = 1 - stop_idx = (start_idx + len(event_id.keys()) + - len(keep_first_cols + keep_last_cols)) + stop_idx = start_idx + len(event_id.keys()) + len(keep_first_cols + keep_last_cols) metadata.iloc[:, start_idx:stop_idx] = np.nan # keep_first and keep_last names @@ -2467,22 +2863,23 @@ def _ensure_list(x): # We're all set, let's iterate over all eventns and fill in in the # respective cells in the metadata. We will subset this to include only # `row_events` later - for row_event in events_df.itertuples(name='RowEvent'): + for row_event in events_df.itertuples(name="RowEvent"): row_idx = row_event.Index - metadata.loc[row_idx, 'event_name'] = \ - id_to_name_map[row_event.id] + metadata.loc[row_idx, "event_name"] = id_to_name_map[row_event.id] # Determine which events fall into the current epoch window_start_sample = row_event.sample + start_sample window_stop_sample = row_event.sample + stop_sample events_in_window = events_df.loc[ - (events_df['sample'] >= window_start_sample) & - (events_df['sample'] <= window_stop_sample), :] + (events_df["sample"] >= window_start_sample) + & (events_df["sample"] <= window_stop_sample), + :, + ] assert not events_in_window.empty # Store the metadata - for event in events_in_window.itertuples(name='Event'): + for event in events_in_window.itertuples(name="Event"): event_sample = event.sample - row_event.sample event_time = event_sample / sfreq event_time = 0 if np.isclose(event_time, 0) else event_time @@ -2499,31 +2896,29 @@ def _ensure_list(x): # Handle keep_first and keep_last event aggregation for event_group_name in keep_first + keep_last: - if event_name not in match_event_names( - event_id, [event_group_name] - ): + if event_name not in match_event_names(event_id, [event_group_name]): continue if event_group_name in keep_first: - first_last_col = f'first_{event_group_name}' + first_last_col = f"first_{event_group_name}" else: - first_last_col = f'last_{event_group_name}' + first_last_col = f"last_{event_group_name}" old_time = metadata.loc[row_idx, event_group_name] if not np.isnan(old_time): - if ((event_group_name in keep_first and - old_time <= event_time) or - (event_group_name in keep_last and - old_time >= event_time)): + if (event_group_name in keep_first and old_time <= event_time) or ( + event_group_name in keep_last and old_time >= event_time + ): continue if event_group_name not in event_id: # This is an HED. Strip redundant information from the # event name - name = (event_name - .replace(event_group_name, '') - .replace('//', '/') - .strip('/')) + name = ( + event_name.replace(event_group_name, "") + .replace("//", "/") + .strip("/") + ) metadata.loc[row_idx, first_last_col] = name del name @@ -2531,12 +2926,11 @@ def _ensure_list(x): # Only keep rows of interest if row_events: - event_id_timelocked = {name: val for name, val in event_id.items() - if name in row_events} - events = events[np.in1d(events[:, 2], - list(event_id_timelocked.values()))] - metadata = metadata.loc[ - metadata['event_name'].isin(event_id_timelocked)] + event_id_timelocked = { + name: val for name, val in event_id.items() if name in row_events + } + events = events[np.in1d(events[:, 2], list(event_id_timelocked.values()))] + metadata = metadata.loc[metadata["event_name"].isin(event_id_timelocked)] assert len(events) == len(metadata) event_id = event_id_timelocked @@ -2648,15 +3042,34 @@ class Epochs(BaseEpochs): """ @verbose - def __init__(self, raw, events, event_id=None, tmin=-0.2, tmax=0.5, - baseline=(None, 0), picks=None, preload=False, reject=None, - flat=None, proj=True, decim=1, reject_tmin=None, - reject_tmax=None, detrend=None, on_missing='raise', - reject_by_annotation=True, metadata=None, - event_repeated='error', verbose=None): # noqa: D102 + def __init__( + self, + raw, + events, + event_id=None, + tmin=-0.2, + tmax=0.5, + baseline=(None, 0), + picks=None, + preload=False, + reject=None, + flat=None, + proj=True, + decim=1, + reject_tmin=None, + reject_tmax=None, + detrend=None, + on_missing="raise", + reject_by_annotation=True, + metadata=None, + event_repeated="error", + verbose=None, + ): # noqa: D102 if not isinstance(raw, BaseRaw): - raise ValueError('The first argument to `Epochs` must be an ' - 'instance of mne.io.BaseRaw') + raise ValueError( + "The first argument to `Epochs` must be an " + "instance of mne.io.BaseRaw" + ) info = deepcopy(raw.info) # proj is on when applied in Raw @@ -2665,17 +3078,34 @@ def __init__(self, raw, events, event_id=None, tmin=-0.2, tmax=0.5, self.reject_by_annotation = reject_by_annotation # keep track of original sfreq (needed for annotations) - raw_sfreq = raw.info['sfreq'] + raw_sfreq = raw.info["sfreq"] # call BaseEpochs constructor super(Epochs, self).__init__( - info, None, events, event_id, tmin, tmax, - metadata=metadata, baseline=baseline, raw=raw, picks=picks, - reject=reject, flat=flat, decim=decim, reject_tmin=reject_tmin, - reject_tmax=reject_tmax, detrend=detrend, - proj=proj, on_missing=on_missing, preload_at_end=preload, - event_repeated=event_repeated, verbose=verbose, - raw_sfreq=raw_sfreq, annotations=raw.annotations) + info, + None, + events, + event_id, + tmin, + tmax, + metadata=metadata, + baseline=baseline, + raw=raw, + picks=picks, + reject=reject, + flat=flat, + decim=decim, + reject_tmin=reject_tmin, + reject_tmax=reject_tmax, + detrend=detrend, + proj=proj, + on_missing=on_missing, + preload_at_end=preload, + event_repeated=event_repeated, + verbose=verbose, + raw_sfreq=raw_sfreq, + annotations=raw.annotations, + ) @verbose def _get_epoch_from_raw(self, idx, verbose=None): @@ -2690,10 +3120,12 @@ def _get_epoch_from_raw(self, idx, verbose=None): """ if self._raw is None: # This should never happen, as raw=None only if preload=True - raise ValueError('An error has occurred, no valid raw file found. ' - 'Please report this to the mne-python ' - 'developers.') - sfreq = self._raw.info['sfreq'] + raise ValueError( + "An error has occurred, no valid raw file found. " + "Please report this to the mne-python " + "developers." + ) + sfreq = self._raw.info["sfreq"] event_samp = self.events[idx, 0] # Read a data segment from "start" to "stop" in samples first_samp = self._raw.first_samp @@ -2715,10 +3147,15 @@ def _get_epoch_from_raw(self, idx, verbose=None): diff = int(round((self._raw_times[-1] - reject_tmax) * sfreq)) reject_stop = stop - diff - logger.debug(' Getting epoch for %d-%d' % (start, stop)) - data = self._raw._check_bad_segment(start, stop, self.picks, - reject_start, reject_stop, - self.reject_by_annotation) + logger.debug(" Getting epoch for %d-%d" % (start, stop)) + data = self._raw._check_bad_segment( + start, + stop, + self.picks, + reject_start, + reject_stop, + self.reject_by_annotation, + ) return data @@ -2800,38 +3237,72 @@ class EpochsArray(BaseEpochs): """ @verbose - def __init__(self, data, info, events=None, tmin=0, event_id=None, - reject=None, flat=None, reject_tmin=None, - reject_tmax=None, baseline=None, proj=True, - on_missing='raise', metadata=None, selection=None, - *, drop_log=None, raw_sfreq=None, verbose=None): # noqa: D102 + def __init__( + self, + data, + info, + events=None, + tmin=0, + event_id=None, + reject=None, + flat=None, + reject_tmin=None, + reject_tmax=None, + baseline=None, + proj=True, + on_missing="raise", + metadata=None, + selection=None, + *, + drop_log=None, + raw_sfreq=None, + verbose=None, + ): # noqa: D102 dtype = np.complex128 if np.any(np.iscomplex(data)) else np.float64 data = np.asanyarray(data, dtype=dtype) if data.ndim != 3: - raise ValueError('Data must be a 3D array of shape (n_epochs, ' - 'n_channels, n_samples)') + raise ValueError( + "Data must be a 3D array of shape (n_epochs, " "n_channels, n_samples)" + ) - if len(info['ch_names']) != data.shape[1]: - raise ValueError('Info and data must have same number of ' - 'channels.') + if len(info["ch_names"]) != data.shape[1]: + raise ValueError("Info and data must have same number of " "channels.") if events is None: n_epochs = len(data) events = _gen_events(n_epochs) info = info.copy() # do not modify original info - tmax = (data.shape[2] - 1) / info['sfreq'] + tmin + tmax = (data.shape[2] - 1) / info["sfreq"] + tmin super(EpochsArray, self).__init__( - info, data, events, event_id, tmin, tmax, baseline, - reject=reject, flat=flat, reject_tmin=reject_tmin, - reject_tmax=reject_tmax, decim=1, metadata=metadata, - selection=selection, proj=proj, on_missing=on_missing, - drop_log=drop_log, raw_sfreq=raw_sfreq, verbose=verbose) + info, + data, + events, + event_id, + tmin, + tmax, + baseline, + reject=reject, + flat=flat, + reject_tmin=reject_tmin, + reject_tmax=reject_tmax, + decim=1, + metadata=metadata, + selection=selection, + proj=proj, + on_missing=on_missing, + drop_log=drop_log, + raw_sfreq=raw_sfreq, + verbose=verbose, + ) if self.baseline is not None: self._do_baseline = True - if len(events) != np.in1d(self.events[:, 2], - list(self.event_id.values())).sum(): - raise ValueError('The events must only contain event numbers from ' - 'event_id') + if ( + len(events) + != np.in1d(self.events[:, 2], list(self.event_id.values())).sum() + ): + raise ValueError( + "The events must only contain event numbers from " "event_id" + ) detrend_picks = self._detrend_picks for e in self._data: # This is safe without assignment b/c there is no decim @@ -2875,19 +3346,20 @@ def combine_event_ids(epochs, old_event_ids, new_event_id, copy=True): new_event_id = {str(new_event_id): new_event_id} else: if not isinstance(new_event_id, dict): - raise ValueError('new_event_id must be a dict or int') + raise ValueError("new_event_id must be a dict or int") if not len(list(new_event_id.keys())) == 1: - raise ValueError('new_event_id dict must have one entry') + raise ValueError("new_event_id dict must have one entry") new_event_num = list(new_event_id.values())[0] new_event_num = operator.index(new_event_num) if new_event_num in epochs.event_id.values(): - raise ValueError('new_event_id value must not already exist') + raise ValueError("new_event_id value must not already exist") # could use .pop() here, but if a latter one doesn't exist, we're # in trouble, so run them all here and pop() later old_event_nums = np.array([epochs.event_id[key] for key in old_event_ids]) # find the ones to replace - inds = np.any(epochs.events[:, 2][:, np.newaxis] == - old_event_nums[np.newaxis, :], axis=1) + inds = np.any( + epochs.events[:, 2][:, np.newaxis] == old_event_nums[np.newaxis, :], axis=1 + ) # replace the event numbers in the events list epochs.events[inds, 2] = new_event_num # delete old entries @@ -2898,7 +3370,7 @@ def combine_event_ids(epochs, old_event_ids, new_event_id, copy=True): return epochs -def equalize_epoch_counts(epochs_list, method='mintime'): +def equalize_epoch_counts(epochs_list, method="mintime"): """Equalize the number of trials in multiple Epoch instances. Parameters @@ -2927,7 +3399,7 @@ def equalize_epoch_counts(epochs_list, method='mintime'): >>> equalize_epoch_counts([epochs1, epochs2]) # doctest: +SKIP """ if not all(isinstance(e, BaseEpochs) for e in epochs_list): - raise ValueError('All inputs must be Epochs instances') + raise ValueError("All inputs must be Epochs instances") # make sure bad epochs are dropped for e in epochs_list: @@ -2936,21 +3408,21 @@ def equalize_epoch_counts(epochs_list, method='mintime'): event_times = [e.events[:, 0] for e in epochs_list] indices = _get_drop_indices(event_times, method) for e, inds in zip(epochs_list, indices): - e.drop(inds, reason='EQUALIZED_COUNT') + e.drop(inds, reason="EQUALIZED_COUNT") def _get_drop_indices(event_times, method): """Get indices to drop from multiple event timing lists.""" small_idx = np.argmin([e.shape[0] for e in event_times]) small_e_times = event_times[small_idx] - _check_option('method', method, ['mintime', 'truncate']) + _check_option("method", method, ["mintime", "truncate"]) indices = list() for e in event_times: - if method == 'mintime': + if method == "mintime": mask = _minimize_time_diff(small_e_times, e) else: mask = np.ones(e.shape[0], dtype=bool) - mask[small_e_times.shape[0]:] = False + mask[small_e_times.shape[0] :] = False indices.append(np.where(np.logical_not(mask))[0]) return indices @@ -2959,6 +3431,7 @@ def _get_drop_indices(event_times, method): def _minimize_time_diff(t_shorter, t_longer): """Find a boolean mask to minimize timing differences.""" from scipy.interpolate import interp1d + keep = np.ones((len(t_longer)), dtype=bool) # special case: length zero or one if len(t_shorter) < 2: # interp1d won't work @@ -2971,8 +3444,7 @@ def _minimize_time_diff(t_shorter, t_longer): x1 = np.arange(len(t_shorter)) # The first set of keep masks to test kwargs = dict(copy=False, bounds_error=False, assume_sorted=True) - shorter_interp = interp1d(x1, t_shorter, fill_value=t_shorter[-1], - **kwargs) + shorter_interp = interp1d(x1, t_shorter, fill_value=t_shorter[-1], **kwargs) for ii in range(len(t_longer) - len(t_shorter)): scores.fill(np.inf) # set up the keep masks to test, eliminating any rows that are already @@ -2982,9 +3454,9 @@ def _minimize_time_diff(t_shorter, t_longer): # Check every possible removal to see if it minimizes x2 = np.arange(len(t_longer) - ii - 1) t_keeps = np.array([t_longer[km] for km in keep_mask]) - longer_interp = interp1d(x2, t_keeps, axis=1, - fill_value=t_keeps[:, -1], - **kwargs) + longer_interp = interp1d( + x2, t_keeps, axis=1, fill_value=t_keeps[:, -1], **kwargs + ) d1 = longer_interp(x1) - t_shorter d2 = shorter_interp(x2) - t_keeps scores[keep] = np.abs(d1, d1).sum(axis=1) + np.abs(d2, d2).sum(axis=1) @@ -2993,8 +3465,16 @@ def _minimize_time_diff(t_shorter, t_longer): @verbose -def _is_good(e, ch_names, channel_type_idx, reject, flat, full_report=False, - ignore_chs=[], verbose=None): +def _is_good( + e, + ch_names, + channel_type_idx, + reject, + flat, + full_report=False, + ignore_chs=[], + verbose=None, +): """Test if data segment e is good according to reject and flat. If full_report=True, it will give True/False as well as a list of all @@ -3003,9 +3483,8 @@ def _is_good(e, ch_names, channel_type_idx, reject, flat, full_report=False, bad_tuple = tuple() has_printed = False checkable = np.ones(len(ch_names), dtype=bool) - checkable[np.array([c in ignore_chs - for c in ch_names], dtype=bool)] = False - for refl, f, t in zip([reject, flat], [np.greater, np.less], ['', 'flat']): + checkable[np.array([c in ignore_chs for c in ch_names], dtype=bool)] = False + for refl, f, t in zip([reject, flat], [np.greater, np.less], ["", "flat"]): if refl is not None: for key, thresh in refl.items(): idx = channel_type_idx[key] @@ -3014,14 +3493,17 @@ def _is_good(e, ch_names, channel_type_idx, reject, flat, full_report=False, e_idx = e[idx] deltas = np.max(e_idx, axis=1) - np.min(e_idx, axis=1) checkable_idx = checkable[idx] - idx_deltas = np.where(np.logical_and(f(deltas, thresh), - checkable_idx))[0] + idx_deltas = np.where( + np.logical_and(f(deltas, thresh), checkable_idx) + )[0] if len(idx_deltas) > 0: bad_names = [ch_names[idx[i]] for i in idx_deltas] - if (not has_printed): - logger.info(' Rejecting %s epoch based on %s : ' - '%s' % (t, name, bad_names)) + if not has_printed: + logger.info( + " Rejecting %s epoch based on %s : " + "%s" % (t, name, bad_names) + ) has_printed = True if not full_report: return False @@ -3051,7 +3533,7 @@ def _read_one_epoch_file(f, tree, preload): metadata = None metadata_tree = dir_tree_find(tree, FIFF.FIFFB_MNE_METADATA) if len(metadata_tree) > 0: - for dd in metadata_tree[0]['directory']: + for dd in metadata_tree[0]["directory"]: kind = dd.kind pos = dd.pos if kind == FIFF.FIFF_DESCRIPTION: @@ -3063,7 +3545,7 @@ def _read_one_epoch_file(f, tree, preload): processed = dir_tree_find(meas, FIFF.FIFFB_PROCESSED_DATA) del meas if len(processed) == 0: - raise ValueError('Could not find processed data') + raise ValueError("Could not find processed data") epochs_node = dir_tree_find(tree, FIFF.FIFFB_MNE_EPOCHS) if len(epochs_node) == 0: @@ -3073,7 +3555,7 @@ def _read_one_epoch_file(f, tree, preload): if len(epochs_node) == 0: epochs_node = dir_tree_find(tree, 122) # 122 used before v0.11 if len(epochs_node) == 0: - raise ValueError('Could not find epochs data') + raise ValueError("Could not find epochs data") my_epochs = epochs_node[0] @@ -3086,9 +3568,9 @@ def _read_one_epoch_file(f, tree, preload): drop_log = None raw_sfreq = None reject_params = {} - for k in range(my_epochs['nent']): - kind = my_epochs['directory'][k].kind - pos = my_epochs['directory'][k].pos + for k in range(my_epochs["nent"]): + kind = my_epochs["directory"][k].kind + pos = my_epochs["directory"][k].pos if kind == FIFF.FIFF_FIRST_SAMPLE: tag = read_tag(fid, pos) first = int(tag.data.item()) @@ -3128,44 +3610,52 @@ def _read_one_epoch_file(f, tree, preload): baseline = (bmin, bmax) n_samp = last - first + 1 - logger.info(' Found the data of interest:') - logger.info(' t = %10.2f ... %10.2f ms' - % (1000 * first / info['sfreq'], - 1000 * last / info['sfreq'])) - if info['comps'] is not None: - logger.info(' %d CTF compensation matrices available' - % len(info['comps'])) + logger.info(" Found the data of interest:") + logger.info( + " t = %10.2f ... %10.2f ms" + % (1000 * first / info["sfreq"], 1000 * last / info["sfreq"]) + ) + if info["comps"] is not None: + logger.info( + " %d CTF compensation matrices available" % len(info["comps"]) + ) # Inspect the data if data_tag is None: - raise ValueError('Epochs data not found') - epoch_shape = (len(info['ch_names']), n_samp) + raise ValueError("Epochs data not found") + epoch_shape = (len(info["ch_names"]), n_samp) size_expected = len(events) * np.prod(epoch_shape) # on read double-precision is always used if data_tag.type == FIFF.FIFFT_FLOAT: datatype = np.float64 - fmt = '>f4' + fmt = ">f4" elif data_tag.type == FIFF.FIFFT_DOUBLE: datatype = np.float64 - fmt = '>f8' + fmt = ">f8" elif data_tag.type == FIFF.FIFFT_COMPLEX_FLOAT: datatype = np.complex128 - fmt = '>c8' + fmt = ">c8" elif data_tag.type == FIFF.FIFFT_COMPLEX_DOUBLE: datatype = np.complex128 - fmt = '>c16' + fmt = ">c16" fmt_itemsize = np.dtype(fmt).itemsize assert fmt_itemsize in (4, 8, 16) size_actual = data_tag.size // fmt_itemsize - 16 // fmt_itemsize if not size_actual == size_expected: - raise ValueError('Incorrect number of samples (%d instead of %d)' - % (size_actual, size_expected)) + raise ValueError( + "Incorrect number of samples (%d instead of %d)" + % (size_actual, size_expected) + ) # Calibration factors - cals = np.array([[info['chs'][k]['cal'] * - info['chs'][k].get('scale', 1.0)] - for k in range(info['nchan'])], np.float64) + cals = np.array( + [ + [info["chs"][k]["cal"] * info["chs"][k].get("scale", 1.0)] + for k in range(info["nchan"]) + ], + np.float64, + ) # Read the data if preload: @@ -3173,10 +3663,13 @@ def _read_one_epoch_file(f, tree, preload): data *= cals # Put it all together - tmin = first / info['sfreq'] - tmax = last / info['sfreq'] - event_id = ({str(e): e for e in np.unique(events[:, 2])} - if mappings is None else mappings) + tmin = first / info["sfreq"] + tmax = last / info["sfreq"] + event_id = ( + {str(e): e for e in np.unique(events[:, 2])} + if mappings is None + else mappings + ) # In case epochs didn't have a FIFF.FIFF_MNE_EPOCHS_SELECTION tag # (version < 0.8): if selection is None: @@ -3184,9 +3677,25 @@ def _read_one_epoch_file(f, tree, preload): if drop_log is None: drop_log = ((),) * len(events) - return (info, data, data_tag, events, event_id, metadata, tmin, tmax, - baseline, selection, drop_log, epoch_shape, cals, reject_params, - fmt, annotations, raw_sfreq) + return ( + info, + data, + data_tag, + events, + event_id, + metadata, + tmin, + tmax, + baseline, + selection, + drop_log, + epoch_shape, + cals, + reject_params, + fmt, + annotations, + raw_sfreq, + ) @verbose @@ -3213,8 +3722,9 @@ def read_epochs(fname, proj=True, preload=True, verbose=None): class _RawContainer: """Helper for a raw data container.""" - def __init__(self, fid, data_tag, event_samps, epoch_shape, - cals, fmt): # noqa: D102 + def __init__( + self, fid, data_tag, event_samps, epoch_shape, cals, fmt + ): # noqa: D102 self.fid = fid self.data_tag = data_tag self.event_samps = event_samps @@ -3248,36 +3758,51 @@ class EpochsFIF(BaseEpochs): """ @verbose - def __init__(self, fname, proj=True, preload=True, - verbose=None): # noqa: D102 + def __init__(self, fname, proj=True, preload=True, verbose=None): # noqa: D102 if _path_like(fname): check_fname( - fname=fname, filetype='epochs', - endings=('-epo.fif', '-epo.fif.gz', '_epo.fif', '_epo.fif.gz') - ) - fname = str( - _check_fname(fname=fname, must_exist=True, overwrite="read") + fname=fname, + filetype="epochs", + endings=("-epo.fif", "-epo.fif.gz", "_epo.fif", "_epo.fif.gz"), ) + fname = str(_check_fname(fname=fname, must_exist=True, overwrite="read")) elif not preload: - raise ValueError('preload must be used with file-like objects') + raise ValueError("preload must be used with file-like objects") fnames = [fname] ep_list = list() raw = list() for fname in fnames: fname_rep = _get_fname_rep(fname) - logger.info('Reading %s ...' % fname_rep) + logger.info("Reading %s ..." % fname_rep) fid, tree, _ = fiff_open(fname, preload=preload) next_fname = _get_next_fname(fid, fname, tree) - (info, data, data_tag, events, event_id, metadata, tmin, tmax, - baseline, selection, drop_log, epoch_shape, cals, - reject_params, fmt, annotations, raw_sfreq) = \ - _read_one_epoch_file(fid, tree, preload) + ( + info, + data, + data_tag, + events, + event_id, + metadata, + tmin, + tmax, + baseline, + selection, + drop_log, + epoch_shape, + cals, + reject_params, + fmt, + annotations, + raw_sfreq, + ) = _read_one_epoch_file(fid, tree, preload) if (events[:, 0] < 0).any(): events = events.copy() - warn('Incorrect events detected on disk, setting event ' - 'numbers to consecutive increasing integers') + warn( + "Incorrect events detected on disk, setting event " + "numbers to consecutive increasing integers" + ) events[:, 0] = np.arange(1, len(events) + 1) # here we ignore missing events, since users should already be # aware of missing events if they have saved data that way @@ -3285,35 +3810,63 @@ def __init__(self, fname, proj=True, preload=True, # correction (data is being baseline-corrected when written to # disk) epoch = BaseEpochs( - info, data, events, event_id, tmin, tmax, + info, + data, + events, + event_id, + tmin, + tmax, baseline=None, - metadata=metadata, on_missing='ignore', - selection=selection, drop_log=drop_log, - proj=False, verbose=False, raw_sfreq=raw_sfreq) + metadata=metadata, + on_missing="ignore", + selection=selection, + drop_log=drop_log, + proj=False, + verbose=False, + raw_sfreq=raw_sfreq, + ) epoch.baseline = baseline epoch._do_baseline = False # might be superfluous but won't hurt ep_list.append(epoch) if not preload: # store everything we need to index back to the original data - raw.append(_RawContainer(fiff_open(fname)[0], data_tag, - events[:, 0].copy(), epoch_shape, - cals, fmt)) + raw.append( + _RawContainer( + fiff_open(fname)[0], + data_tag, + events[:, 0].copy(), + epoch_shape, + cals, + fmt, + ) + ) if next_fname is not None: fnames.append(next_fname) unsafe_annot_add = raw_sfreq is None - (info, data, raw_sfreq, events, event_id, tmin, tmax, metadata, - baseline, selection, drop_log) = _concatenate_epochs( + ( + info, + data, + raw_sfreq, + events, + event_id, + tmin, + tmax, + metadata, + baseline, + selection, + drop_log, + ) = _concatenate_epochs( ep_list, with_data=preload, add_offset=False, - on_mismatch='raise', + on_mismatch="raise", ) # we need this uniqueness for non-preloaded data to work properly if len(np.unique(events[:, 0])) != len(events): - raise RuntimeError('Event time samples were not unique') + raise RuntimeError("Event time samples were not unique") # correct the drop log assert len(drop_log) % len(fnames) == 0 @@ -3323,7 +3876,7 @@ def __init__(self, fname, proj=True, preload=True, for i1, i2 in zip(offsets[:-1], offsets[1:]): other_log = drop_log[i1:i2] for k, (a, b) in enumerate(zip(drop_log, other_log)): - if a == ('IGNORED',) and b != ('IGNORED',): + if a == ("IGNORED",) and b != ("IGNORED",): drop_log[k] = b drop_log = tuple(drop_log[:step]) @@ -3331,12 +3884,26 @@ def __init__(self, fname, proj=True, preload=True, # again, ensure we're retaining the baseline period originally loaded # from disk without trying to re-apply baseline correction super(EpochsFIF, self).__init__( - info, data, events, event_id, tmin, tmax, - baseline=None, raw=raw, - proj=proj, preload_at_end=False, on_missing='ignore', - selection=selection, drop_log=drop_log, filename=fname_rep, - metadata=metadata, verbose=verbose, raw_sfreq=raw_sfreq, - annotations=annotations, **reject_params) + info, + data, + events, + event_id, + tmin, + tmax, + baseline=None, + raw=raw, + proj=proj, + preload_at_end=False, + on_missing="ignore", + selection=selection, + drop_log=drop_log, + filename=fname_rep, + metadata=metadata, + verbose=verbose, + raw_sfreq=raw_sfreq, + annotations=annotations, + **reject_params, + ) self.baseline = baseline self._do_baseline = False # use the private property instead of drop_bad so that epochs @@ -3361,8 +3928,10 @@ def _get_epoch_from_raw(self, idx, verbose=None): break else: # read the correct subset of the data - raise RuntimeError('Correct epoch could not be found, please ' - 'contact mne-python developers') + raise RuntimeError( + "Correct epoch could not be found, please " + "contact mne-python developers" + ) # the following is equivalent to this, but faster: # # >>> data = read_tag(raw.fid, raw.data_tag.pos).data.astype(float) @@ -3372,10 +3941,10 @@ def _get_epoch_from_raw(self, idx, verbose=None): # Eventually this could be refactored in io/tag.py if other functions # could make use of it raw.fid.seek(raw.data_tag.pos + offset, 0) - if fmt == '>c8': - read_fmt = '>f4' - elif fmt == '>c16': - read_fmt = '>f8' + if fmt == ">c8": + read_fmt = ">f4" + elif fmt == ">c16": + read_fmt = ">f8" else: read_fmt = fmt data = np.frombuffer(raw.fid.read(size), read_fmt) @@ -3406,9 +3975,11 @@ def bootstrap(epochs, random_state=None): The bootstrap samples """ if not epochs.preload: - raise RuntimeError('Modifying data of epochs is only supported ' - 'when preloading is used. Use preload=True ' - 'in the constructor.') + raise RuntimeError( + "Modifying data of epochs is only supported " + "when preloading is used. Use preload=True " + "in the constructor." + ) rng = check_random_state(random_state) epochs_bootstrap = epochs.copy() @@ -3430,27 +4001,35 @@ def _check_merge_epochs(epochs_list): raise NotImplementedError("Epochs with unequal values for baseline") -def _concatenate_epochs(epochs_list, *, with_data=True, add_offset=True, - on_mismatch='raise'): +def _concatenate_epochs( + epochs_list, *, with_data=True, add_offset=True, on_mismatch="raise" +): """Auxiliary function for concatenating epochs.""" if not isinstance(epochs_list, (list, tuple)): - raise TypeError('epochs_list must be a list or tuple, got %s' - % (type(epochs_list),)) + raise TypeError( + "epochs_list must be a list or tuple, got %s" % (type(epochs_list),) + ) # to make warning messages only occur once during concatenation warned = False for ei, epochs in enumerate(epochs_list): if not isinstance(epochs, BaseEpochs): - raise TypeError('epochs_list[%d] must be an instance of Epochs, ' - 'got %s' % (ei, type(epochs))) + raise TypeError( + "epochs_list[%d] must be an instance of Epochs, " + "got %s" % (ei, type(epochs)) + ) - if (getattr(epochs, 'annotations', None) is not None and - len(epochs.annotations) > 0 and - not warned): + if ( + getattr(epochs, "annotations", None) is not None + and len(epochs.annotations) > 0 + and not warned + ): warned = True - warn('Concatenation of Annotations within Epochs is not supported ' - 'yet. All annotations will be dropped.') + warn( + "Concatenation of Annotations within Epochs is not supported " + "yet. All annotations will be dropped." + ) # create a copy, so that the Annotations are not modified in place # from the original object @@ -3470,40 +4049,42 @@ def _concatenate_epochs(epochs_list, *, with_data=True, add_offset=True, event_id = deepcopy(out.event_id) selection = out.selection # offset is the last epoch + tmax + 10 second - shift = int((10 + tmax) * out.info['sfreq']) + shift = int((10 + tmax) * out.info["sfreq"]) events_offset = int(np.max(events[0][:, 0])) + shift events_overflow = False warned = False for ii, epochs in enumerate(epochs_list[1:], 1): - _ensure_infos_match(epochs.info, info, f'epochs[{ii}]', - on_mismatch=on_mismatch) + _ensure_infos_match(epochs.info, info, f"epochs[{ii}]", on_mismatch=on_mismatch) if not np.allclose(epochs.times, epochs_list[0].times): - raise ValueError('Epochs must have same times') + raise ValueError("Epochs must have same times") if epochs.baseline != baseline: - raise ValueError('Baseline must be same for all epochs') + raise ValueError("Baseline must be same for all epochs") if epochs._raw_sfreq != raw_sfreq and not warned: warned = True - warn('The original raw sampling rate of the Epochs does not ' - 'match for all Epochs. Please proceed cautiously.') + warn( + "The original raw sampling rate of the Epochs does not " + "match for all Epochs. Please proceed cautiously." + ) # compare event_id common_keys = list(set(event_id).intersection(set(epochs.event_id))) for key in common_keys: if not event_id[key] == epochs.event_id[key]: - msg = ('event_id values must be the same for identical keys ' - 'for all concatenated epochs. Key "{}" maps to {} in ' - 'some epochs and to {} in others.') - raise ValueError(msg.format(key, event_id[key], - epochs.event_id[key])) + msg = ( + "event_id values must be the same for identical keys " + 'for all concatenated epochs. Key "{}" maps to {} in ' + "some epochs and to {} in others." + ) + raise ValueError(msg.format(key, event_id[key], epochs.event_id[key])) if with_data: epochs.drop_bad() offsets.append(len(epochs)) evs = epochs.events.copy() if len(epochs.events) == 0: - warn('One of the Epochs objects to concatenate was empty.') + warn("One of the Epochs objects to concatenate was empty.") elif add_offset: # We need to cast to a native Python int here to detect an # overflow of a numpy int32 (which is the default on windows) @@ -3511,9 +4092,11 @@ def _concatenate_epochs(epochs_list, *, with_data=True, add_offset=True, evs[:, 0] += events_offset events_offset += max_timestamp + shift if events_offset > INT32_MAX: - warn(f'Event number greater than {INT32_MAX} created, ' - 'events[:, 0] will be assigned consecutive increasing ' - 'integer values') + warn( + f"Event number greater than {INT32_MAX} created, " + "events[:, 0] will be assigned consecutive increasing " + "integer values" + ) events_overflow = True add_offset = False # we no longer need to add offset events.append(evs) @@ -3531,9 +4114,10 @@ def _concatenate_epochs(epochs_list, *, with_data=True, add_offset=True, if n_have == 0: metadata = None elif n_have != len(metadata): - raise ValueError('%d of %d epochs instances have metadata, either ' - 'all or none must have metadata' - % (n_have, len(metadata))) + raise ValueError( + "%d of %d epochs instances have metadata, either " + "all or none must have metadata" % (n_have, len(metadata)) + ) else: pd = _check_pandas_installed(strict=False) if pd is not False: @@ -3549,15 +4133,28 @@ def _concatenate_epochs(epochs_list, *, with_data=True, add_offset=True, if data is None: data = np.empty( (offsets[-1], len(out.ch_names), len(out.times)), - dtype=this_data.dtype) + dtype=this_data.dtype, + ) data[start:stop] = this_data - return (info, data, raw_sfreq, events, event_id, tmin, tmax, metadata, - baseline, selection, drop_log) + return ( + info, + data, + raw_sfreq, + events, + event_id, + tmin, + tmax, + metadata, + baseline, + selection, + drop_log, + ) @verbose -def concatenate_epochs(epochs_list, add_offset=True, *, on_mismatch='raise', - verbose=None): +def concatenate_epochs( + epochs_list, add_offset=True, *, on_mismatch="raise", verbose=None +): """Concatenate a list of `~mne.Epochs` into one `~mne.Epochs` object. .. note:: Unlike `~mne.concatenate_raws`, this function does **not** @@ -3586,8 +4183,19 @@ def concatenate_epochs(epochs_list, add_offset=True, *, on_mismatch='raise', ----- .. versionadded:: 0.9.0 """ - (info, data, raw_sfreq, events, event_id, tmin, tmax, metadata, - baseline, selection, drop_log) = _concatenate_epochs( + ( + info, + data, + raw_sfreq, + events, + event_id, + tmin, + tmax, + metadata, + baseline, + selection, + drop_log, + ) = _concatenate_epochs( epochs_list, with_data=True, add_offset=add_offset, @@ -3595,19 +4203,39 @@ def concatenate_epochs(epochs_list, add_offset=True, *, on_mismatch='raise', ) selection = np.where([len(d) == 0 for d in drop_log])[0] out = EpochsArray( - data=data, info=info, events=events, event_id=event_id, - tmin=tmin, baseline=baseline, selection=selection, drop_log=drop_log, - proj=False, on_missing='ignore', metadata=metadata, - raw_sfreq=raw_sfreq) + data=data, + info=info, + events=events, + event_id=event_id, + tmin=tmin, + baseline=baseline, + selection=selection, + drop_log=drop_log, + proj=False, + on_missing="ignore", + metadata=metadata, + raw_sfreq=raw_sfreq, + ) out.drop_bad() return out @verbose -def average_movements(epochs, head_pos=None, orig_sfreq=None, picks=None, - origin='auto', weight_all=True, int_order=8, ext_order=3, - destination=None, ignore_ref=False, return_mapping=False, - mag_scale=100., verbose=None): +def average_movements( + epochs, + head_pos=None, + orig_sfreq=None, + picks=None, + origin="auto", + weight_all=True, + int_order=8, + ext_order=3, + destination=None, + ignore_ref=False, + return_mapping=False, + mag_scale=100.0, + verbose=None, +): """Average data using Maxwell filtering, transforming using head positions. Parameters @@ -3668,37 +4296,48 @@ def average_movements(epochs, head_pos=None, orig_sfreq=None, picks=None, of children in MEG: Quantification, effects on source estimation, and compensation. NeuroImage 40:541–550, 2008. """ # noqa: E501 - from .preprocessing.maxwell import (_trans_sss_basis, _reset_meg_bads, - _check_usable, _col_norm_pinv, - _get_n_moments, _get_mf_picks_fix_mags, - _prep_mf_coils, _check_destination, - _remove_meg_projs_comps, - _get_coil_scale, _get_sensor_operator) + from .preprocessing.maxwell import ( + _trans_sss_basis, + _reset_meg_bads, + _check_usable, + _col_norm_pinv, + _get_n_moments, + _get_mf_picks_fix_mags, + _prep_mf_coils, + _check_destination, + _remove_meg_projs_comps, + _get_coil_scale, + _get_sensor_operator, + ) + if head_pos is None: - raise TypeError('head_pos must be provided and cannot be None') + raise TypeError("head_pos must be provided and cannot be None") from .chpi import head_pos_to_trans_rot_t + if not isinstance(epochs, BaseEpochs): - raise TypeError('epochs must be an instance of Epochs, not %s' - % (type(epochs),)) - orig_sfreq = epochs.info['sfreq'] if orig_sfreq is None else orig_sfreq + raise TypeError( + "epochs must be an instance of Epochs, not %s" % (type(epochs),) + ) + orig_sfreq = epochs.info["sfreq"] if orig_sfreq is None else orig_sfreq orig_sfreq = float(orig_sfreq) if isinstance(head_pos, np.ndarray): head_pos = head_pos_to_trans_rot_t(head_pos) trn, rot, t = head_pos del head_pos _check_usable(epochs, ignore_ref) - origin = _check_origin(origin, epochs.info, 'head') + origin = _check_origin(origin, epochs.info, "head") recon_trans = _check_destination(destination, epochs.info, True) - logger.info('Aligning and averaging up to %s epochs' - % (len(epochs.events))) + logger.info("Aligning and averaging up to %s epochs" % (len(epochs.events))) if not np.array_equal(epochs.events[:, 0], np.unique(epochs.events[:, 0])): - raise RuntimeError('Epochs must have monotonically increasing events') + raise RuntimeError("Epochs must have monotonically increasing events") info_to = epochs.info.copy() - meg_picks, mag_picks, grad_picks, good_mask, _ = \ - _get_mf_picks_fix_mags(info_to, int_order, ext_order, ignore_ref) + meg_picks, mag_picks, grad_picks, good_mask, _ = _get_mf_picks_fix_mags( + info_to, int_order, ext_order, ignore_ref + ) coil_scale, mag_scale = _get_coil_scale( - meg_picks, mag_picks, grad_picks, mag_scale, info_to) + meg_picks, mag_picks, grad_picks, mag_scale, info_to + ) mult = _get_sensor_operator(epochs, meg_picks) n_channels, n_times = len(epochs.ch_names), len(epochs.times) other_picks = np.setdiff1d(np.arange(n_channels), meg_picks) @@ -3711,37 +4350,36 @@ def average_movements(epochs, head_pos=None, orig_sfreq=None, picks=None, # remove MEG bads in "to" info _reset_meg_bads(info_to) # set up variables - w_sum = 0. + w_sum = 0.0 n_in, n_out = _get_n_moments([int_order, ext_order]) - S_decomp = 0. # this will end up being a weighted average + S_decomp = 0.0 # this will end up being a weighted average last_trans = None decomp_coil_scale = coil_scale[good_mask] - exp = dict(int_order=int_order, ext_order=ext_order, head_frame=True, - origin=origin) + exp = dict(int_order=int_order, ext_order=ext_order, head_frame=True, origin=origin) n_in = _get_n_moments(int_order) for ei, epoch in enumerate(epochs): event_time = epochs.events[epochs._current - 1, 0] / orig_sfreq use_idx = np.where(t <= event_time)[0] if len(use_idx) == 0: - trans = info_to['dev_head_t']['trans'] + trans = info_to["dev_head_t"]["trans"] else: use_idx = use_idx[-1] - trans = np.vstack([np.hstack([rot[use_idx], trn[[use_idx]].T]), - [[0., 0., 0., 1.]]]) - loc_str = ', '.join('%0.1f' % tr for tr in (trans[:3, 3] * 1000)) + trans = np.vstack( + [np.hstack([rot[use_idx], trn[[use_idx]].T]), [[0.0, 0.0, 0.0, 1.0]]] + ) + loc_str = ", ".join("%0.1f" % tr for tr in (trans[:3, 3] * 1000)) if last_trans is None or not np.allclose(last_trans, trans): - logger.info(' Processing epoch %s (device location: %s mm)' - % (ei + 1, loc_str)) + logger.info( + " Processing epoch %s (device location: %s mm)" % (ei + 1, loc_str) + ) reuse = False last_trans = trans else: - logger.info(' Processing epoch %s (device location: same)' - % (ei + 1,)) + logger.info(" Processing epoch %s (device location: same)" % (ei + 1,)) reuse = True epoch = epoch.copy() # because we operate inplace if not reuse: - S = _trans_sss_basis(exp, all_coils, trans, - coil_scale=decomp_coil_scale) + S = _trans_sss_basis(exp, all_coils, trans, coil_scale=decomp_coil_scale) # Get the weight from the un-regularized version (eq. 44) weight = np.linalg.norm(S[:, :n_in]) # XXX Eventually we could do cross-talk and fine-cal here @@ -3762,12 +4400,12 @@ def average_movements(epochs, head_pos=None, orig_sfreq=None, picks=None, S_decomp /= w_sum # Get recon matrix # (We would need to include external here for regularization to work) - exp['ext_order'] = 0 + exp["ext_order"] = 0 S_recon = _trans_sss_basis(exp, all_coils_recon, recon_trans) if mult is not None: S_decomp = mult @ S_decomp S_recon = mult @ S_recon - exp['ext_order'] = ext_order + exp["ext_order"] = ext_order # We could determine regularization on basis of destination basis # matrix, restricted to good channels, as regularizing individual # matrices within the loop above does not seem to work. But in @@ -3781,19 +4419,26 @@ def average_movements(epochs, head_pos=None, orig_sfreq=None, picks=None, mapping = np.dot(S_recon, pS_ave) # Apply mapping data[meg_picks] = np.dot(mapping, data[meg_picks[good_mask]]) - info_to['dev_head_t'] = recon_trans # set the reconstruction transform - evoked = epochs._evoked_from_epoch_data(data, info_to, picks, - n_events=count, kind='average', - comment=epochs._name) + info_to["dev_head_t"] = recon_trans # set the reconstruction transform + evoked = epochs._evoked_from_epoch_data( + data, info_to, picks, n_events=count, kind="average", comment=epochs._name + ) _remove_meg_projs_comps(evoked, ignore_ref) - logger.info('Created Evoked dataset from %s epochs' % (count,)) + logger.info("Created Evoked dataset from %s epochs" % (count,)) return (evoked, mapping) if return_mapping else evoked @verbose -def make_fixed_length_epochs(raw, duration=1., preload=False, - reject_by_annotation=True, proj=True, overlap=0., - id=1, verbose=None): +def make_fixed_length_epochs( + raw, + duration=1.0, + preload=False, + reject_by_annotation=True, + proj=True, + overlap=0.0, + id=1, + verbose=None, +): """Divide continuous raw data into equal-sized consecutive epochs. Parameters @@ -3829,10 +4474,17 @@ def make_fixed_length_epochs(raw, duration=1., preload=False, ----- .. versionadded:: 0.20 """ - events = make_fixed_length_events(raw, id=id, duration=duration, - overlap=overlap) - delta = 1. / raw.info['sfreq'] - return Epochs(raw, events, event_id=[id], tmin=0, tmax=duration - delta, - baseline=None, preload=preload, - reject_by_annotation=reject_by_annotation, proj=proj, - verbose=verbose) + events = make_fixed_length_events(raw, id=id, duration=duration, overlap=overlap) + delta = 1.0 / raw.info["sfreq"] + return Epochs( + raw, + events, + event_id=[id], + tmin=0, + tmax=duration - delta, + baseline=None, + preload=preload, + reject_by_annotation=reject_by_annotation, + proj=proj, + verbose=verbose, + ) diff --git a/mne/event.py b/mne/event.py index 68f943c3b49..63cb994db8a 100644 --- a/mne/event.py +++ b/mne/event.py @@ -12,9 +12,19 @@ import numpy as np -from .utils import (check_fname, logger, verbose, _get_stim_channel, warn, - _validate_type, _check_option, fill_doc, _check_fname, - _on_missing, _check_on_missing) +from .utils import ( + check_fname, + logger, + verbose, + _get_stim_channel, + warn, + _validate_type, + _check_option, + fill_doc, + _check_fname, + _on_missing, + _check_on_missing, +) from .io.constants import FIFF from .io.tree import dir_tree_find from .io.tag import read_tag @@ -75,8 +85,9 @@ def pick_events(events, include=None, exclude=None, step=False): return events -def define_target_events(events, reference_id, target_id, sfreq, tmin, tmax, - new_id=None, fill_na=None): +def define_target_events( + events, reference_id, target_id, sfreq, tmin, tmax, new_id=None, fill_na=None +): """Define new events by co-occurrence of existing events. This function can be used to evaluate events depending on the @@ -125,8 +136,11 @@ def define_target_events(events, reference_id, target_id, sfreq, tmin, tmax, if event[2] == reference_id: lower = event[0] + imin upper = event[0] + imax - res = events[(events[:, 0] > lower) & - (events[:, 0] < upper) & (events[:, 2] == target_id)] + res = events[ + (events[:, 0] > lower) + & (events[:, 0] < upper) + & (events[:, 2] == target_id) + ] if res.any(): lag += [event[0] - res[0][0]] event[2] = new_id @@ -138,8 +152,8 @@ def define_target_events(events, reference_id, target_id, sfreq, tmin, tmax, new_events = np.array(new_events) - with np.errstate(invalid='ignore'): # casting nans - lag = np.abs(lag, dtype='f8') + with np.errstate(invalid="ignore"): # casting nans + lag = np.abs(lag, dtype="f8") if lag.any(): lag *= tsample else: @@ -155,12 +169,12 @@ def _read_events_fif(fid, tree): if len(events) == 0: fid.close() - raise ValueError('Could not find event data') + raise ValueError("Could not find event data") events = events[0] event_list = None event_id = None - for d in events['directory']: + for d in events["directory"]: kind = d.kind pos = d.pos if kind == FIFF.FIFF_MNE_EVENT_LIST: @@ -169,21 +183,20 @@ def _read_events_fif(fid, tree): event_list.shape = (-1, 3) break if event_list is None: - raise ValueError('Could not find any events') - for d in events['directory']: + raise ValueError("Could not find any events") + for d in events["directory"]: kind = d.kind pos = d.pos if kind == FIFF.FIFF_DESCRIPTION: tag = read_tag(fid, pos) event_id = tag.data - m_ = [[s[::-1] for s in m[::-1].split(':', 1)] - for m in event_id.split(';')] + m_ = [[s[::-1] for s in m[::-1].split(":", 1)] for m in event_id.split(";")] event_id = {k: int(v) for v, k in m_} break elif kind == FIFF.FIFF_MNE_EVENT_COMMENTS: tag = read_tag(fid, pos) event_id = tag.data - event_id = event_id.tobytes().decode('latin-1').split('\x00')[:-1] + event_id = event_id.tobytes().decode("latin-1").split("\x00")[:-1] assert len(event_id) == len(event_list) event_id = {k: v[2] for k, v in zip(event_id, event_list)} break @@ -191,8 +204,15 @@ def _read_events_fif(fid, tree): @verbose -def read_events(filename, include=None, exclude=None, mask=None, - mask_type='and', return_event_id=False, verbose=None): +def read_events( + filename, + include=None, + exclude=None, + mask=None, + mask_type="and", + return_event_id=False, + verbose=None, +): """Read :term:`events` from fif or text file. See :ref:`tut-events-vs-annotations` and :ref:`tut-event-arrays` @@ -247,11 +267,22 @@ def read_events(filename, include=None, exclude=None, mask=None, For more information on ``mask`` and ``mask_type``, see :func:`mne.find_events`. """ - check_fname(filename, 'events', ('.eve', '-eve.fif', '-eve.fif.gz', - '-eve.lst', '-eve.txt', '_eve.fif', - '_eve.fif.gz', '_eve.lst', '_eve.txt', - '-annot.fif', # MNE-C annot - )) + check_fname( + filename, + "events", + ( + ".eve", + "-eve.fif", + "-eve.fif.gz", + "-eve.lst", + "-eve.txt", + "_eve.fif", + "_eve.fif.gz", + "_eve.lst", + "_eve.txt", + "-annot.fif", # MNE-C annot + ), + ) filename = Path(filename) if filename.suffix in (".fif", ".gz"): fid, tree, _ = fiff_open(filename) @@ -264,7 +295,7 @@ def read_events(filename, include=None, exclude=None, mask=None, # eve/lst files had a second float column that will raise errors lines = np.loadtxt(filename, dtype=np.float64).astype(int) if len(lines) == 0: - raise ValueError('No text lines found') + raise ValueError("No text lines found") if lines.ndim == 1: # Special case for only one event lines = lines[np.newaxis, :] @@ -274,13 +305,12 @@ def read_events(filename, include=None, exclude=None, mask=None, elif len(lines[0]) == 3: goods = [0, 1, 2] else: - raise ValueError('Unknown number of columns in event text file') + raise ValueError("Unknown number of columns in event text file") event_list = lines[:, goods] - if (mask is not None and event_list.shape[0] > 0 and - event_list[0, 2] == 0): + if mask is not None and event_list.shape[0] > 0 and event_list[0, 2] == 0: event_list = event_list[1:] - warn('first row of event file discarded (zero-valued)') + warn("first row of event file discarded (zero-valued)") event_id = None event_list = pick_events(event_list, include, exclude) @@ -289,12 +319,13 @@ def read_events(filename, include=None, exclude=None, mask=None, event_list = _mask_trigs(event_list, mask, mask_type) masked_len = event_list.shape[0] if masked_len < unmasked_len: - warn('{} of {} events masked'.format(unmasked_len - masked_len, - unmasked_len)) + warn( + "{} of {} events masked".format(unmasked_len - masked_len, unmasked_len) + ) out = event_list if return_event_id: if event_id is None: - raise RuntimeError('No event_id found in the file') + raise RuntimeError("No event_id found in the file") out = (out, event_id) return out @@ -321,26 +352,38 @@ def write_events(filename, events, *, overwrite=False, verbose=None): read_events """ filename = _check_fname(filename, overwrite=overwrite) - check_fname(filename, 'events', ('.eve', '-eve.fif', '-eve.fif.gz', - '-eve.lst', '-eve.txt', '_eve.fif', - '_eve.fif.gz', '_eve.lst', '_eve.txt')) - if filename.suffix in ('.fif', '.gz'): + check_fname( + filename, + "events", + ( + ".eve", + "-eve.fif", + "-eve.fif.gz", + "-eve.lst", + "-eve.txt", + "_eve.fif", + "_eve.fif.gz", + "_eve.lst", + "_eve.txt", + ), + ) + if filename.suffix in (".fif", ".gz"): # Start writing... with start_and_end_file(filename) as fid: start_block(fid, FIFF.FIFFB_MNE_EVENTS) write_int(fid, FIFF.FIFF_MNE_EVENT_LIST, events.T) end_block(fid, FIFF.FIFFB_MNE_EVENTS) else: - with open(filename, 'w') as f: + with open(filename, "w") as f: for e in events: - f.write('%6d %6d %3d\n' % tuple(e)) + f.write("%6d %6d %3d\n" % tuple(e)) def _find_stim_steps(data, first_samp, pad_start=None, pad_stop=None, merge=0): changed = np.diff(data, axis=1) != 0 idx = np.where(np.all(changed, axis=0))[0] if len(idx) == 0: - return np.empty((0, 3), dtype='int32') + return np.empty((0, 3), dtype="int32") pre_step = data[0, idx] idx += 1 @@ -361,7 +404,7 @@ def _find_stim_steps(data, first_samp, pad_start=None, pad_stop=None, merge=0): if merge != 0: diff = np.diff(steps[:, 0]) - idx = (diff <= abs(merge)) + idx = diff <= abs(merge) if np.any(idx): where = np.where(idx)[0] keep = np.logical_not(idx) @@ -374,15 +417,14 @@ def _find_stim_steps(data, first_samp, pad_start=None, pad_stop=None, merge=0): steps[where, 2] = steps[where + 1, 2] keep = np.insert(keep, 0, True) - is_step = (steps[:, 1] != steps[:, 2]) + is_step = steps[:, 1] != steps[:, 2] keep = np.logical_and(keep, is_step) steps = steps[keep] return steps -def find_stim_steps(raw, pad_start=None, pad_stop=None, merge=0, - stim_channel=None): +def find_stim_steps(raw, pad_start=None, pad_stop=None, merge=0, stim_channel=None): """Find all steps in data from a stim channel. Parameters @@ -422,24 +464,33 @@ def find_stim_steps(raw, pad_start=None, pad_stop=None, merge=0, # pull stim channel from config if necessary stim_channel = _get_stim_channel(stim_channel, raw.info) - picks = pick_channels( - raw.info['ch_names'], include=stim_channel, ordered=False) + picks = pick_channels(raw.info["ch_names"], include=stim_channel, ordered=False) if len(picks) == 0: - raise ValueError('No stim channel found to extract event triggers.') + raise ValueError("No stim channel found to extract event triggers.") data, _ = raw[picks, :] if np.any(data < 0): - warn('Trigger channel contains negative values, using absolute value.') + warn("Trigger channel contains negative values, using absolute value.") data = np.abs(data) # make sure trig channel is positive data = data.astype(np.int64) - return _find_stim_steps(data, raw.first_samp, pad_start=pad_start, - pad_stop=pad_stop, merge=merge) + return _find_stim_steps( + data, raw.first_samp, pad_start=pad_start, pad_stop=pad_stop, merge=merge + ) @verbose -def _find_events(data, first_samp, verbose=None, output='onset', - consecutive='increasing', min_samples=0, mask=None, - uint_cast=False, mask_type='and', initial_event=False): +def _find_events( + data, + first_samp, + verbose=None, + output="onset", + consecutive="increasing", + min_samples=0, + mask=None, + uint_cast=False, + mask_type="and", + initial_event=False, +): """Help find events.""" assert data.shape[0] == 1 # data should be only a row vector @@ -454,42 +505,46 @@ def _find_events(data, first_samp, verbose=None, output='onset', if uint_cast: data = data.astype(np.uint16).astype(np.int64) if data.min() < 0: - warn('Trigger channel contains negative values, using absolute ' - 'value. If data were acquired on a Neuromag system with ' - 'STI016 active, consider using uint_cast=True to work around ' - 'an acquisition bug') + warn( + "Trigger channel contains negative values, using absolute " + "value. If data were acquired on a Neuromag system with " + "STI016 active, consider using uint_cast=True to work around " + "an acquisition bug" + ) data = np.abs(data) # make sure trig channel is positive events = _find_stim_steps(data, first_samp, pad_stop=0, merge=merge) initial_value = data[0, 0] if initial_value != 0: if initial_event: - events = np.insert( - events, 0, [first_samp, 0, initial_value], axis=0) + events = np.insert(events, 0, [first_samp, 0, initial_value], axis=0) else: - logger.info('Trigger channel has a non-zero initial value of {} ' - '(consider using initial_event=True to detect this ' - 'event)'.format(initial_value)) + logger.info( + "Trigger channel has a non-zero initial value of {} " + "(consider using initial_event=True to detect this " + "event)".format(initial_value) + ) events = _mask_trigs(events, mask, mask_type) # Determine event onsets and offsets - if consecutive == 'increasing': - onsets = (events[:, 2] > events[:, 1]) - offsets = np.logical_and(np.logical_or(onsets, (events[:, 2] == 0)), - (events[:, 1] > 0)) + if consecutive == "increasing": + onsets = events[:, 2] > events[:, 1] + offsets = np.logical_and( + np.logical_or(onsets, (events[:, 2] == 0)), (events[:, 1] > 0) + ) elif consecutive: - onsets = (events[:, 2] > 0) - offsets = (events[:, 1] > 0) + onsets = events[:, 2] > 0 + offsets = events[:, 1] > 0 else: - onsets = (events[:, 1] == 0) - offsets = (events[:, 2] == 0) + onsets = events[:, 1] == 0 + offsets = events[:, 2] == 0 onset_idx = np.where(onsets)[0] offset_idx = np.where(offsets)[0] if len(onset_idx) == 0 or len(offset_idx) == 0: - return np.empty((0, 3), dtype='int32') + return np.empty((0, 3), dtype="int32") # delete orphaned onsets/offsets if onset_idx[0] > offset_idx[0]: @@ -500,12 +555,12 @@ def _find_events(data, first_samp, verbose=None, output='onset', logger.info("Removing orphaned onset at the end of the file.") onset_idx = np.delete(onset_idx, -1) - if output == 'onset': + if output == "onset": events = events[onset_idx] - elif output == 'step': + elif output == "step": idx = np.union1d(onset_idx, offset_idx) events = events[idx] - elif output == 'offset': + elif output == "offset": event_id = events[onset_idx, 2] events = events[offset_idx] events[:, 1] = events[:, 2] @@ -523,20 +578,32 @@ def _find_events(data, first_samp, verbose=None, output='onset', def _find_unique_events(events): """Uniquify events (ie remove duplicated rows.""" e = np.ascontiguousarray(events).view( - np.dtype((np.void, events.dtype.itemsize * events.shape[1]))) + np.dtype((np.void, events.dtype.itemsize * events.shape[1])) + ) _, idx = np.unique(e, return_index=True) n_dupes = len(events) - len(idx) if n_dupes > 0: - warn("Some events are duplicated in your different stim channels." - " %d events were ignored during deduplication." % n_dupes) + warn( + "Some events are duplicated in your different stim channels." + " %d events were ignored during deduplication." % n_dupes + ) return events[idx] @verbose -def find_events(raw, stim_channel=None, output='onset', - consecutive='increasing', min_duration=0, - shortest_event=2, mask=None, uint_cast=False, - mask_type='and', initial_event=False, verbose=None): +def find_events( + raw, + stim_channel=None, + output="onset", + consecutive="increasing", + min_duration=0, + shortest_event=2, + mask=None, + uint_cast=False, + mask_type="and", + initial_event=False, + verbose=None, +): """Find :term:`events` from raw file. See :ref:`tut-events-vs-annotations` and :ref:`tut-event-arrays` @@ -683,42 +750,53 @@ def find_events(raw, stim_channel=None, output='onset', ---------------- 2 '0000010' """ - min_samples = min_duration * raw.info['sfreq'] + min_samples = min_duration * raw.info["sfreq"] # pull stim channel from config if necessary try: stim_channel = _get_stim_channel(stim_channel, raw.info) except ValueError: if len(raw.annotations) > 0: - raise ValueError("No stim channels found, but the raw object has " - "annotations. Consider using " - "mne.events_from_annotations to convert these to " - "events.") + raise ValueError( + "No stim channels found, but the raw object has " + "annotations. Consider using " + "mne.events_from_annotations to convert these to " + "events." + ) else: raise - picks = pick_channels(raw.info['ch_names'], include=stim_channel) + picks = pick_channels(raw.info["ch_names"], include=stim_channel) if len(picks) == 0: - raise ValueError('No stim channel found to extract event triggers.') + raise ValueError("No stim channel found to extract event triggers.") data, _ = raw[picks, :] events_list = [] for d in data: - events = _find_events(d[np.newaxis, :], raw.first_samp, - verbose=verbose, output=output, - consecutive=consecutive, min_samples=min_samples, - mask=mask, uint_cast=uint_cast, - mask_type=mask_type, initial_event=initial_event) + events = _find_events( + d[np.newaxis, :], + raw.first_samp, + verbose=verbose, + output=output, + consecutive=consecutive, + min_samples=min_samples, + mask=mask, + uint_cast=uint_cast, + mask_type=mask_type, + initial_event=initial_event, + ) # add safety check for spurious events (for ex. from neuromag syst.) by # checking the number of low sample events n_short_events = np.sum(np.diff(events[:, 0]) < shortest_event) if n_short_events > 0: - raise ValueError("You have %i events shorter than the " - "shortest_event. These are very unusual and you " - "may want to set min_duration to a larger value " - "e.g. x / raw.info['sfreq']. Where x = 1 sample " - "shorter than the shortest event " - "length." % (n_short_events)) + raise ValueError( + "You have %i events shorter than the " + "shortest_event. These are very unusual and you " + "may want to set min_duration to a larger value " + "e.g. x / raw.info['sfreq']. Where x = 1 sample " + "shorter than the shortest event " + "length." % (n_short_events) + ) events_list.append(events) @@ -730,7 +808,7 @@ def find_events(raw, stim_channel=None, output='onset', def _mask_trigs(events, mask, mask_type): """Mask digital trigger values.""" - _check_option('mask_type', mask_type, ['not_and', 'and']) + _check_option("mask_type", mask_type, ["not_and", "and"]) if mask is not None: _validate_type(mask, "int", "mask", "int or None") n_events = len(events) @@ -738,11 +816,13 @@ def _mask_trigs(events, mask, mask_type): return events.copy() if mask is not None: - if mask_type == 'not_and': + if mask_type == "not_and": mask = np.bitwise_not(mask) - elif mask_type != 'and': - raise ValueError("'mask_type' should be either 'and'" - " or 'not_and', instead of '%s'" % mask_type) + elif mask_type != "and": + raise ValueError( + "'mask_type' should be either 'and'" + " or 'not_and', instead of '%s'" % mask_type + ) events[:, 1:] = np.bitwise_and(events[:, 1:], mask) events = events[events[:, 1] != events[:, 2]] @@ -841,8 +921,9 @@ def shift_time_events(events, ids, tshift, sfreq): @fill_doc -def make_fixed_length_events(raw, id=1, start=0, stop=None, duration=1., - first_samp=True, overlap=0.): +def make_fixed_length_events( + raw, id=1, start=0, stop=None, duration=1.0, first_samp=True, overlap=0.0 +): """Make a set of :term:`events` separated by a fixed duration. Parameters @@ -875,14 +956,16 @@ def make_fixed_length_events(raw, id=1, start=0, stop=None, duration=1., %(events)s """ from .io.base import BaseRaw + _validate_type(raw, BaseRaw, "raw") _validate_type(id, int, "id") _validate_type(duration, "numeric", "duration") _validate_type(overlap, "numeric", "overlap") duration, overlap = float(duration), float(overlap) if not 0 <= overlap < duration: - raise ValueError('overlap must be >=0 but < duration (%s), got %s' - % (duration, overlap)) + raise ValueError( + "overlap must be >=0 but < duration (%s), got %s" % (duration, overlap) + ) start = raw.time_as_index(start, use_rounding=True)[0] if stop is not None: @@ -895,16 +978,17 @@ def make_fixed_length_events(raw, id=1, start=0, stop=None, duration=1., else: stop = min([stop, len(raw.times)]) # Make sure we don't go out the end of the file: - stop -= int(np.round(raw.info['sfreq'] * duration)) + stop -= int(np.round(raw.info["sfreq"] * duration)) # This should be inclusive due to how we generally use start and stop... - ts = np.arange(start, stop + 1, - raw.info['sfreq'] * (duration - overlap)).astype(int) + ts = np.arange(start, stop + 1, raw.info["sfreq"] * (duration - overlap)).astype( + int + ) n_events = len(ts) if n_events == 0: - raise ValueError('No events produced, check the values of start, ' - 'stop, and duration') - events = np.c_[ts, np.zeros(n_events, dtype=int), - id * np.ones(n_events, dtype=int)] + raise ValueError( + "No events produced, check the values of start, " "stop, and duration" + ) + events = np.c_[ts, np.zeros(n_events, dtype=int), id * np.ones(n_events, dtype=int)] return events @@ -935,10 +1019,10 @@ def concatenate_events(events, first_samps, last_samps): mne.concatenate_raws """ _validate_type(events, list, "events") - if not (len(events) == len(last_samps) and - len(events) == len(first_samps)): - raise ValueError('events, first_samps, and last_samps must all have ' - 'the same lengths') + if not (len(events) == len(last_samps) and len(events) == len(first_samps)): + raise ValueError( + "events, first_samps, and last_samps must all have " "the same lengths" + ) first_samps = np.array(first_samps) last_samps = np.array(last_samps) n_samps = np.cumsum(last_samps - first_samps + 1) @@ -994,85 +1078,125 @@ class AcqParserFIF: """ # DACQ variables always start with one of these - _acq_var_magic = ['ERF', 'DEF', 'ACQ', 'TCP'] + _acq_var_magic = ["ERF", "DEF", "ACQ", "TCP"] # averager related DACQ variable names (without preceding 'ERF') # old versions (DACQ < 3.4) - _dacq_vars_compat = ('megMax', 'megMin', 'megNoise', 'megSlope', - 'megSpike', 'eegMax', 'eegMin', 'eegNoise', - 'eegSlope', 'eegSpike', 'eogMax', 'ecgMax', 'ncateg', - 'nevent', 'stimSource', 'triggerMap', 'update', - 'artefIgnore', 'averUpdate') - - _event_vars_compat = ('Comment', 'Delay') - - _cat_vars = ('Comment', 'Display', 'Start', 'State', 'End', 'Event', - 'Nave', 'ReqEvent', 'ReqWhen', 'ReqWithin', 'SubAve') + _dacq_vars_compat = ( + "megMax", + "megMin", + "megNoise", + "megSlope", + "megSpike", + "eegMax", + "eegMin", + "eegNoise", + "eegSlope", + "eegSpike", + "eogMax", + "ecgMax", + "ncateg", + "nevent", + "stimSource", + "triggerMap", + "update", + "artefIgnore", + "averUpdate", + ) + + _event_vars_compat = ("Comment", "Delay") + + _cat_vars = ( + "Comment", + "Display", + "Start", + "State", + "End", + "Event", + "Nave", + "ReqEvent", + "ReqWhen", + "ReqWithin", + "SubAve", + ) # new versions only (DACQ >= 3.4) - _dacq_vars = _dacq_vars_compat + ('magMax', 'magMin', 'magNoise', - 'magSlope', 'magSpike', 'version') - - _event_vars = _event_vars_compat + ('Name', 'Channel', 'NewBits', - 'OldBits', 'NewMask', 'OldMask') + _dacq_vars = _dacq_vars_compat + ( + "magMax", + "magMin", + "magNoise", + "magSlope", + "magSpike", + "version", + ) + + _event_vars = _event_vars_compat + ( + "Name", + "Channel", + "NewBits", + "OldBits", + "NewMask", + "OldMask", + ) def __init__(self, info): # noqa: D102 - acq_pars = info['acq_pars'] + acq_pars = info["acq_pars"] if not acq_pars: - raise ValueError('No acquisition parameters') + raise ValueError("No acquisition parameters") self.acq_dict = dict(self._acqpars_gen(acq_pars)) - if 'ERFversion' in self.acq_dict: + if "ERFversion" in self.acq_dict: self.compat = False # DACQ ver >= 3.4 - elif 'ERFncateg' in self.acq_dict: # probably DACQ < 3.4 + elif "ERFncateg" in self.acq_dict: # probably DACQ < 3.4 self.compat = True else: - raise ValueError('Cannot parse acquisition parameters') + raise ValueError("Cannot parse acquisition parameters") dacq_vars = self._dacq_vars_compat if self.compat else self._dacq_vars # set instance variables for var in dacq_vars: - val = self.acq_dict['ERF' + var] - if var[:3] in ['mag', 'meg', 'eeg', 'eog', 'ecg']: + val = self.acq_dict["ERF" + var] + if var[:3] in ["mag", "meg", "eeg", "eog", "ecg"]: val = float(val) - elif var in ['ncateg', 'nevent']: + elif var in ["ncateg", "nevent"]: val = int(val) setattr(self, var.lower(), val) - self.stimsource = ( - 'Internal' if self.stimsource == '1' else 'External') + self.stimsource = "Internal" if self.stimsource == "1" else "External" # collect all events and categories self._events = self._events_from_acq_pars() self._categories = self._categories_from_acq_pars() # mark events that are used by a category for cat in self._categories.values(): - if cat['event']: - self._events[cat['event']]['in_use'] = True - if cat['reqevent']: - self._events[cat['reqevent']]['in_use'] = True + if cat["event"]: + self._events[cat["event"]]["in_use"] = True + if cat["reqevent"]: + self._events[cat["reqevent"]]["in_use"] = True # make mne rejection dicts based on the averager parameters - self.reject = {'grad': self.megmax, 'eeg': self.eegmax, - 'eog': self.eogmax, 'ecg': self.ecgmax} + self.reject = { + "grad": self.megmax, + "eeg": self.eegmax, + "eog": self.eogmax, + "ecg": self.ecgmax, + } if not self.compat: - self.reject['mag'] = self.magmax - self.reject = {k: float(v) for k, v in self.reject.items() - if float(v) > 0} - self.flat = {'grad': self.megmin, 'eeg': self.eegmin} + self.reject["mag"] = self.magmax + self.reject = {k: float(v) for k, v in self.reject.items() if float(v) > 0} + self.flat = {"grad": self.megmin, "eeg": self.eegmin} if not self.compat: - self.flat['mag'] = self.magmin - self.flat = {k: float(v) for k, v in self.flat.items() - if float(v) > 0} + self.flat["mag"] = self.magmin + self.flat = {k: float(v) for k, v in self.flat.items() if float(v) > 0} def __repr__(self): # noqa: D105 - s = ' bits for old DACQ versions - _compat_event_lookup = {1: 1, 2: 2, 3: 4, 4: 8, 5: 16, 6: 32, 7: 3, - 8: 5, 9: 6, 10: 7, 11: 9, 12: 10, 13: 11, - 14: 12, 15: 13, 16: 14, 17: 15} + _compat_event_lookup = { + 1: 1, + 2: 2, + 3: 4, + 4: 8, + 5: 16, + 6: 32, + 7: 3, + 8: 5, + 9: 6, + 10: 7, + 11: 9, + 12: 10, + 13: 11, + 14: 12, + 15: 13, + 16: 14, + 17: 15, + } events = dict() for evnum in range(1, self.nevent + 1): evnum_s = str(evnum).zfill(2) # '01', '02' etc. evdi = dict() - event_vars = (self._event_vars_compat if self.compat - else self._event_vars) + event_vars = self._event_vars_compat if self.compat else self._event_vars for var in event_vars: # name of DACQ variable, e.g. 'ERFeventNewBits01' - acq_key = 'ERFevent' + var + evnum_s + acq_key = "ERFevent" + var + evnum_s # corresponding dict key, e.g. 'newbits' dict_key = var.lower() val = self.acq_dict[acq_key] # type convert numeric values - if dict_key in ['newbits', 'oldbits', 'newmask', 'oldmask']: + if dict_key in ["newbits", "oldbits", "newmask", "oldmask"]: val = int(val) - elif dict_key in ['delay']: + elif dict_key in ["delay"]: val = float(val) evdi[dict_key] = val - evdi['in_use'] = False # __init__() will set this - evdi['index'] = evnum + evdi["in_use"] = False # __init__() will set this + evdi["index"] = evnum if self.compat: - evdi['name'] = str(evnum) - evdi['oldmask'] = 63 - evdi['newmask'] = 63 - evdi['oldbits'] = 0 - evdi['newbits'] = _compat_event_lookup[evnum] + evdi["name"] = str(evnum) + evdi["oldmask"] = 63 + evdi["newmask"] = 63 + evdi["oldbits"] = 0 + evdi["newbits"] = _compat_event_lookup[evnum] events[evnum] = evdi return events def _acqpars_gen(self, acq_pars): """Yield key/value pairs from ``info['acq_pars'])``.""" - key, val = '', '' + key, val = "", "" for line in acq_pars.split(): if any([line.startswith(x) for x in self._acq_var_magic]): key = line - val = '' + val = "" else: if not key: - raise ValueError('Cannot parse acquisition parameters') + raise ValueError("Cannot parse acquisition parameters") # DACQ splits items with spaces into multiple lines - val += ' ' + line if val else line + val += " " + line if val else line yield key, val def _categories_from_acq_pars(self): @@ -1210,20 +1349,20 @@ def _categories_from_acq_pars(self): catdi = dict() # read all category variables for var in self._cat_vars: - acq_key = 'ERFcat' + var + catnum + acq_key = "ERFcat" + var + catnum class_key = var.lower() val = self.acq_dict[acq_key] catdi[class_key] = val # some type conversions - catdi['display'] = (catdi['display'] == '1') - catdi['state'] = (catdi['state'] == '1') - for key in ['start', 'end', 'reqwithin']: + catdi["display"] = catdi["display"] == "1" + catdi["state"] = catdi["state"] == "1" + for key in ["start", "end", "reqwithin"]: catdi[key] = float(catdi[key]) - for key in ['nave', 'event', 'reqevent', 'reqwhen', 'subave']: + for key in ["nave", "event", "reqevent", "reqwhen", "subave"]: catdi[key] = int(catdi[key]) # some convenient extra (non-DACQ) vars - catdi['index'] = int(catnum) # index of category in DACQ list - cats[catdi['comment']] = catdi + catdi["index"] = int(catnum) # index of category in DACQ list + cats[catdi["comment"]] = catdi return cats def _events_mne_to_dacq(self, mne_events): @@ -1239,13 +1378,13 @@ def _events_mne_to_dacq(self, mne_events): events_ = mne_events.copy() events_[:, 1:3] = 0 for n, ev in self._events.items(): - if ev['in_use']: + if ev["in_use"]: pre_ok = ( - np.bitwise_and(ev['oldmask'], - mne_events[:, 1]) == ev['oldbits']) + np.bitwise_and(ev["oldmask"], mne_events[:, 1]) == ev["oldbits"] + ) post_ok = ( - np.bitwise_and(ev['newmask'], - mne_events[:, 2]) == ev['newbits']) + np.bitwise_and(ev["newmask"], mne_events[:, 2]) == ev["newbits"] + ) ok_ind = np.where(pre_ok & post_ok) events_[ok_ind, 2] |= 1 << (n - 1) return events_ @@ -1257,8 +1396,8 @@ def _mne_events_to_category_t0(self, cat, mne_events, sfreq): Then the zero times for the epochs are obtained by considering the reference and conditional (required) events and the delay to stimulus. """ - cat_ev = cat['event'] - cat_reqev = cat['reqevent'] + cat_ev = cat["event"] + cat_reqev = cat["reqevent"] # first convert mne events to dacq event list events = self._events_mne_to_dacq(mne_events) # next, take req. events and delays into account @@ -1268,25 +1407,25 @@ def _mne_events_to_category_t0(self, cat, mne_events, sfreq): refEvents_t = times[refEvents_inds] if cat_reqev: # indices of times where req. event occurs - reqEvents_inds = np.where(events[:, 2] & ( - 1 << cat_reqev - 1))[0] + reqEvents_inds = np.where(events[:, 2] & (1 << cat_reqev - 1))[0] reqEvents_t = times[reqEvents_inds] # relative (to refevent) time window where req. event # must occur (e.g. [0 .2]) - twin = [0, (-1)**(cat['reqwhen']) * cat['reqwithin']] + twin = [0, (-1) ** (cat["reqwhen"]) * cat["reqwithin"]] win = np.round(np.array(sorted(twin)) * sfreq) # to samples refEvents_wins = refEvents_t[:, None] + win req_acc = np.zeros(refEvents_inds.shape, dtype=bool) for t in reqEvents_t: # mark time windows where req. condition is satisfied reqEvent_in_win = np.logical_and( - t >= refEvents_wins[:, 0], t <= refEvents_wins[:, 1]) + t >= refEvents_wins[:, 0], t <= refEvents_wins[:, 1] + ) req_acc |= reqEvent_in_win # drop ref. events where req. event condition is not satisfied refEvents_inds = refEvents_inds[np.where(req_acc)] refEvents_t = times[refEvents_inds] # adjust for trigger-stimulus delay by delaying the ref. event - refEvents_t += int(np.round(self._events[cat_ev]['delay'] * sfreq)) + refEvents_t += int(np.round(self._events[cat_ev]["delay"] * sfreq)) return refEvents_t @property @@ -1295,8 +1434,7 @@ def categories(self): Only returns categories marked active in DACQ. """ - cats = sorted(self._categories_in_use.values(), - key=lambda cat: cat['index']) + cats = sorted(self._categories_in_use.values(), key=lambda cat: cat["index"]) return cats @property @@ -1305,19 +1443,27 @@ def events(self): Only returns events that are in use (referred to by a category). """ - evs = sorted(self._events_in_use.values(), key=lambda ev: ev['index']) + evs = sorted(self._events_in_use.values(), key=lambda ev: ev["index"]) return evs @property def _categories_in_use(self): - return {k: v for k, v in self._categories.items() if v['state']} + return {k: v for k, v in self._categories.items() if v["state"]} @property def _events_in_use(self): - return {k: v for k, v in self._events.items() if v['in_use']} - - def get_condition(self, raw, condition=None, stim_channel=None, mask=None, - uint_cast=None, mask_type='and', delayed_lookup=True): + return {k: v for k, v in self._events.items() if v["in_use"]} + + def get_condition( + self, + raw, + condition=None, + stim_channel=None, + mask=None, + uint_cast=None, + mask_type="and", + delayed_lookup=True, + ): """Get averaging parameters for a condition (averaging category). Output is designed to be used with the Epochs class to extract the @@ -1389,35 +1535,45 @@ def get_condition(self, raw, condition=None, stim_channel=None, mask=None, for cat in condition: if isinstance(cat, str): cat = self[cat] - mne_events = find_events(raw, stim_channel=stim_channel, mask=mask, - mask_type=mask_type, output='step', - uint_cast=uint_cast, consecutive=True, - verbose=False, shortest_event=1) + mne_events = find_events( + raw, + stim_channel=stim_channel, + mask=mask, + mask_type=mask_type, + output="step", + uint_cast=uint_cast, + consecutive=True, + verbose=False, + shortest_event=1, + ) if delayed_lookup: ind = np.where(np.diff(mne_events[:, 0]) == 1)[0] if 1 in np.diff(ind): - raise ValueError('There are several subsequent ' - 'transitions on the trigger channel. ' - 'This will not work well with ' - 'delayed_lookup=True. You may want to ' - 'check your trigger data and ' - 'set delayed_lookup=False.') + raise ValueError( + "There are several subsequent " + "transitions on the trigger channel. " + "This will not work well with " + "delayed_lookup=True. You may want to " + "check your trigger data and " + "set delayed_lookup=False." + ) mne_events[ind, 2] = mne_events[ind + 1, 2] mne_events = np.delete(mne_events, ind + 1, axis=0) - sfreq = raw.info['sfreq'] + sfreq = raw.info["sfreq"] cat_t0_ = self._mne_events_to_category_t0(cat, mne_events, sfreq) # make it compatible with the usual events array - cat_t0 = np.c_[cat_t0_, np.zeros(cat_t0_.shape), - cat['index'] * np.ones(cat_t0_.shape) - ].astype(np.uint32) - cat_id = {cat['comment']: cat['index']} - tmin, tmax = cat['start'], cat['end'] - conds_data.append(dict(events=cat_t0, event_id=cat_id, - tmin=tmin, tmax=tmax)) + cat_t0 = np.c_[ + cat_t0_, np.zeros(cat_t0_.shape), cat["index"] * np.ones(cat_t0_.shape) + ].astype(np.uint32) + cat_id = {cat["comment"]: cat["index"]} + tmin, tmax = cat["start"], cat["end"] + conds_data.append( + dict(events=cat_t0, event_id=cat_id, tmin=tmin, tmax=tmax) + ) return conds_data[0] if len(conds_data) == 1 else conds_data -def match_event_names(event_names, keys, *, on_missing='raise'): +def match_event_names(event_names, keys, *, on_missing="raise"): """Search a collection of event names for matching (sub-)groups of events. This function is particularly helpful when using grouped event names @@ -1468,10 +1624,7 @@ def match_event_names(event_names, keys, *, on_missing='raise'): event_names = list(event_names) # ensure we have a list of `keys` - if ( - isinstance(keys, (Sequence, np.ndarray)) and - not isinstance(keys, str) - ): + if isinstance(keys, (Sequence, np.ndarray)) and not isinstance(keys, str): keys = list(keys) else: keys = [keys] @@ -1481,19 +1634,20 @@ def match_event_names(event_names, keys, *, on_missing='raise'): # form the hierarchical event name mapping for key in keys: if not isinstance(key, str): - raise ValueError(f'keys must be strings, got {type(key)} ({key})') + raise ValueError(f"keys must be strings, got {type(key)} ({key})") matches.extend( - name for name in event_names - if set(key.split('/')).issubset(name.split('/')) + name + for name in event_names + if set(key.split("/")).issubset(name.split("/")) ) if not matches: _on_missing( on_missing=on_missing, msg=f'Event name "{key}" could not be found. The following events ' - f'are present in the data: {", ".join(event_names)}', - error_klass=KeyError + f'are present in the data: {", ".join(event_names)}', + error_klass=KeyError, ) matches = sorted(set(matches)) # deduplicate if necessary diff --git a/mne/evoked.py b/mne/evoked.py index 29c6e5ca1cb..e41c8c10cbb 100644 --- a/mne/evoked.py +++ b/mne/evoked.py @@ -12,20 +12,39 @@ import numpy as np from .baseline import rescale, _log_rescale, _check_baseline -from .channels.channels import (UpdateChannelsMixin, - SetChannelsMixin, InterpolationMixin) +from .channels.channels import UpdateChannelsMixin, SetChannelsMixin, InterpolationMixin from .channels.layout import _merge_ch_data, _pair_grad_sensors -from .defaults import (_INTERPOLATION_DEFAULT, _EXTRAPOLATE_DEFAULT, - _BORDER_DEFAULT) +from .defaults import _INTERPOLATION_DEFAULT, _EXTRAPOLATE_DEFAULT, _BORDER_DEFAULT from .filter import detrend, FilterMixin, _check_fun -from .utils import (check_fname, logger, verbose, warn, sizeof_fmt, repr_html, - SizeMixin, copy_function_doc_to_method_doc, _validate_type, - fill_doc, _check_option, _build_data_frame, - _check_pandas_installed, _check_pandas_index_arguments, - _convert_times, _scale_dataframe_data, _check_time_format, - _check_preload, _check_fname, TimeMixin) -from .viz import (plot_evoked, plot_evoked_topomap, plot_evoked_field, - plot_evoked_image, plot_evoked_topo) +from .utils import ( + check_fname, + logger, + verbose, + warn, + sizeof_fmt, + repr_html, + SizeMixin, + copy_function_doc_to_method_doc, + _validate_type, + fill_doc, + _check_option, + _build_data_frame, + _check_pandas_installed, + _check_pandas_index_arguments, + _convert_times, + _scale_dataframe_data, + _check_time_format, + _check_preload, + _check_fname, + TimeMixin, +) +from .viz import ( + plot_evoked, + plot_evoked_topomap, + plot_evoked_field, + plot_evoked_image, + plot_evoked_topo, +) from .viz.evoked import plot_evoked_white, plot_evoked_joint from .viz.topomap import _topomap_animation @@ -34,37 +53,58 @@ from .io.tag import read_tag from .io.tree import dir_tree_find from .io.pick import pick_types, _picks_to_idx, _FNIRS_CH_TYPES_SPLIT -from .io.meas_info import (ContainsMixin, read_meas_info, write_meas_info, - _read_extended_ch_info, _rename_list, - _ensure_infos_match) +from .io.meas_info import ( + ContainsMixin, + read_meas_info, + write_meas_info, + _read_extended_ch_info, + _rename_list, + _ensure_infos_match, +) from .io.proj import ProjMixin -from .io.write import (start_and_end_file, start_block, end_block, - write_int, write_string, write_float_matrix, - write_id, write_float, write_complex_float_matrix) +from .io.write import ( + start_and_end_file, + start_block, + end_block, + write_int, + write_string, + write_float_matrix, + write_id, + write_float, + write_complex_float_matrix, +) from .io.base import _check_maxshield, _get_ch_factors from .parallel import parallel_func from .time_frequency.spectrum import Spectrum, SpectrumMixin, _validate_method _aspect_dict = { - 'average': FIFF.FIFFV_ASPECT_AVERAGE, - 'standard_error': FIFF.FIFFV_ASPECT_STD_ERR, - 'single_epoch': FIFF.FIFFV_ASPECT_SINGLE, - 'partial_average': FIFF.FIFFV_ASPECT_SUBAVERAGE, - 'alternating_subaverage': FIFF.FIFFV_ASPECT_ALTAVERAGE, - 'sample_cut_out_by_graph': FIFF.FIFFV_ASPECT_SAMPLE, - 'power_density_spectrum': FIFF.FIFFV_ASPECT_POWER_DENSITY, - 'dipole_amplitude_cuvre': FIFF.FIFFV_ASPECT_DIPOLE_WAVE, - 'squid_modulation_lower_bound': FIFF.FIFFV_ASPECT_IFII_LOW, - 'squid_modulation_upper_bound': FIFF.FIFFV_ASPECT_IFII_HIGH, - 'squid_gate_setting': FIFF.FIFFV_ASPECT_GATE, + "average": FIFF.FIFFV_ASPECT_AVERAGE, + "standard_error": FIFF.FIFFV_ASPECT_STD_ERR, + "single_epoch": FIFF.FIFFV_ASPECT_SINGLE, + "partial_average": FIFF.FIFFV_ASPECT_SUBAVERAGE, + "alternating_subaverage": FIFF.FIFFV_ASPECT_ALTAVERAGE, + "sample_cut_out_by_graph": FIFF.FIFFV_ASPECT_SAMPLE, + "power_density_spectrum": FIFF.FIFFV_ASPECT_POWER_DENSITY, + "dipole_amplitude_cuvre": FIFF.FIFFV_ASPECT_DIPOLE_WAVE, + "squid_modulation_lower_bound": FIFF.FIFFV_ASPECT_IFII_LOW, + "squid_modulation_upper_bound": FIFF.FIFFV_ASPECT_IFII_HIGH, + "squid_gate_setting": FIFF.FIFFV_ASPECT_GATE, } _aspect_rev = {val: key for key, val in _aspect_dict.items()} @fill_doc -class Evoked(ProjMixin, ContainsMixin, UpdateChannelsMixin, SetChannelsMixin, - InterpolationMixin, FilterMixin, TimeMixin, SizeMixin, - SpectrumMixin): +class Evoked( + ProjMixin, + ContainsMixin, + UpdateChannelsMixin, + SetChannelsMixin, + InterpolationMixin, + FilterMixin, + TimeMixin, + SizeMixin, + SpectrumMixin, +): """Evoked data. Parameters @@ -123,17 +163,28 @@ class Evoked(ProjMixin, ContainsMixin, UpdateChannelsMixin, SetChannelsMixin, """ @verbose - def __init__(self, fname, condition=None, proj=True, - kind='average', allow_maxshield=False, *, - verbose=None): # noqa: D102 + def __init__( + self, + fname, + condition=None, + proj=True, + kind="average", + allow_maxshield=False, + *, + verbose=None, + ): # noqa: D102 _validate_type(proj, bool, "'proj'") # Read the requested data - fname = str( - _check_fname(fname=fname, must_exist=True, overwrite="read") - ) - self.info, self.nave, self._aspect_kind, self.comment, times, \ - self.data, self.baseline = _read_evoked(fname, condition, kind, - allow_maxshield) + fname = str(_check_fname(fname=fname, must_exist=True, overwrite="read")) + ( + self.info, + self.nave, + self._aspect_kind, + self.comment, + times, + self.data, + self.baseline, + ) = _read_evoked(fname, condition, kind, allow_maxshield) self._set_times(times) self._raw_times = self.times.copy() self._decim = 1 @@ -152,7 +203,7 @@ def kind(self): @kind.setter def kind(self, kind): - _check_option('kind', kind, list(_aspect_dict.keys())) + _check_option("kind", kind, list(_aspect_dict.keys())) self._aspect_kind = _aspect_dict[kind] @property @@ -200,8 +251,9 @@ def get_data(self, picks=None, units=None, tmin=None, tmax=None): return data @verbose - def apply_function(self, fun, picks=None, dtype=None, n_jobs=None, - verbose=None, **kwargs): + def apply_function( + self, fun, picks=None, dtype=None, n_jobs=None, verbose=None, **kwargs + ): """Apply a function to a subset of channels. %(applyfun_summary_evoked)s @@ -221,18 +273,18 @@ def apply_function(self, fun, picks=None, dtype=None, n_jobs=None, self : instance of Evoked The evoked object with transformed data. """ - _check_preload(self, 'evoked.apply_function') + _check_preload(self, "evoked.apply_function") picks = _picks_to_idx(self.info, picks, exclude=(), with_ref_meg=False) if not callable(fun): - raise ValueError('fun needs to be a function') + raise ValueError("fun needs to be a function") data_in = self._data if dtype is not None and dtype != self._data.dtype: self._data = self._data.astype(dtype) # check the dimension of the incoming evoked data - _check_option('evoked.ndim', self._data.ndim, [2]) + _check_option("evoked.ndim", self._data.ndim, [2]) parallel, p_fun, n_jobs = parallel_func(_check_fun, n_jobs) if n_jobs == 1: @@ -241,8 +293,9 @@ def apply_function(self, fun, picks=None, dtype=None, n_jobs=None, self._data[idx, :] = _check_fun(fun, data_in[idx, :], **kwargs) else: # use parallel function - data_picks_new = parallel(p_fun( - fun, data_in[p, :], **kwargs) for p in picks) + data_picks_new = parallel( + p_fun(fun, data_in[p, :], **kwargs) for p in picks + ) for pp, p in enumerate(picks): self._data[p, :] = data_picks_new[pp] @@ -270,11 +323,12 @@ def apply_baseline(self, baseline=(None, 0), *, verbose=None): .. versionadded:: 0.13.0 """ - baseline = _check_baseline(baseline, times=self.times, - sfreq=self.info['sfreq']) + baseline = _check_baseline(baseline, times=self.times, sfreq=self.info["sfreq"]) if self.baseline is not None and baseline is None: - raise ValueError('The data has already been baseline-corrected. ' - 'Cannot remove existing baseline correction.') + raise ValueError( + "The data has already been baseline-corrected. " + "Cannot remove existing baseline correction." + ) elif baseline is None: # Do not rescale logger.info(_log_rescale(None)) @@ -309,7 +363,7 @@ def save(self, fname, *, overwrite=False, verbose=None): write_evokeds(fname, self, overwrite=overwrite) @verbose - def export(self, fname, fmt='auto', *, overwrite=False, verbose=None): + def export(self, fname, fmt="auto", *, overwrite=False, verbose=None): """Export Evoked to external formats. %(export_fmt_support_evoked)s @@ -330,6 +384,7 @@ def export(self, fname, fmt='auto', *, overwrite=False, verbose=None): %(export_warning_note_evoked)s """ from .export import export_evokeds + export_evokeds(fname, self, fmt, overwrite=overwrite, verbose=verbose) def __repr__(self): # noqa: D105 @@ -341,15 +396,18 @@ def __repr__(self): # noqa: D105 comment = self.comment s = "'%s' (%s, N=%s)" % (comment, self.kind, self.nave) s += ", %0.5g – %0.5g s" % (self.times[0], self.times[-1]) - s += ', baseline ' + s += ", baseline " if self.baseline is None: - s += 'off' + s += "off" else: - s += f'{self.baseline[0]:g} – {self.baseline[1]:g} s' + s += f"{self.baseline[0]:g} – {self.baseline[1]:g} s" if self.baseline != _check_baseline( - self.baseline, times=self.times, sfreq=self.info['sfreq'], - on_baseline_outside_data='adjust'): - s += ' (baseline period was cropped after baseline correction)' + self.baseline, + times=self.times, + sfreq=self.info["sfreq"], + on_baseline_outside_data="adjust", + ): + s += " (baseline period was cropped after baseline correction)" s += ", %s ch" % self.data.shape[0] s += ", ~%s" % (sizeof_fmt(self._size),) return "" % s @@ -357,122 +415,328 @@ def __repr__(self): # noqa: D105 @repr_html def _repr_html_(self): from .html_templates import repr_templates_env + if self.baseline is None: - baseline = 'off' + baseline = "off" else: - baseline = tuple([f'{b:.3f}' for b in self.baseline]) - baseline = f'{baseline[0]} – {baseline[1]} s' + baseline = tuple([f"{b:.3f}" for b in self.baseline]) + baseline = f"{baseline[0]} – {baseline[1]} s" - t = repr_templates_env.get_template('evoked.html.jinja') + t = repr_templates_env.get_template("evoked.html.jinja") t = t.render(evoked=self, baseline=baseline) return t @property def ch_names(self): """Channel names.""" - return self.info['ch_names'] + return self.info["ch_names"] @copy_function_doc_to_method_doc(plot_evoked) - def plot(self, picks=None, exclude='bads', unit=True, show=True, ylim=None, - xlim='tight', proj=False, hline=None, units=None, scalings=None, - titles=None, axes=None, gfp=False, window_title=None, - spatial_colors='auto', zorder='unsorted', selectable=True, - noise_cov=None, time_unit='s', sphere=None, *, highlight=None, - verbose=None): + def plot( + self, + picks=None, + exclude="bads", + unit=True, + show=True, + ylim=None, + xlim="tight", + proj=False, + hline=None, + units=None, + scalings=None, + titles=None, + axes=None, + gfp=False, + window_title=None, + spatial_colors="auto", + zorder="unsorted", + selectable=True, + noise_cov=None, + time_unit="s", + sphere=None, + *, + highlight=None, + verbose=None, + ): return plot_evoked( - self, picks=picks, exclude=exclude, unit=unit, show=show, - ylim=ylim, proj=proj, xlim=xlim, hline=hline, units=units, - scalings=scalings, titles=titles, axes=axes, gfp=gfp, - window_title=window_title, spatial_colors=spatial_colors, - zorder=zorder, selectable=selectable, noise_cov=noise_cov, - time_unit=time_unit, sphere=sphere, highlight=highlight, - verbose=verbose) + self, + picks=picks, + exclude=exclude, + unit=unit, + show=show, + ylim=ylim, + proj=proj, + xlim=xlim, + hline=hline, + units=units, + scalings=scalings, + titles=titles, + axes=axes, + gfp=gfp, + window_title=window_title, + spatial_colors=spatial_colors, + zorder=zorder, + selectable=selectable, + noise_cov=noise_cov, + time_unit=time_unit, + sphere=sphere, + highlight=highlight, + verbose=verbose, + ) @copy_function_doc_to_method_doc(plot_evoked_image) - def plot_image(self, picks=None, exclude='bads', unit=True, show=True, - clim=None, xlim='tight', proj=False, units=None, - scalings=None, titles=None, axes=None, cmap='RdBu_r', - colorbar=True, mask=None, mask_style=None, - mask_cmap='Greys', mask_alpha=.25, time_unit='s', - show_names=None, group_by=None, sphere=None): + def plot_image( + self, + picks=None, + exclude="bads", + unit=True, + show=True, + clim=None, + xlim="tight", + proj=False, + units=None, + scalings=None, + titles=None, + axes=None, + cmap="RdBu_r", + colorbar=True, + mask=None, + mask_style=None, + mask_cmap="Greys", + mask_alpha=0.25, + time_unit="s", + show_names=None, + group_by=None, + sphere=None, + ): return plot_evoked_image( - self, picks=picks, exclude=exclude, unit=unit, show=show, - clim=clim, xlim=xlim, proj=proj, units=units, scalings=scalings, - titles=titles, axes=axes, cmap=cmap, colorbar=colorbar, mask=mask, - mask_style=mask_style, mask_cmap=mask_cmap, mask_alpha=mask_alpha, - time_unit=time_unit, show_names=show_names, group_by=group_by, - sphere=sphere) + self, + picks=picks, + exclude=exclude, + unit=unit, + show=show, + clim=clim, + xlim=xlim, + proj=proj, + units=units, + scalings=scalings, + titles=titles, + axes=axes, + cmap=cmap, + colorbar=colorbar, + mask=mask, + mask_style=mask_style, + mask_cmap=mask_cmap, + mask_alpha=mask_alpha, + time_unit=time_unit, + show_names=show_names, + group_by=group_by, + sphere=sphere, + ) @copy_function_doc_to_method_doc(plot_evoked_topo) - def plot_topo(self, layout=None, layout_scale=0.945, color=None, - border='none', ylim=None, scalings=None, title=None, - proj=False, vline=[0.0], fig_background=None, - merge_grads=False, legend=True, axes=None, - background_color='w', noise_cov=None, exclude='bads', - show=True): + def plot_topo( + self, + layout=None, + layout_scale=0.945, + color=None, + border="none", + ylim=None, + scalings=None, + title=None, + proj=False, + vline=[0.0], + fig_background=None, + merge_grads=False, + legend=True, + axes=None, + background_color="w", + noise_cov=None, + exclude="bads", + show=True, + ): """ Notes ----- .. versionadded:: 0.10.0 """ return plot_evoked_topo( - self, layout=layout, layout_scale=layout_scale, - color=color, border=border, ylim=ylim, scalings=scalings, - title=title, proj=proj, vline=vline, fig_background=fig_background, - merge_grads=merge_grads, legend=legend, axes=axes, - background_color=background_color, noise_cov=noise_cov, - exclude=exclude, show=show) + self, + layout=layout, + layout_scale=layout_scale, + color=color, + border=border, + ylim=ylim, + scalings=scalings, + title=title, + proj=proj, + vline=vline, + fig_background=fig_background, + merge_grads=merge_grads, + legend=legend, + axes=axes, + background_color=background_color, + noise_cov=noise_cov, + exclude=exclude, + show=show, + ) @copy_function_doc_to_method_doc(plot_evoked_topomap) def plot_topomap( - self, times="auto", *, average=None, ch_type=None, scalings=None, - proj=False, sensors=True, show_names=False, mask=None, - mask_params=None, contours=6, outlines='head', sphere=None, - image_interp=_INTERPOLATION_DEFAULT, - extrapolate=_EXTRAPOLATE_DEFAULT, border=_BORDER_DEFAULT, res=64, - size=1, cmap=None, vlim=(None, None), cnorm=None, colorbar=True, - cbar_fmt='%3.1f', units=None, axes=None, time_unit='s', - time_format=None, nrows=1, ncols='auto', show=True): + self, + times="auto", + *, + average=None, + ch_type=None, + scalings=None, + proj=False, + sensors=True, + show_names=False, + mask=None, + mask_params=None, + contours=6, + outlines="head", + sphere=None, + image_interp=_INTERPOLATION_DEFAULT, + extrapolate=_EXTRAPOLATE_DEFAULT, + border=_BORDER_DEFAULT, + res=64, + size=1, + cmap=None, + vlim=(None, None), + cnorm=None, + colorbar=True, + cbar_fmt="%3.1f", + units=None, + axes=None, + time_unit="s", + time_format=None, + nrows=1, + ncols="auto", + show=True, + ): return plot_evoked_topomap( - self, times=times, ch_type=ch_type, vlim=vlim, cmap=cmap, - cnorm=cnorm, sensors=sensors, colorbar=colorbar, scalings=scalings, - units=units, res=res, size=size, cbar_fmt=cbar_fmt, - time_unit=time_unit, time_format=time_format, proj=proj, show=show, - show_names=show_names, mask=mask, mask_params=mask_params, - outlines=outlines, contours=contours, image_interp=image_interp, - average=average, axes=axes, extrapolate=extrapolate, sphere=sphere, - border=border, nrows=nrows, ncols=ncols) + self, + times=times, + ch_type=ch_type, + vlim=vlim, + cmap=cmap, + cnorm=cnorm, + sensors=sensors, + colorbar=colorbar, + scalings=scalings, + units=units, + res=res, + size=size, + cbar_fmt=cbar_fmt, + time_unit=time_unit, + time_format=time_format, + proj=proj, + show=show, + show_names=show_names, + mask=mask, + mask_params=mask_params, + outlines=outlines, + contours=contours, + image_interp=image_interp, + average=average, + axes=axes, + extrapolate=extrapolate, + sphere=sphere, + border=border, + nrows=nrows, + ncols=ncols, + ) @copy_function_doc_to_method_doc(plot_evoked_field) - def plot_field(self, surf_maps, time=None, time_label='t = %0.0f ms', - n_jobs=None, fig=None, vmax=None, n_contours=21, - *, interaction='terrain', verbose=None): - return plot_evoked_field(self, surf_maps, time=time, - time_label=time_label, n_jobs=n_jobs, - fig=fig, vmax=vmax, n_contours=n_contours, - interaction=interaction, verbose=verbose) + def plot_field( + self, + surf_maps, + time=None, + time_label="t = %0.0f ms", + n_jobs=None, + fig=None, + vmax=None, + n_contours=21, + *, + interaction="terrain", + verbose=None, + ): + return plot_evoked_field( + self, + surf_maps, + time=time, + time_label=time_label, + n_jobs=n_jobs, + fig=fig, + vmax=vmax, + n_contours=n_contours, + interaction=interaction, + verbose=verbose, + ) @copy_function_doc_to_method_doc(plot_evoked_white) - def plot_white(self, noise_cov, show=True, rank=None, time_unit='s', - sphere=None, axes=None, verbose=None): + def plot_white( + self, + noise_cov, + show=True, + rank=None, + time_unit="s", + sphere=None, + axes=None, + verbose=None, + ): return plot_evoked_white( - self, noise_cov=noise_cov, rank=rank, show=show, - time_unit=time_unit, sphere=sphere, axes=axes, verbose=verbose) + self, + noise_cov=noise_cov, + rank=rank, + show=show, + time_unit=time_unit, + sphere=sphere, + axes=axes, + verbose=verbose, + ) @copy_function_doc_to_method_doc(plot_evoked_joint) - def plot_joint(self, times="peaks", title='', picks=None, - exclude='bads', show=True, ts_args=None, - topomap_args=None): - return plot_evoked_joint(self, times=times, title=title, picks=picks, - exclude=exclude, show=show, ts_args=ts_args, - topomap_args=topomap_args) + def plot_joint( + self, + times="peaks", + title="", + picks=None, + exclude="bads", + show=True, + ts_args=None, + topomap_args=None, + ): + return plot_evoked_joint( + self, + times=times, + title=title, + picks=picks, + exclude=exclude, + show=show, + ts_args=ts_args, + topomap_args=topomap_args, + ) @fill_doc - def animate_topomap(self, ch_type=None, times=None, frame_rate=None, - butterfly=False, blit=True, show=True, time_unit='s', - sphere=None, *, image_interp=_INTERPOLATION_DEFAULT, - extrapolate=_EXTRAPOLATE_DEFAULT, vmin=None, vmax=None, - verbose=None): + def animate_topomap( + self, + ch_type=None, + times=None, + frame_rate=None, + butterfly=False, + blit=True, + show=True, + time_unit="s", + sphere=None, + *, + image_interp=_INTERPOLATION_DEFAULT, + extrapolate=_EXTRAPOLATE_DEFAULT, + vmin=None, + vmax=None, + verbose=None, + ): """Make animation of evoked data as topomap timeseries. The animation can be paused/resumed with left mouse button. @@ -530,12 +794,23 @@ def animate_topomap(self, ch_type=None, times=None, frame_rate=None, .. versionadded:: 0.12.0 """ return _topomap_animation( - self, ch_type=ch_type, times=times, frame_rate=frame_rate, - butterfly=butterfly, blit=blit, show=show, time_unit=time_unit, - sphere=sphere, image_interp=image_interp, - extrapolate=extrapolate, vmin=vmin, vmax=vmax, verbose=verbose) + self, + ch_type=ch_type, + times=times, + frame_rate=frame_rate, + butterfly=butterfly, + blit=blit, + show=show, + time_unit=time_unit, + sphere=sphere, + image_interp=image_interp, + extrapolate=extrapolate, + vmin=vmin, + vmax=vmax, + verbose=verbose, + ) - def as_type(self, ch_type='grad', mode='fast'): + def as_type(self, ch_type="grad", mode="fast"): """Compute virtual evoked using interpolated fields. .. Warning:: Using virtual evoked to compute inverse can yield @@ -565,6 +840,7 @@ def as_type(self, ch_type='grad', mode='fast'): .. versionadded:: 0.9.0 """ from .forward import _as_meg_type_inst + return _as_meg_type_inst(self, ch_type=ch_type, mode=mode) @fill_doc @@ -612,14 +888,21 @@ def __neg__(self): out = self.copy() out.data *= -1 - if out.comment is not None and ' + ' in out.comment: - out.comment = f'({out.comment})' # multiple conditions in evoked + if out.comment is not None and " + " in out.comment: + out.comment = f"({out.comment})" # multiple conditions in evoked out.comment = f'- {out.comment or "unknown"}' return out - def get_peak(self, ch_type=None, tmin=None, tmax=None, - mode='abs', time_as_index=False, merge_grads=False, - return_amplitude=False): + def get_peak( + self, + ch_type=None, + tmin=None, + tmax=None, + mode="abs", + time_as_index=False, + merge_grads=False, + return_amplitude=False, + ): """Get location and latency of peak amplitude. Parameters @@ -660,11 +943,19 @@ def get_peak(self, ch_type=None, tmin=None, tmax=None, .. versionadded:: 0.16 """ # noqa: E501 - supported = ('mag', 'grad', 'eeg', 'seeg', 'dbs', 'ecog', 'misc', - 'None') + _FNIRS_CH_TYPES_SPLIT + supported = ( + "mag", + "grad", + "eeg", + "seeg", + "dbs", + "ecog", + "misc", + "None", + ) + _FNIRS_CH_TYPES_SPLIT types_used = self.get_channel_types(unique=True, only_data_chs=True) - _check_option('ch_type', str(ch_type), supported) + _check_option("ch_type", str(ch_type), supported) if ch_type is not None and ch_type not in types_used: raise ValueError( @@ -674,29 +965,31 @@ def get_peak(self, ch_type=None, tmin=None, tmax=None, elif len(types_used) > 1 and ch_type is None: raise RuntimeError( 'Multiple data channel types found. Please pass the "ch_type" ' - 'parameter.' + "parameter." ) if merge_grads: - if ch_type != 'grad': + if ch_type != "grad": raise ValueError('Channel type must be "grad" for merge_grads') - elif mode == 'neg': - raise ValueError('Negative mode (mode=neg) does not make ' - 'sense with merge_grads=True') + elif mode == "neg": + raise ValueError( + "Negative mode (mode=neg) does not make " + "sense with merge_grads=True" + ) meg = eeg = misc = seeg = dbs = ecog = fnirs = False picks = None - if ch_type in ('mag', 'grad'): + if ch_type in ("mag", "grad"): meg = ch_type - elif ch_type == 'eeg': + elif ch_type == "eeg": eeg = True - elif ch_type == 'misc': + elif ch_type == "misc": misc = True - elif ch_type == 'seeg': + elif ch_type == "seeg": seeg = True - elif ch_type == 'dbs': + elif ch_type == "dbs": dbs = True - elif ch_type == 'ecog': + elif ch_type == "ecog": ecog = True elif ch_type in _FNIRS_CH_TYPES_SPLIT: fnirs = ch_type @@ -705,9 +998,17 @@ def get_peak(self, ch_type=None, tmin=None, tmax=None, if merge_grads: picks = _pair_grad_sensors(self.info, topomap_coords=False) else: - picks = pick_types(self.info, meg=meg, eeg=eeg, misc=misc, - seeg=seeg, ecog=ecog, ref_meg=False, - fnirs=fnirs, dbs=dbs) + picks = pick_types( + self.info, + meg=meg, + eeg=eeg, + misc=misc, + seeg=seeg, + ecog=ecog, + ref_meg=False, + fnirs=fnirs, + dbs=dbs, + ) data = self.data ch_names = self.ch_names @@ -717,13 +1018,11 @@ def get_peak(self, ch_type=None, tmin=None, tmax=None, if merge_grads: data, _ = _merge_ch_data(data, ch_type, []) - ch_names = [ch_name[:-1] + 'X' for ch_name in ch_names[::2]] + ch_names = [ch_name[:-1] + "X" for ch_name in ch_names[::2]] - ch_idx, time_idx, max_amp = _get_peak(data, self.times, tmin, - tmax, mode) + ch_idx, time_idx, max_amp = _get_peak(data, self.times, tmin, tmax, mode) - out = (ch_names[ch_idx], time_idx if time_as_index else - self.times[time_idx]) + out = (ch_names[ch_idx], time_idx if time_as_index else self.times[time_idx]) if return_amplitude: out += (max_amp,) @@ -731,9 +1030,20 @@ def get_peak(self, ch_type=None, tmin=None, tmax=None, return out @verbose - def compute_psd(self, method='multitaper', fmin=0, fmax=np.inf, tmin=None, - tmax=None, picks=None, proj=False, *, n_jobs=1, - verbose=None, **method_kw): + def compute_psd( + self, + method="multitaper", + fmin=0, + fmax=np.inf, + tmin=None, + tmax=None, + picks=None, + proj=False, + *, + n_jobs=1, + verbose=None, + **method_kw, + ): """Perform spectral analysis on sensor data. Parameters @@ -765,17 +1075,48 @@ def compute_psd(self, method='multitaper', fmin=0, fmax=np.inf, tmin=None, self._set_legacy_nfft_default(tmin, tmax, method, method_kw) return Spectrum( - self, method=method, fmin=fmin, fmax=fmax, tmin=tmin, tmax=tmax, - picks=picks, proj=proj, reject_by_annotation=False, n_jobs=n_jobs, - verbose=verbose, **method_kw) + self, + method=method, + fmin=fmin, + fmax=fmax, + tmin=tmin, + tmax=tmax, + picks=picks, + proj=proj, + reject_by_annotation=False, + n_jobs=n_jobs, + verbose=verbose, + **method_kw, + ) @verbose - def plot_psd(self, fmin=0, fmax=np.inf, tmin=None, tmax=None, picks=None, - proj=False, *, method='auto', average=False, dB=True, - estimate='auto', xscale='linear', area_mode='std', - area_alpha=0.33, color='black', line_alpha=None, - spatial_colors=True, sphere=None, exclude='bads', ax=None, - show=True, n_jobs=1, verbose=None, **method_kw): + def plot_psd( + self, + fmin=0, + fmax=np.inf, + tmin=None, + tmax=None, + picks=None, + proj=False, + *, + method="auto", + average=False, + dB=True, + estimate="auto", + xscale="linear", + area_mode="std", + area_alpha=0.33, + color="black", + line_alpha=None, + spatial_colors=True, + sphere=None, + exclude="bads", + ax=None, + show=True, + n_jobs=1, + verbose=None, + **method_kw, + ): """%(plot_psd_doc)s. Parameters @@ -819,17 +1160,44 @@ def plot_psd(self, fmin=0, fmax=np.inf, tmin=None, tmax=None, picks=None, %(notes_plot_psd_meth)s """ return super().plot_psd( - fmin=fmin, fmax=fmax, tmin=tmin, tmax=tmax, picks=picks, proj=proj, - reject_by_annotation=False, method=method, average=average, dB=dB, - estimate=estimate, xscale=xscale, area_mode=area_mode, - area_alpha=area_alpha, color=color, line_alpha=line_alpha, - spatial_colors=spatial_colors, sphere=sphere, exclude=exclude, - ax=ax, show=show, n_jobs=n_jobs, verbose=verbose, **method_kw) + fmin=fmin, + fmax=fmax, + tmin=tmin, + tmax=tmax, + picks=picks, + proj=proj, + reject_by_annotation=False, + method=method, + average=average, + dB=dB, + estimate=estimate, + xscale=xscale, + area_mode=area_mode, + area_alpha=area_alpha, + color=color, + line_alpha=line_alpha, + spatial_colors=spatial_colors, + sphere=sphere, + exclude=exclude, + ax=ax, + show=show, + n_jobs=n_jobs, + verbose=verbose, + **method_kw, + ) @verbose - def to_data_frame(self, picks=None, index=None, - scalings=None, copy=True, long_format=False, - time_format=None, *, verbose=None): + def to_data_frame( + self, + picks=None, + index=None, + scalings=None, + copy=True, + long_format=False, + time_format=None, + *, + verbose=None, + ): """Export data in tabular structure as a pandas DataFrame. Channels are converted to columns in the DataFrame. By default, @@ -856,12 +1224,12 @@ def to_data_frame(self, picks=None, index=None, # check pandas once here, instead of in each private utils function pd = _check_pandas_installed() # noqa # arg checking - valid_index_args = ['time'] - valid_time_formats = ['ms', 'timedelta'] + valid_index_args = ["time"] + valid_time_formats = ["ms", "timedelta"] index = _check_pandas_index_arguments(index, valid_index_args) time_format = _check_time_format(time_format, valid_time_formats) # get data - picks = _picks_to_idx(self.info, picks, 'all', exclude=()) + picks = _picks_to_idx(self.info, picks, "all", exclude=()) data = self.data[picks, :] times = self.times data = data.T @@ -871,10 +1239,11 @@ def to_data_frame(self, picks=None, index=None, # prepare extra columns / multiindex mindex = list() times = _convert_times(self, times, time_format) - mindex.append(('time', times)) + mindex.append(("time", times)) # build DataFrame - df = _build_data_frame(self, data, picks, long_format, mindex, index, - default_index=['time']) + df = _build_data_frame( + self, data, picks, long_format, mindex, index, default_index=["time"] + ) return df @@ -919,26 +1288,40 @@ class EvokedArray(Evoked): """ @verbose - def __init__(self, data, info, tmin=0., comment='', nave=1, kind='average', - baseline=None, *, verbose=None): # noqa: D102 + def __init__( + self, + data, + info, + tmin=0.0, + comment="", + nave=1, + kind="average", + baseline=None, + *, + verbose=None, + ): # noqa: D102 dtype = np.complex128 if np.iscomplexobj(data) else np.float64 data = np.asanyarray(data, dtype=dtype) if data.ndim != 2: - raise ValueError('Data must be a 2D array of shape (n_channels, ' - 'n_samples), got shape %s' % (data.shape,)) + raise ValueError( + "Data must be a 2D array of shape (n_channels, " + "n_samples), got shape %s" % (data.shape,) + ) - if len(info['ch_names']) != np.shape(data)[0]: - raise ValueError('Info (%s) and data (%s) must have same number ' - 'of channels.' % (len(info['ch_names']), - np.shape(data)[0])) + if len(info["ch_names"]) != np.shape(data)[0]: + raise ValueError( + "Info (%s) and data (%s) must have same number " + "of channels." % (len(info["ch_names"]), np.shape(data)[0]) + ) self.data = data - self.first = int(round(tmin * info['sfreq'])) + self.first = int(round(tmin * info["sfreq"])) self.last = self.first + np.shape(data)[-1] - 1 - self._set_times(np.arange(self.first, self.last + 1, - dtype=np.float64) / info['sfreq']) + self._set_times( + np.arange(self.first, self.last + 1, dtype=np.float64) / info["sfreq"] + ) self._raw_times = self.times.copy() self._decim = 1 self.info = info.copy() # do not modify original info @@ -950,8 +1333,10 @@ def __init__(self, data, info, tmin=0., comment='', nave=1, kind='average', self._projector = None _validate_type(self.kind, "str", "kind") if self.kind not in _aspect_dict: - raise ValueError('unknown kind "%s", should be "average" or ' - '"standard_error"' % (self.kind,)) + raise ValueError( + 'unknown kind "%s", should be "average" or ' + '"standard_error"' % (self.kind,) + ) self._aspect_kind = _aspect_dict[self.kind] self.baseline = baseline @@ -964,16 +1349,16 @@ def _get_entries(fid, evoked_node, allow_maxshield=False): comments = list() aspect_kinds = list() for ev in evoked_node: - for k in range(ev['nent']): - my_kind = ev['directory'][k].kind - pos = ev['directory'][k].pos + for k in range(ev["nent"]): + my_kind = ev["directory"][k].kind + pos = ev["directory"][k].pos if my_kind == FIFF.FIFF_COMMENT: tag = read_tag(fid, pos) comments.append(tag.data) my_aspect = _get_aspect(ev, allow_maxshield)[0] - for k in range(my_aspect['nent']): - my_kind = my_aspect['directory'][k].kind - pos = my_aspect['directory'][k].pos + for k in range(my_aspect["nent"]): + my_kind = my_aspect["directory"][k].kind + pos = my_aspect["directory"][k].pos if my_kind == FIFF.FIFF_ASPECT_KIND: tag = read_tag(fid, pos) aspect_kinds.append(int(tag.data.item())) @@ -981,11 +1366,10 @@ def _get_entries(fid, evoked_node, allow_maxshield=False): aspect_kinds = np.atleast_1d(aspect_kinds) if len(comments) != len(aspect_kinds) or len(comments) == 0: fid.close() - raise ValueError('Dataset names in FIF file ' - 'could not be found.') + raise ValueError("Dataset names in FIF file " "could not be found.") t = [_aspect_rev[a] for a in aspect_kinds] - t = ['"' + c + '" (' + tt + ')' for tt, c in zip(t, comments)] - t = '\n'.join(t) + t = ['"' + c + '" (' + tt + ")" for tt, c in zip(t, comments)] + t = "\n".join(t) return comments, aspect_kinds, t @@ -998,7 +1382,7 @@ def _get_aspect(evoked, allow_maxshield): aspect = dir_tree_find(evoked, FIFF.FIFFB_IAS_ASPECT) is_maxshield = True if len(aspect) > 1: - logger.info('Multiple data aspects found. Taking first one.') + logger.info("Multiple data aspects found. Taking first one.") return aspect[0], is_maxshield @@ -1018,16 +1402,17 @@ def _check_evokeds_ch_names_times(all_evoked): if ev.ch_names != ch_names: if set(ev.ch_names) != set(ch_names): raise ValueError( - "%s and %s do not contain the same channels." % (evoked, - ev)) + "%s and %s do not contain the same channels." % (evoked, ev) + ) else: warn("Order of channels differs, reordering channels ...") ev = ev.copy() ev.reorder_channels(ch_names) all_evoked[ii + 1] = ev if not np.max(np.abs(ev.times - evoked.times)) < 1e-7: - raise ValueError("%s and %s do not contain the same time instants" - % (evoked, ev)) + raise ValueError( + "%s and %s do not contain the same time instants" % (evoked, ev) + ) return all_evoked @@ -1066,8 +1451,8 @@ def combine_evoked(all_evoked, weights): """ naves = np.array([evk.nave for evk in all_evoked], float) if isinstance(weights, str): - _check_option('weights', weights, ['nave', 'equal']) - if weights == 'nave': + _check_option("weights", weights, ["nave", "equal"]) + if weights == "nave": weights = naves / naves.sum() else: weights = np.ones_like(naves) / len(naves) @@ -1075,7 +1460,7 @@ def combine_evoked(all_evoked, weights): weights = np.array(weights, float) if weights.ndim != 1 or weights.size != len(all_evoked): - raise ValueError('weights must be the same size as all_evoked') + raise ValueError("weights must be the same size as all_evoked") # cf. https://en.wikipedia.org/wiki/Weighted_arithmetic_mean, section on # "weighted sample variance". The variance of a weighted sample mean is: @@ -1087,7 +1472,7 @@ def combine_evoked(all_evoked, weights): # σ² = w₁² / nave₁ + w₂² / nave₂ + ... + wₙ² / naveₙ # # And our resulting nave is the reciprocal of this: - new_nave = 1. / np.sum(weights ** 2 / naves) + new_nave = 1.0 / np.sum(weights**2 / naves) # This general formula is equivalent to formulae in Matti's manual # (pp 128-129), where: # new_nave = sum(naves) when weights='nave' and @@ -1097,37 +1482,44 @@ def combine_evoked(all_evoked, weights): evoked = all_evoked[0].copy() # use union of bad channels - bads = list(set(b for e in all_evoked for b in e.info['bads'])) - evoked.info['bads'] = bads + bads = list(set(b for e in all_evoked for b in e.info["bads"])) + evoked.info["bads"] = bads evoked.data = sum(w * e.data for w, e in zip(weights, all_evoked)) evoked.nave = new_nave - comment = '' + comment = "" for idx, (w, e) in enumerate(zip(weights, all_evoked)): # pick sign - sign = '' if w >= 0 else '-' + sign = "" if w >= 0 else "-" # format weight - weight = '' if np.isclose(abs(w), 1.) else f'{abs(w):0.3f}' + weight = "" if np.isclose(abs(w), 1.0) else f"{abs(w):0.3f}" # format multiplier - multiplier = ' × ' if weight else '' + multiplier = " × " if weight else "" # format comment - if e.comment is not None and ' + ' in e.comment: # multiple conditions - this_comment = f'({e.comment})' + if e.comment is not None and " + " in e.comment: # multiple conditions + this_comment = f"({e.comment})" else: this_comment = f'{e.comment or "unknown"}' # assemble everything if idx == 0: - comment += f'{sign}{weight}{multiplier}{this_comment}' + comment += f"{sign}{weight}{multiplier}{this_comment}" else: comment += f' {sign or "+"} {weight}{multiplier}{this_comment}' # special-case: combine_evoked([e1, -e2], [1, -1]) - evoked.comment = comment.replace(' - - ', ' + ') + evoked.comment = comment.replace(" - - ", " + ") return evoked @verbose -def read_evokeds(fname, condition=None, baseline=None, kind='average', - proj=True, allow_maxshield=False, verbose=None): +def read_evokeds( + fname, + condition=None, + baseline=None, + kind="average", + proj=True, + allow_maxshield=False, + verbose=None, +): """Read evoked dataset(s). Parameters @@ -1182,9 +1574,8 @@ def read_evokeds(fname, condition=None, baseline=None, kind='average', reading. """ fname = str(_check_fname(fname, overwrite="read", must_exist=True)) - check_fname(fname, 'evoked', ('-ave.fif', '-ave.fif.gz', - '_ave.fif', '_ave.fif.gz')) - logger.info('Reading %s ...' % fname) + check_fname(fname, "evoked", ("-ave.fif", "-ave.fif.gz", "_ave.fif", "_ave.fif.gz")) + logger.info("Reading %s ..." % fname) return_list = True if condition is None: evoked_node = _get_evoked_node(fname) @@ -1195,16 +1586,23 @@ def read_evokeds(fname, condition=None, baseline=None, kind='average', out = [] for c in condition: - evoked = Evoked(fname, c, kind=kind, proj=proj, - allow_maxshield=allow_maxshield, - verbose=verbose) + evoked = Evoked( + fname, + c, + kind=kind, + proj=proj, + allow_maxshield=allow_maxshield, + verbose=verbose, + ) if baseline is None and evoked.baseline is None: logger.info(_log_rescale(None)) elif baseline is None and evoked.baseline is not None: # Don't touch an existing baseline bmin, bmax = evoked.baseline - logger.info(f'Loaded Evoked data is baseline-corrected ' - f'(baseline: [{bmin:g}, {bmax:g}] s)') + logger.info( + f"Loaded Evoked data is baseline-corrected " + f"(baseline: [{bmin:g}, {bmax:g}] s)" + ) else: evoked.apply_baseline(baseline) out.append(evoked) @@ -1212,10 +1610,10 @@ def read_evokeds(fname, condition=None, baseline=None, kind='average', return out if return_list else out[0] -def _read_evoked(fname, condition=None, kind='average', allow_maxshield=False): +def _read_evoked(fname, condition=None, kind="average", allow_maxshield=False): """Read evoked data from a FIF file.""" if fname is None: - raise ValueError('No evoked filename specified') + raise ValueError("No evoked filename specified") f, tree, _ = fiff_open(fname) with f as fid: @@ -1225,47 +1623,47 @@ def _read_evoked(fname, condition=None, kind='average', allow_maxshield=False): # Locate the data of interest processed = dir_tree_find(meas, FIFF.FIFFB_PROCESSED_DATA) if len(processed) == 0: - raise ValueError('Could not find processed data') + raise ValueError("Could not find processed data") evoked_node = dir_tree_find(meas, FIFF.FIFFB_EVOKED) if len(evoked_node) == 0: - raise ValueError('Could not find evoked data') + raise ValueError("Could not find evoked data") # find string-based entry if isinstance(condition, str): if kind not in _aspect_dict.keys(): - raise ValueError('kind must be "average" or ' - '"standard_error"') + raise ValueError('kind must be "average" or ' '"standard_error"') - comments, aspect_kinds, t = _get_entries(fid, evoked_node, - allow_maxshield) - goods = (np.in1d(comments, [condition]) & - np.in1d(aspect_kinds, [_aspect_dict[kind]])) + comments, aspect_kinds, t = _get_entries(fid, evoked_node, allow_maxshield) + goods = np.in1d(comments, [condition]) & np.in1d( + aspect_kinds, [_aspect_dict[kind]] + ) found_cond = np.where(goods)[0] if len(found_cond) != 1: - raise ValueError('condition "%s" (%s) not found, out of ' - 'found datasets:\n%s' - % (condition, kind, t)) + raise ValueError( + 'condition "%s" (%s) not found, out of ' + "found datasets:\n%s" % (condition, kind, t) + ) condition = found_cond[0] elif condition is None: if len(evoked_node) > 1: - _, _, conditions = _get_entries(fid, evoked_node, - allow_maxshield) - raise TypeError("Evoked file has more than one " - "condition, the condition parameters " - "must be specified from:\n%s" % conditions) + _, _, conditions = _get_entries(fid, evoked_node, allow_maxshield) + raise TypeError( + "Evoked file has more than one " + "condition, the condition parameters " + "must be specified from:\n%s" % conditions + ) else: condition = 0 if condition >= len(evoked_node) or condition < 0: - raise ValueError('Data set selector out of range') + raise ValueError("Data set selector out of range") my_evoked = evoked_node[condition] # Identify the aspects with info._unlock(): - my_aspect, info['maxshield'] = _get_aspect(my_evoked, - allow_maxshield) + my_aspect, info["maxshield"] = _get_aspect(my_evoked, allow_maxshield) # Now find the data in the evoked block nchan = 0 @@ -1273,9 +1671,9 @@ def _read_evoked(fname, condition=None, kind='average', allow_maxshield=False): chs = [] baseline = bmin = bmax = None comment = last = first = first_time = nsamp = None - for k in range(my_evoked['nent']): - my_kind = my_evoked['directory'][k].kind - pos = my_evoked['directory'][k].pos + for k in range(my_evoked["nent"]): + my_kind = my_evoked["directory"][k].kind + pos = my_evoked["directory"][k].pos if my_kind == FIFF.FIFF_COMMENT: tag = read_tag(fid, pos) comment = tag.data @@ -1308,7 +1706,7 @@ def _read_evoked(fname, condition=None, kind='average', allow_maxshield=False): bmax = float(tag.data.item()) if comment is None: - comment = 'No comment' + comment = "No comment" if bmin is not None or bmax is not None: # None's should've been replaced with floats @@ -1318,27 +1716,31 @@ def _read_evoked(fname, condition=None, kind='average', allow_maxshield=False): # Local channel information? if nchan > 0: if chs is None: - raise ValueError('Local channel information was not found ' - 'when it was expected.') + raise ValueError( + "Local channel information was not found " "when it was expected." + ) if len(chs) != nchan: - raise ValueError('Number of channels and number of ' - 'channel definitions are different') + raise ValueError( + "Number of channels and number of " + "channel definitions are different" + ) ch_names_mapping = _read_extended_ch_info(chs, my_evoked, fid) - info['chs'] = chs - info['bads'][:] = _rename_list(info['bads'], ch_names_mapping) - logger.info(' Found channel information in evoked data. ' - 'nchan = %d' % nchan) + info["chs"] = chs + info["bads"][:] = _rename_list(info["bads"], ch_names_mapping) + logger.info( + " Found channel information in evoked data. " "nchan = %d" % nchan + ) if sfreq > 0: - info['sfreq'] = sfreq + info["sfreq"] = sfreq # Read the data in the aspect block nave = 1 epoch = [] - for k in range(my_aspect['nent']): - kind = my_aspect['directory'][k].kind - pos = my_aspect['directory'][k].pos + for k in range(my_aspect["nent"]): + kind = my_aspect["directory"][k].kind + pos = my_aspect["directory"][k].pos if kind == FIFF.FIFF_COMMENT: tag = read_tag(fid, pos) comment = tag.data @@ -1353,16 +1755,17 @@ def _read_evoked(fname, condition=None, kind='average', allow_maxshield=False): epoch.append(tag) nepoch = len(epoch) - if nepoch != 1 and nepoch != info['nchan']: - raise ValueError('Number of epoch tags is unreasonable ' - '(nepoch = %d nchan = %d)' - % (nepoch, info['nchan'])) + if nepoch != 1 and nepoch != info["nchan"]: + raise ValueError( + "Number of epoch tags is unreasonable " + "(nepoch = %d nchan = %d)" % (nepoch, info["nchan"]) + ) if nepoch == 1: # Only one epoch data = epoch[0].data # May need a transpose if the number of channels is one - if data.shape[1] == 1 and info['nchan'] == 1: + if data.shape[1] == 1 and info["nchan"] == 1: data = data.T else: # Put the old style epochs together @@ -1373,37 +1776,43 @@ def _read_evoked(fname, condition=None, kind='average', allow_maxshield=False): data = data.astype(np.complex128) if first_time is not None and nsamp is not None: - times = first_time + np.arange(nsamp) / info['sfreq'] + times = first_time + np.arange(nsamp) / info["sfreq"] elif first is not None: nsamp = last - first + 1 - times = np.arange(first, last + 1) / info['sfreq'] + times = np.arange(first, last + 1) / info["sfreq"] else: - raise RuntimeError('Could not read time parameters') + raise RuntimeError("Could not read time parameters") del first, last if nsamp is not None and data.shape[1] != nsamp: - raise ValueError('Incorrect number of samples (%d instead of ' - ' %d)' % (data.shape[1], nsamp)) - logger.info(' Found the data of interest:') - logger.info(' t = %10.2f ... %10.2f ms (%s)' - % (1000 * times[0], 1000 * times[-1], comment)) - if info['comps'] is not None: - logger.info(' %d CTF compensation matrices available' - % len(info['comps'])) - logger.info(' nave = %d - aspect type = %d' - % (nave, aspect_kind)) + raise ValueError( + "Incorrect number of samples (%d instead of " + " %d)" % (data.shape[1], nsamp) + ) + logger.info(" Found the data of interest:") + logger.info( + " t = %10.2f ... %10.2f ms (%s)" + % (1000 * times[0], 1000 * times[-1], comment) + ) + if info["comps"] is not None: + logger.info( + " %d CTF compensation matrices available" % len(info["comps"]) + ) + logger.info(" nave = %d - aspect type = %d" % (nave, aspect_kind)) # Calibrate - cals = np.array([info['chs'][k]['cal'] * - info['chs'][k].get('scale', 1.0) - for k in range(info['nchan'])]) + cals = np.array( + [ + info["chs"][k]["cal"] * info["chs"][k].get("scale", 1.0) + for k in range(info["nchan"]) + ] + ) data *= cals[:, np.newaxis] return info, nave, aspect_kind, comment, times, data, baseline @verbose -def write_evokeds(fname, evoked, *, on_mismatch='raise', overwrite=False, - verbose=None): +def write_evokeds(fname, evoked, *, on_mismatch="raise", overwrite=False, verbose=None): """Write an evoked dataset to a file. Parameters @@ -1436,15 +1845,15 @@ def write_evokeds(fname, evoked, *, on_mismatch='raise', overwrite=False, _write_evokeds(fname, evoked, on_mismatch=on_mismatch, overwrite=overwrite) -def _write_evokeds(fname, evoked, check=True, *, on_mismatch='raise', - overwrite=False): +def _write_evokeds(fname, evoked, check=True, *, on_mismatch="raise", overwrite=False): """Write evoked data.""" from .dipole import DipoleFixed # avoid circular import fname = _check_fname(fname=fname, overwrite=overwrite) if check: - check_fname(fname, 'evoked', ('-ave.fif', '-ave.fif.gz', - '_ave.fif', '_ave.fif.gz')) + check_fname( + fname, "evoked", ("-ave.fif", "-ave.fif.gz", "_ave.fif", "_ave.fif.gz") + ) if not isinstance(evoked, (list, tuple)): evoked = [evoked] @@ -1452,11 +1861,10 @@ def _write_evokeds(fname, evoked, check=True, *, on_mismatch='raise', warned = False # Create the file and save the essentials with start_and_end_file(fname) as fid: - start_block(fid, FIFF.FIFFB_MEAS) write_id(fid, FIFF.FIFF_BLOCK_ID) - if evoked[0].info['meas_id'] is not None: - write_id(fid, FIFF.FIFF_PARENT_BLOCK_ID, evoked[0].info['meas_id']) + if evoked[0].info["meas_id"] is not None: + write_id(fid, FIFF.FIFF_PARENT_BLOCK_ID, evoked[0].info["meas_id"]) # Write measurement info write_meas_info(fid, evoked[0].info) @@ -1465,9 +1873,12 @@ def _write_evokeds(fname, evoked, check=True, *, on_mismatch='raise', start_block(fid, FIFF.FIFFB_PROCESSED_DATA) for ei, e in enumerate(evoked): if ei: - _ensure_infos_match(info1=evoked[0].info, info2=e.info, - name=f'evoked[{ei}]', - on_mismatch=on_mismatch) + _ensure_infos_match( + info1=evoked[0].info, + info2=e.info, + name=f"evoked[{ei}]", + on_mismatch=on_mismatch, + ) start_block(fid, FIFF.FIFFB_EVOKED) # Comment is optional @@ -1487,7 +1898,7 @@ def _write_evokeds(fname, evoked, check=True, *, on_mismatch='raise', write_float(fid, FIFF.FIFF_MNE_BASELINE_MAX, bmax) # The evoked data itself - if e.info.get('maxshield'): + if e.info.get("maxshield"): aspect = FIFF.FIFFB_IAS_ASPECT else: aspect = FIFF.FIFFB_ASPECT @@ -1497,17 +1908,20 @@ def _write_evokeds(fname, evoked, check=True, *, on_mismatch='raise', # convert nave to integer to comply with FIFF spec nave_int = int(round(e.nave)) if nave_int != e.nave and not warned: - warn('converting "nave" to integer before saving evoked; this ' - 'can have a minor effect on the scale of source ' - 'estimates that are computed using "nave".') + warn( + 'converting "nave" to integer before saving evoked; this ' + "can have a minor effect on the scale of source " + 'estimates that are computed using "nave".' + ) warned = True write_int(fid, FIFF.FIFF_NAVE, nave_int) del nave_int - decal = np.zeros((e.info['nchan'], 1)) - for k in range(e.info['nchan']): - decal[k] = 1.0 / (e.info['chs'][k]['cal'] * - e.info['chs'][k].get('scale', 1.0)) + decal = np.zeros((e.info["nchan"], 1)) + for k in range(e.info["nchan"]): + decal[k] = 1.0 / ( + e.info["chs"][k]["cal"] * e.info["chs"][k].get("scale", 1.0) + ) if np.iscomplexobj(e.data): write_function = write_complex_float_matrix @@ -1522,7 +1936,7 @@ def _write_evokeds(fname, evoked, check=True, *, on_mismatch='raise', end_block(fid, FIFF.FIFFB_MEAS) -def _get_peak(data, times, tmin=None, tmax=None, mode='abs'): +def _get_peak(data, times, tmin=None, tmax=None, mode="abs"): """Get feature-index and time of maximum signal from 2D array. Note. This is a 'getter', not a 'finder'. For non-evoked type @@ -1553,7 +1967,7 @@ def _get_peak(data, times, tmin=None, tmax=None, mode='abs'): max_amp : float Amplitude of the maximum response. """ - _check_option('mode', mode, ['abs', 'neg', 'pos']) + _check_option("mode", mode, ["abs", "neg", "pos"]) if tmin is None: tmin = times[0] @@ -1562,36 +1976,37 @@ def _get_peak(data, times, tmin=None, tmax=None, mode='abs'): if tmin < times.min() or tmax > times.max(): if tmin < times.min(): - param_name = 'tmin' + param_name = "tmin" param_val = tmin else: - param_name = 'tmax' + param_name = "tmax" param_val = tmax raise ValueError( - f'{param_name} ({param_val}) is out of bounds. It must be ' - f'between {times.min()} and {times.max()}' + f"{param_name} ({param_val}) is out of bounds. It must be " + f"between {times.min()} and {times.max()}" ) elif tmin > tmax: - raise ValueError(f'tmin ({tmin}) must be <= tmax ({tmax})') + raise ValueError(f"tmin ({tmin}) must be <= tmax ({tmax})") time_win = (times >= tmin) & (times <= tmax) mask = np.ones_like(data).astype(bool) mask[:, time_win] = False maxfun = np.argmax - if mode == 'pos': + if mode == "pos": if not np.any(data[~mask] > 0): - raise ValueError('No positive values encountered. Cannot ' - 'operate in pos mode.') - elif mode == 'neg': + raise ValueError( + "No positive values encountered. Cannot " "operate in pos mode." + ) + elif mode == "neg": if not np.any(data[~mask] < 0): - raise ValueError('No negative values encountered. Cannot ' - 'operate in neg mode.') + raise ValueError( + "No negative values encountered. Cannot " "operate in neg mode." + ) maxfun = np.argmin - masked_index = np.ma.array(np.abs(data) if mode == 'abs' else data, - mask=mask) + masked_index = np.ma.array(np.abs(data) if mode == "abs" else data, mask=mask) max_loc, max_time = np.unravel_index(maxfun(masked_index), data.shape) diff --git a/mne/export/_brainvision.py b/mne/export/_brainvision.py index 91e0c08b94d..ff61ee939fb 100644 --- a/mne/export/_brainvision.py +++ b/mne/export/_brainvision.py @@ -5,6 +5,7 @@ import os from ..utils import _check_pybv_installed + _check_pybv_installed() from pybv._export import _export_mne_raw # noqa: E402 diff --git a/mne/export/_edf.py b/mne/export/_edf.py index 8a6b1370470..3666aae30fe 100644 --- a/mne/export/_edf.py +++ b/mne/export/_edf.py @@ -6,6 +6,7 @@ import numpy as np from ..utils import _check_edflib_installed, warn + _check_edflib_installed() from EDFlib.edfwriter import EDFwriter # noqa: E402 @@ -14,7 +15,7 @@ def _try_to_set_value(header, key, value, channel_index=None): """Set key/value pairs in EDF header.""" # all EDFLib set functions are set # for example "setPatientName()" - func_name = f'set{key}' + func_name = f"set{key}" func = getattr(header, func_name) # some setter functions are indexed by channels @@ -25,9 +26,9 @@ def _try_to_set_value(header, key, value, channel_index=None): # a nonzero return value indicates an error if return_val != 0: - raise RuntimeError(f"Setting {key} with {value} " - f"returned an error value " - f"{return_val}.") + raise RuntimeError( + f"Setting {key} with {value} " f"returned an error value " f"{return_val}." + ) @contextmanager @@ -49,11 +50,12 @@ def _export_raw(fname, raw, physical_range, add_ch_type): technician information, allow writing those here. """ # scale to save data in EDF - phys_dims = 'uV' + phys_dims = "uV" # get EEG-related data in uV - units = dict(eeg='uV', ecog='uV', seeg='uV', eog='uV', ecg='uV', emg='uV', - bio='uV', dbs='uV') + units = dict( + eeg="uV", ecog="uV", seeg="uV", eog="uV", ecg="uV", emg="uV", bio="uV", dbs="uV" + ) digital_min = -32767 digital_max = 32767 @@ -65,8 +67,8 @@ def _export_raw(fname, raw, physical_range, add_ch_type): # remove extra STI channels orig_ch_types = raw.get_channel_types() drop_chs = [] - if 'stim' in orig_ch_types: - stim_index = np.argwhere(np.array(orig_ch_types) == 'stim') + if "stim" in orig_ch_types: + stim_index = np.argwhere(np.array(orig_ch_types) == "stim") stim_index = np.atleast_1d(stim_index.squeeze()).tolist() drop_chs.extend([raw.ch_names[idx] for idx in stim_index]) @@ -77,17 +79,19 @@ def _export_raw(fname, raw, physical_range, add_ch_type): # Note: we can write these other channels, such as 'misc' # but these are simply a "catch all" for unknown or undesired # channels. - voltage_types = list(units) + ['stim', 'misc'] + voltage_types = list(units) + ["stim", "misc"] non_voltage_ch = [ch not in voltage_types for ch in orig_ch_types] if any(non_voltage_ch): - warn(f"Non-voltage channels detected: {non_voltage_ch}. MNE-Python's " - 'EDF exporter only supports voltage-based channels, because the ' - 'EDF format cannot accommodate much of the accompanying data ' - 'necessary for channel types like MEG and fNIRS (channel ' - 'orientations, coordinate frame transforms, etc). You can ' - 'override this restriction by setting those channel types to ' - '"misc" but no guarantees are made of the fidelity of that ' - 'approach.') + warn( + f"Non-voltage channels detected: {non_voltage_ch}. MNE-Python's " + "EDF exporter only supports voltage-based channels, because the " + "EDF format cannot accommodate much of the accompanying data " + "necessary for channel types like MEG and fNIRS (channel " + "orientations, coordinate frame transforms, etc). You can " + "override this restriction by setting those channel types to " + '"misc" but no guarantees are made of the fidelity of that ' + "approach." + ) ch_names = [ch for ch in raw.ch_names if ch not in drop_chs] ch_types = np.array(raw.get_channel_types(picks=ch_names)) @@ -97,28 +101,29 @@ def _export_raw(fname, raw, physical_range, add_ch_type): # Sampling frequency in EDF only supports integers, so to allow for # float sampling rates from Raw, we adjust the output sampling rate # for all channels and the data record duration. - sfreq = raw.info['sfreq'] + sfreq = raw.info["sfreq"] if float(sfreq).is_integer(): out_sfreq = int(sfreq) data_record_duration = None else: out_sfreq = np.floor(sfreq).astype(int) - data_record_duration = int(np.around( - out_sfreq / sfreq, decimals=6) * 1e6) + data_record_duration = int(np.around(out_sfreq / sfreq, decimals=6) * 1e6) - warn(f'Data has a non-integer sampling rate of {sfreq}; writing to ' - 'EDF format may cause a small change to sample times.') + warn( + f"Data has a non-integer sampling rate of {sfreq}; writing to " + "EDF format may cause a small change to sample times." + ) # get any filter information applied to the data - lowpass = raw.info['lowpass'] - highpass = raw.info['highpass'] - linefreq = raw.info['line_freq'] + lowpass = raw.info["lowpass"] + highpass = raw.info["highpass"] + linefreq = raw.info["line_freq"] filter_str_info = f"HP:{highpass}Hz LP:{lowpass}Hz N:{linefreq}Hz" # get the entire dataset in uV data = raw.get_data(units=units, picks=ch_names) - if physical_range == 'auto': + if physical_range == "auto": # get max and min for each channel type data ch_types_phys_max = dict() ch_types_phys_min = dict() @@ -156,54 +161,60 @@ def _export_raw(fname, raw, physical_range, add_ch_type): # set channel data for idx, ch in enumerate(ch_names): ch_type = ch_types[idx] - signal_label = f'{ch_type.upper()} {ch}' if add_ch_type else ch + signal_label = f"{ch_type.upper()} {ch}" if add_ch_type else ch if len(signal_label) > 16: - raise RuntimeError(f'Signal label for {ch} ({ch_type}) is ' - f'longer than 16 characters, which is not ' - f'supported in EDF. Please shorten the ' - f'channel name before exporting to EDF.') - - if physical_range == 'auto': + raise RuntimeError( + f"Signal label for {ch} ({ch_type}) is " + f"longer than 16 characters, which is not " + f"supported in EDF. Please shorten the " + f"channel name before exporting to EDF." + ) + + if physical_range == "auto": # take the channel type minimum and maximum pmin = ch_types_phys_min[ch_type] pmax = ch_types_phys_max[ch_type] - for key, val in [('PhysicalMaximum', pmax), - ('PhysicalMinimum', pmin), - ('DigitalMaximum', digital_max), - ('DigitalMinimum', digital_min), - ('PhysicalDimension', phys_dims), - ('SampleFrequency', out_sfreq), - ('SignalLabel', signal_label), - ('PreFilter', filter_str_info)]: + for key, val in [ + ("PhysicalMaximum", pmax), + ("PhysicalMinimum", pmin), + ("DigitalMaximum", digital_max), + ("DigitalMinimum", digital_min), + ("PhysicalDimension", phys_dims), + ("SampleFrequency", out_sfreq), + ("SignalLabel", signal_label), + ("PreFilter", filter_str_info), + ]: _try_to_set_value(hdl, key, val, channel_index=idx) # set patient info - subj_info = raw.info.get('subject_info') + subj_info = raw.info.get("subject_info") if subj_info is not None: - birthday = subj_info.get('birthday') + birthday = subj_info.get("birthday") # get the full name of subject if available - first_name = subj_info.get('first_name') - last_name = subj_info.get('last_name') - first_name = first_name or '' - last_name = last_name or '' - joiner = '' + first_name = subj_info.get("first_name") + last_name = subj_info.get("last_name") + first_name = first_name or "" + last_name = last_name or "" + joiner = "" if len(first_name) and len(last_name): - joiner = ' ' + joiner = " " name = joiner.join([first_name, last_name]) - hand = subj_info.get('hand') - sex = subj_info.get('sex') + hand = subj_info.get("hand") + sex = subj_info.get("sex") if birthday is not None: - if hdl.setPatientBirthDate(birthday[0], birthday[1], - birthday[2]) != 0: + if hdl.setPatientBirthDate(birthday[0], birthday[1], birthday[2]) != 0: raise RuntimeError( f"Setting patient birth date to {birthday} " - f"returned an error") - for key, val in [('PatientName', name), - ('PatientGender', sex), - ('AdditionalPatientInfo', f'hand={hand}')]: + f"returned an error" + ) + for key, val in [ + ("PatientName", name), + ("PatientGender", sex), + ("AdditionalPatientInfo", f"hand={hand}"), + ]: # EDFwriter compares integer encodings of sex and will # raise a TypeError if value is None as returned by # subj_info.get(key) if key is missing. @@ -211,25 +222,33 @@ def _export_raw(fname, raw, physical_range, add_ch_type): _try_to_set_value(hdl, key, val) # set measurement date - meas_date = raw.info['meas_date'] + meas_date = raw.info["meas_date"] if meas_date: subsecond = int(meas_date.microsecond / 100) - if hdl.setStartDateTime(year=meas_date.year, month=meas_date.month, - day=meas_date.day, hour=meas_date.hour, - minute=meas_date.minute, - second=meas_date.second, - subsecond=subsecond) != 0: - raise RuntimeError(f"Setting start date time {meas_date} " - f"returned an error") - - device_info = raw.info.get('device_info') + if ( + hdl.setStartDateTime( + year=meas_date.year, + month=meas_date.month, + day=meas_date.day, + hour=meas_date.hour, + minute=meas_date.minute, + second=meas_date.second, + subsecond=subsecond, + ) + != 0 + ): + raise RuntimeError( + f"Setting start date time {meas_date} " f"returned an error" + ) + + device_info = raw.info.get("device_info") if device_info is not None: - device_type = device_info.get('type') - _try_to_set_value(hdl, 'Equipment', device_type) + device_type = device_info.get("type") + _try_to_set_value(hdl, "Equipment", device_type) # set data record duration if data_record_duration is not None: - _try_to_set_value(hdl, 'DataRecordDuration', data_record_duration) + _try_to_set_value(hdl, "DataRecordDuration", data_record_duration) # compute number of data records to loop over n_blocks = np.ceil(n_times / out_sfreq).astype(int) @@ -260,29 +279,36 @@ def _export_raw(fname, raw, physical_range, add_ch_type): ch_data = data[jdx, start_samp:end_samp] # assign channel data to the buffer and write to EDF - buf[:len(ch_data)] = ch_data + buf[: len(ch_data)] = ch_data err = hdl.writeSamples(buf) if err != 0: raise RuntimeError( f"writeSamples() for channel{ch_names[jdx]} " - f"returned error: {err}") + f"returned error: {err}" + ) # there was an incomplete datarecord if len(ch_data) != len(buf): - warn(f'EDF format requires equal-length data blocks, ' - f'so {(len(buf) - len(ch_data)) / sfreq} seconds of ' - 'zeros were appended to all channels when writing the ' - 'final block.') + warn( + f"EDF format requires equal-length data blocks, " + f"so {(len(buf) - len(ch_data)) / sfreq} seconds of " + "zeros were appended to all channels when writing the " + "final block." + ) # write annotations if annots is not None: - for desc, onset, duration in zip(raw.annotations.description, - raw.annotations.onset, - raw.annotations.duration): + for desc, onset, duration in zip( + raw.annotations.description, + raw.annotations.onset, + raw.annotations.duration, + ): # annotations are written in terms of 100 microseconds onset = onset * 10000 duration = duration * 10000 if hdl.writeAnnotation(onset, duration, desc) != 0: - raise RuntimeError(f'writeAnnotation() returned an error ' - f'trying to write {desc} at {onset} ' - f'for {duration} seconds.') + raise RuntimeError( + f"writeAnnotation() returned an error " + f"trying to write {desc} at {onset} " + f"for {duration} seconds." + ) diff --git a/mne/export/_eeglab.py b/mne/export/_eeglab.py index 00d566c13fe..3fd1cc55902 100644 --- a/mne/export/_eeglab.py +++ b/mne/export/_eeglab.py @@ -5,6 +5,7 @@ import numpy as np from ..utils import _check_eeglabio_installed + _check_eeglabio_installed() import eeglabio.raw # noqa: E402 import eeglabio.epochs # noqa: E402 @@ -15,20 +16,27 @@ def _export_raw(fname, raw): raw.load_data() # remove extra epoc and STI channels - drop_chs = ['epoc'] + drop_chs = ["epoc"] # filenames attribute of RawArray is filled with None - if raw.filenames[0] and not (raw.filenames[0].endswith('.fif')): - drop_chs.append('STI 014') + if raw.filenames[0] and not (raw.filenames[0].endswith(".fif")): + drop_chs.append("STI 014") ch_names = [ch for ch in raw.ch_names if ch not in drop_chs] - cart_coords = _get_als_coords_from_chs(raw.info['chs'], drop_chs) + cart_coords = _get_als_coords_from_chs(raw.info["chs"], drop_chs) - annotations = [raw.annotations.description, - raw.annotations.onset, - raw.annotations.duration] + annotations = [ + raw.annotations.description, + raw.annotations.onset, + raw.annotations.duration, + ] eeglabio.raw.export_set( - fname, data=raw.get_data(picks=ch_names), sfreq=raw.info['sfreq'], - ch_names=ch_names, ch_locs=cart_coords, annotations=annotations) + fname, + data=raw.get_data(picks=ch_names), + sfreq=raw.info["sfreq"], + ch_names=ch_names, + ch_locs=cart_coords, + annotations=annotations, + ) def _export_epochs(fname, epochs): @@ -37,21 +45,31 @@ def _export_epochs(fname, epochs): epochs.load_data() # remove extra epoc and STI channels - drop_chs = ['epoc', 'STI 014'] + drop_chs = ["epoc", "STI 014"] ch_names = [ch for ch in epochs.ch_names if ch not in drop_chs] - cart_coords = _get_als_coords_from_chs(epochs.info['chs'], drop_chs) + cart_coords = _get_als_coords_from_chs(epochs.info["chs"], drop_chs) if epochs.annotations: - annot = [epochs.annotations.description, epochs.annotations.onset, - epochs.annotations.duration] + annot = [ + epochs.annotations.description, + epochs.annotations.onset, + epochs.annotations.duration, + ] else: annot = None eeglabio.epochs.export_set( - fname, data=epochs.get_data(picks=ch_names), - sfreq=epochs.info['sfreq'], events=epochs.events, - tmin=epochs.tmin, tmax=epochs.tmax, ch_names=ch_names, - event_id=epochs.event_id, ch_locs=cart_coords, annotations=annot) + fname, + data=epochs.get_data(picks=ch_names), + sfreq=epochs.info["sfreq"], + events=epochs.events, + tmin=epochs.tmin, + tmax=epochs.tmax, + ch_names=ch_names, + event_id=epochs.event_id, + ch_locs=cart_coords, + annotations=annot, + ) def _get_als_coords_from_chs(chs, drop_chs=None): @@ -63,8 +81,7 @@ def _get_als_coords_from_chs(chs, drop_chs=None): """ if drop_chs is None: drop_chs = [] - cart_coords = np.array([d['loc'][:3] for d in chs - if d['ch_name'] not in drop_chs]) + cart_coords = np.array([d["loc"][:3] for d in chs if d["ch_name"] not in drop_chs]) if cart_coords.any(): # has coordinates # (-y x z) to (x y z) cart_coords[:, 0] = -cart_coords[:, 0] # -y to y diff --git a/mne/export/_egimff.py b/mne/export/_egimff.py index 65418d35d6c..2fc1e66ef9e 100644 --- a/mne/export/_egimff.py +++ b/mne/export/_egimff.py @@ -15,8 +15,7 @@ @verbose -def export_evokeds_mff(fname, evoked, history=None, *, overwrite=False, - verbose=None): +def export_evokeds_mff(fname, evoked, history=None, *, overwrite=False, verbose=None): """Export evoked dataset to MFF. %(export_warning)s @@ -49,18 +48,22 @@ def export_evokeds_mff(fname, evoked, history=None, *, overwrite=False, (e.g. 'HydroCel GSN 256 1.0'). This field is automatically populated when using MFF read functions. """ - mffpy = _import_mffpy('Export evokeds to MFF.') + mffpy = _import_mffpy("Export evokeds to MFF.") import pytz + info = evoked[0].info - if np.round(info['sfreq']) != info['sfreq']: - raise ValueError('Sampling frequency must be a whole number. ' - f'sfreq: {info["sfreq"]}') - sampling_rate = int(info['sfreq']) + if np.round(info["sfreq"]) != info["sfreq"]: + raise ValueError( + "Sampling frequency must be a whole number. " f'sfreq: {info["sfreq"]}' + ) + sampling_rate = int(info["sfreq"]) # check for unapplied projectors - if any(not proj['active'] for proj in evoked[0].info['projs']): - warn('Evoked instance has unapplied projectors. Consider applying ' - 'them before exporting with evoked.apply_proj().') + if any(not proj["active"] for proj in evoked[0].info["projs"]): + warn( + "Evoked instance has unapplied projectors. Consider applying " + "them before exporting with evoked.apply_proj()." + ) # Initialize writer # Future changes: conditions based on version or mffpy requirement if @@ -70,11 +73,11 @@ def export_evokeds_mff(fname, evoked, history=None, *, overwrite=False, os.remove(fname) if op.isfile(fname) else shutil.rmtree(fname) writer = mffpy.Writer(fname) current_time = pytz.utc.localize(datetime.datetime.utcnow()) - writer.addxml('fileInfo', recordTime=current_time) + writer.addxml("fileInfo", recordTime=current_time) try: - device = info['device_info']['type'] + device = info["device_info"]["type"] except (TypeError, KeyError): - raise ValueError('No device type. Cannot determine sensor layout.') + raise ValueError("No device type. Cannot determine sensor layout.") writer.add_coordinates_and_sensor_layout(device) # Add EEG data @@ -88,11 +91,11 @@ def export_evokeds_mff(fname, evoked, history=None, *, overwrite=False, # Add categories categories_content = _categories_content_from_evokeds(evoked) - writer.addxml('categories', categories=categories_content) + writer.addxml("categories", categories=categories_content) # Add history if history: - writer.addxml('historyEntries', entries=history) + writer.addxml("historyEntries", entries=history) writer.write() @@ -103,14 +106,20 @@ def _categories_content_from_evokeds(evoked): begin_time = 0 for ave in evoked: # Times are converted to microseconds - sfreq = ave.info['sfreq'] + sfreq = ave.info["sfreq"] duration = np.round(len(ave.times) / sfreq * 1e6).astype(int) end_time = begin_time + duration event_time = begin_time - np.round(ave.tmin * 1e6).astype(int) eeg_bads = _get_bad_eeg_channels(ave.info) content[ave.comment] = [ - _build_segment_content(begin_time, end_time, event_time, eeg_bads, - name='Average', nsegs=ave.nave) + _build_segment_content( + begin_time, + end_time, + event_time, + eeg_bads, + name="Average", + nsegs=ave.nave, + ) ] begin_time += duration return content @@ -122,49 +131,47 @@ def _get_bad_eeg_channels(info): Given a list of only the EEG channels in file, return the indices of this list (starting at 1) that correspond to bad channels. """ - if len(info['bads']) == 0: + if len(info["bads"]) == 0: return [] eeg_channels = pick_types(info, eeg=True, exclude=[]) - bad_channels = pick_channels(info['ch_names'], info['bads']) + bad_channels = pick_channels(info["ch_names"], info["bads"]) bads_elementwise = np.isin(eeg_channels, bad_channels) return list(np.flatnonzero(bads_elementwise) + 1) -def _build_segment_content(begin_time, end_time, event_time, eeg_bads, - status='unedited', name=None, pns_bads=None, - nsegs=None): +def _build_segment_content( + begin_time, + end_time, + event_time, + eeg_bads, + status="unedited", + name=None, + pns_bads=None, + nsegs=None, +): """Build content for a single segment in categories.xml. Segments are sorted into categories in categories.xml. In a segmented MFF each category can contain multiple segments, but in an averaged MFF each category only contains one segment (the average). """ - channel_status = [{ - 'signalBin': 1, - 'exclusion': 'badChannels', - 'channels': eeg_bads - }] + channel_status = [ + {"signalBin": 1, "exclusion": "badChannels", "channels": eeg_bads} + ] if pns_bads: - channel_status.append({ - 'signalBin': 2, - 'exclusion': 'badChannels', - 'channels': pns_bads - }) + channel_status.append( + {"signalBin": 2, "exclusion": "badChannels", "channels": pns_bads} + ) content = { - 'status': status, - 'beginTime': begin_time, - 'endTime': end_time, - 'evtBegin': event_time, - 'evtEnd': event_time, - 'channelStatus': channel_status, + "status": status, + "beginTime": begin_time, + "endTime": end_time, + "evtBegin": event_time, + "evtEnd": event_time, + "channelStatus": channel_status, } if name: - content['name'] = name + content["name"] = name if nsegs: - content['keys'] = { - '#seg': { - 'type': 'long', - 'data': nsegs - } - } + content["keys"] = {"#seg": {"type": "long", "data": nsegs}} return content diff --git a/mne/export/_export.py b/mne/export/_export.py index c26927d1755..5afa420540c 100644 --- a/mne/export/_export.py +++ b/mne/export/_export.py @@ -9,8 +9,16 @@ @verbose -def export_raw(fname, raw, fmt='auto', physical_range='auto', - add_ch_type=False, *, overwrite=False, verbose=None): +def export_raw( + fname, + raw, + fmt="auto", + physical_range="auto", + add_ch_type=False, + *, + overwrite=False, + verbose=None, +): """Export Raw to external formats. %(export_fmt_support_raw)s @@ -40,30 +48,39 @@ def export_raw(fname, raw, fmt='auto', physical_range='auto', """ fname = str(_check_fname(fname, overwrite=overwrite)) supported_export_formats = { # format : (extensions,) - 'eeglab': ('set',), - 'edf': ('edf',), - 'brainvision': ('eeg', 'vmrk', 'vhdr',) + "eeglab": ("set",), + "edf": ("edf",), + "brainvision": ( + "eeg", + "vmrk", + "vhdr", + ), } fmt = _infer_check_export_fmt(fmt, fname, supported_export_formats) # check for unapplied projectors - if any(not proj['active'] for proj in raw.info['projs']): - warn('Raw instance has unapplied projectors. Consider applying ' - 'them before exporting with raw.apply_proj().') + if any(not proj["active"] for proj in raw.info["projs"]): + warn( + "Raw instance has unapplied projectors. Consider applying " + "them before exporting with raw.apply_proj()." + ) - if fmt == 'eeglab': + if fmt == "eeglab": from ._eeglab import _export_raw + _export_raw(fname, raw) - elif fmt == 'edf': + elif fmt == "edf": from ._edf import _export_raw + _export_raw(fname, raw, physical_range, add_ch_type) - elif fmt == 'brainvision': + elif fmt == "brainvision": from ._brainvision import _export_raw + _export_raw(fname, raw, overwrite) @verbose -def export_epochs(fname, epochs, fmt='auto', *, overwrite=False, verbose=None): +def export_epochs(fname, epochs, fmt="auto", *, overwrite=False, verbose=None): """Export Epochs to external formats. %(export_fmt_support_epochs)s @@ -90,23 +107,25 @@ def export_epochs(fname, epochs, fmt='auto', *, overwrite=False, verbose=None): """ fname = str(_check_fname(fname, overwrite=overwrite)) supported_export_formats = { - 'eeglab': ('set',), + "eeglab": ("set",), } fmt = _infer_check_export_fmt(fmt, fname, supported_export_formats) # check for unapplied projectors - if any(not proj['active'] for proj in epochs.info['projs']): - warn('Epochs instance has unapplied projectors. Consider applying ' - 'them before exporting with epochs.apply_proj().') + if any(not proj["active"] for proj in epochs.info["projs"]): + warn( + "Epochs instance has unapplied projectors. Consider applying " + "them before exporting with epochs.apply_proj()." + ) - if fmt == 'eeglab': + if fmt == "eeglab": from ._eeglab import _export_epochs + _export_epochs(fname, epochs) @verbose -def export_evokeds(fname, evoked, fmt='auto', *, overwrite=False, - verbose=None): +def export_evokeds(fname, evoked, fmt="auto", *, overwrite=False, verbose=None): """Export evoked dataset to external formats. This function is a wrapper for format-specific export functions. The export @@ -143,16 +162,16 @@ def export_evokeds(fname, evoked, fmt='auto', *, overwrite=False, """ fname = str(_check_fname(fname, overwrite=overwrite)) supported_export_formats = { - 'mff': ('mff',), + "mff": ("mff",), } fmt = _infer_check_export_fmt(fmt, fname, supported_export_formats) if not isinstance(evoked, list): evoked = [evoked] - logger.info(f'Exporting evoked dataset to {fname}...') + logger.info(f"Exporting evoked dataset to {fname}...") - if fmt == 'mff': + if fmt == "mff": export_evokeds_mff(fname, evoked, overwrite=overwrite) @@ -174,26 +193,30 @@ def _infer_check_export_fmt(fmt, fname, supported_formats): Dictionary containing supported formats (as keys) and each format's corresponding file extensions in a tuple (e.g., {'eeglab': ('set',)}) """ - _validate_type(fmt, str, 'fmt') + _validate_type(fmt, str, "fmt") fmt = fmt.lower() if fmt == "auto": fmt = op.splitext(fname)[1] if fmt: fmt = fmt[1:].lower() # find fmt in supported formats dict's tuples - fmt = next((k for k, v in supported_formats.items() if fmt in v), - fmt) # default to original fmt for raising error later + fmt = next( + (k for k, v in supported_formats.items() if fmt in v), fmt + ) # default to original fmt for raising error later else: - raise ValueError(f"Couldn't infer format from filename {fname}" - " (no extension found)") + raise ValueError( + f"Couldn't infer format from filename {fname}" " (no extension found)" + ) if fmt not in supported_formats: supported = [] for format, extensions in supported_formats.items(): - ext_str = ', '.join(f'*.{ext}' for ext in extensions) - supported.append(f'{format} ({ext_str})') - - supported_str = ', '.join(supported) - raise ValueError(f"Format '{fmt}' is not supported. " - f"Supported formats are {supported_str}.") + ext_str = ", ".join(f"*.{ext}" for ext in extensions) + supported.append(f"{format} ({ext_str})") + + supported_str = ", ".join(supported) + raise ValueError( + f"Format '{fmt}' is not supported. " + f"Supported formats are {supported_str}." + ) return fmt diff --git a/mne/export/tests/test_export.py b/mne/export/tests/test_export.py index 27e29ab343f..4aeada34543 100644 --- a/mne/export/tests/test_export.py +++ b/mne/export/tests/test_export.py @@ -10,39 +10,49 @@ import pytest import numpy as np -from numpy.testing import (assert_allclose, assert_array_almost_equal, - assert_array_equal) +from numpy.testing import assert_allclose, assert_array_almost_equal, assert_array_equal -from mne import (read_epochs_eeglab, Epochs, read_evokeds, read_evokeds_mff, - Annotations) +from mne import read_epochs_eeglab, Epochs, read_evokeds, read_evokeds_mff, Annotations from mne.datasets import testing, misc from mne.export import export_evokeds, export_evokeds_mff from mne.fixes import _compare_version -from mne.io import (RawArray, read_raw_fif, read_raw_eeglab, read_raw_edf, - read_raw_brainvision) +from mne.io import ( + RawArray, + read_raw_fif, + read_raw_eeglab, + read_raw_edf, + read_raw_brainvision, +) from mne.io.meas_info import create_info -from mne.utils import (_check_eeglabio_installed, requires_version, - object_diff, _check_edflib_installed, _resource_path, - _check_pybv_installed, _record_warnings) +from mne.utils import ( + _check_eeglabio_installed, + requires_version, + object_diff, + _check_edflib_installed, + _resource_path, + _check_pybv_installed, + _record_warnings, +) from mne.tests.test_epochs import _get_data -fname_evoked = _resource_path('mne.io.tests.data', 'test-ave.fif') -fname_raw = _resource_path('mne.io.tests.data', 'test_raw.fif') +fname_evoked = _resource_path("mne.io.tests.data", "test-ave.fif") +fname_raw = _resource_path("mne.io.tests.data", "test_raw.fif") data_path = testing.data_path(download=False) egi_evoked_fname = data_path / "EGI" / "test_egi_evoked.mff" misc_path = misc.data_path(download=False) -@pytest.mark.skipif(not _check_pybv_installed(strict=False), - reason='pybv not installed') +@pytest.mark.skipif( + not _check_pybv_installed(strict=False), reason="pybv not installed" +) @pytest.mark.parametrize( - ['meas_date', 'orig_time', 'ext'], [ - [None, None, '.vhdr'], - [datetime(2022, 12, 3, 19, 1, 10, 720100, tzinfo=timezone.utc), - None, - '.eeg'], - ]) + ["meas_date", "orig_time", "ext"], + [ + [None, None, ".vhdr"], + [datetime(2022, 12, 3, 19, 1, 10, 720100, tzinfo=timezone.utc), None, ".eeg"], + ], +) def test_export_raw_pybv(tmp_path, meas_date, orig_time, ext): """Test saving a Raw instance to BrainVision format via pybv.""" raw = read_raw_fif(fname_raw, preload=True) @@ -66,39 +76,39 @@ def test_export_raw_pybv(tmp_path, meas_date, orig_time, ext): ) raw.set_annotations(annots) - temp_fname = tmp_path / ('test' + ext) + temp_fname = tmp_path / ("test" + ext) with pytest.warns(RuntimeWarning, match="'short' format. Converting"): raw.export(temp_fname) - raw_read = read_raw_brainvision(str(temp_fname).replace('.eeg', '.vhdr')) + raw_read = read_raw_brainvision(str(temp_fname).replace(".eeg", ".vhdr")) assert raw.ch_names == raw_read.ch_names assert_allclose(raw.times, raw_read.times) assert_allclose(raw.get_data(), raw_read.get_data()) -@requires_version('pymatreader') -@pytest.mark.skipif(not _check_eeglabio_installed(strict=False), - reason='eeglabio not installed') +@requires_version("pymatreader") +@pytest.mark.skipif( + not _check_eeglabio_installed(strict=False), reason="eeglabio not installed" +) def test_export_raw_eeglab(tmp_path): """Test saving a Raw instance to EEGLAB's set format.""" raw = read_raw_fif(fname_raw, preload=True) raw.apply_proj() temp_fname = tmp_path / "test.set" raw.export(temp_fname) - raw.drop_channels([ch for ch in ['epoc'] - if ch in raw.ch_names]) + raw.drop_channels([ch for ch in ["epoc"] if ch in raw.ch_names]) - with pytest.warns(RuntimeWarning, match='is above the 99th percentile'): - raw_read = read_raw_eeglab(temp_fname, preload=True, montage_units='m') + with pytest.warns(RuntimeWarning, match="is above the 99th percentile"): + raw_read = read_raw_eeglab(temp_fname, preload=True, montage_units="m") assert raw.ch_names == raw_read.ch_names - cart_coords = np.array([d['loc'][:3] for d in raw.info['chs']]) # just xyz - cart_coords_read = np.array([d['loc'][:3] for d in raw_read.info['chs']]) + cart_coords = np.array([d["loc"][:3] for d in raw.info["chs"]]) # just xyz + cart_coords_read = np.array([d["loc"][:3] for d in raw_read.info["chs"]]) assert_allclose(cart_coords, cart_coords_read) assert_allclose(raw.times, raw_read.times) assert_allclose(raw.get_data(), raw_read.get_data()) # test overwrite - with pytest.raises(FileExistsError, match='Destination file exists'): + with pytest.raises(FileExistsError, match="Destination file exists"): raw.export(temp_fname, overwrite=False) raw.export(temp_fname, overwrite=True) @@ -107,29 +117,41 @@ def test_export_raw_eeglab(tmp_path): # test warning with unapplied projectors raw = read_raw_fif(fname_raw, preload=True) - with pytest.warns(RuntimeWarning, - match='Raw instance has unapplied projectors.'): + with pytest.warns(RuntimeWarning, match="Raw instance has unapplied projectors."): raw.export(temp_fname, overwrite=True) -@pytest.mark.skipif(not _check_edflib_installed(strict=False), - reason='edflib-python not installed') +@pytest.mark.skipif( + not _check_edflib_installed(strict=False), reason="edflib-python not installed" +) def test_double_export_edf(tmp_path): """Test exporting an EDF file multiple times.""" rng = np.random.RandomState(123456) - format = 'edf' - ch_types = ['eeg', 'eeg', 'stim', 'ecog', 'ecog', 'seeg', 'eog', 'ecg', - 'emg', 'dbs', 'bio'] + format = "edf" + ch_types = [ + "eeg", + "eeg", + "stim", + "ecog", + "ecog", + "seeg", + "eog", + "ecg", + "emg", + "dbs", + "bio", + ] info = create_info(len(ch_types), sfreq=1000, ch_types=ch_types) data = rng.random(size=(len(ch_types), 1000)) * 1e-5 # include subject info and measurement date - info['subject_info'] = dict(first_name='mne', last_name='python', - birthday=(1992, 1, 20), sex=1, hand=3) + info["subject_info"] = dict( + first_name="mne", last_name="python", birthday=(1992, 1, 20), sex=1, hand=3 + ) raw = RawArray(data, info) # export once - temp_fname = tmp_path / f'test.{format}' + temp_fname = tmp_path / f"test.{format}" raw.export(temp_fname, add_ch_type=True) raw_read = read_raw_edf(temp_fname, infer_types=True, preload=True) @@ -139,15 +161,15 @@ def test_double_export_edf(tmp_path): raw_read = read_raw_edf(temp_fname, infer_types=True, preload=True) # stim channel should be dropped - raw.drop_channels('2') + raw.drop_channels("2") assert raw.ch_names == raw_read.ch_names # only compare the original length, since extra zeros are appended orig_raw_len = len(raw) assert_array_almost_equal( - raw.get_data(), raw_read.get_data()[:, :orig_raw_len], decimal=4) - assert_allclose( - raw.times, raw_read.times[:orig_raw_len], rtol=0, atol=1e-5) + raw.get_data(), raw_read.get_data()[:, :orig_raw_len], decimal=4 + ) + assert_allclose(raw.times, raw_read.times[:orig_raw_len], rtol=0, atol=1e-5) # check channel types except for 'bio', which loses its type orig_ch_types = raw.get_channel_types() @@ -155,28 +177,41 @@ def test_double_export_edf(tmp_path): assert_array_equal(orig_ch_types, read_ch_types) # check handling of missing subject metadata - del info['subject_info']['sex'] + del info["subject_info"]["sex"] raw_2 = RawArray(data, info) raw_2.export(temp_fname, add_ch_type=True, overwrite=True) -@pytest.mark.skipif(not _check_edflib_installed(strict=False), - reason='edflib-python not installed') +@pytest.mark.skipif( + not _check_edflib_installed(strict=False), reason="edflib-python not installed" +) def test_export_edf_annotations(tmp_path): """Test that exporting EDF preserves annotations.""" rng = np.random.RandomState(123456) - format = 'edf' - ch_types = ['eeg', 'eeg', 'stim', 'ecog', 'ecog', 'seeg', - 'eog', 'ecg', 'emg', 'dbs', 'bio'] + format = "edf" + ch_types = [ + "eeg", + "eeg", + "stim", + "ecog", + "ecog", + "seeg", + "eog", + "ecg", + "emg", + "dbs", + "bio", + ] ch_names = np.arange(len(ch_types)).astype(str).tolist() - info = create_info(ch_names, sfreq=1000, - ch_types=ch_types) - data = rng.random(size=(len(ch_names), 2000)) * 1.e-5 + info = create_info(ch_names, sfreq=1000, ch_types=ch_types) + data = rng.random(size=(len(ch_names), 2000)) * 1.0e-5 raw = RawArray(data, info) annotations = Annotations( - onset=[0.01, 0.05, 0.90, 1.05], duration=[0, 1, 0, 0], - description=['test1', 'test2', 'test3', 'test4']) + onset=[0.01, 0.05, 0.90, 1.05], + duration=[0, 1, 0, 0], + description=["test1", "test2", "test3", "test4"], + ) raw.set_annotations(annotations) # export @@ -187,33 +222,37 @@ def test_export_edf_annotations(tmp_path): raw_read = read_raw_edf(temp_fname, preload=True) assert_array_equal(raw.annotations.onset, raw_read.annotations.onset) assert_array_equal(raw.annotations.duration, raw_read.annotations.duration) - assert_array_equal(raw.annotations.description, - raw_read.annotations.description) + assert_array_equal(raw.annotations.description, raw_read.annotations.description) -@pytest.mark.skipif(not _check_edflib_installed(strict=False), - reason='edflib-python not installed') +@pytest.mark.skipif( + not _check_edflib_installed(strict=False), reason="edflib-python not installed" +) def test_rawarray_edf(tmp_path): """Test saving a Raw array with integer sfreq to EDF.""" rng = np.random.RandomState(12345) - format = 'edf' - ch_types = ['eeg', 'eeg', 'stim', 'ecog', 'seeg', 'eog', 'ecg', 'emg', - 'dbs', 'bio'] + format = "edf" + ch_types = ["eeg", "eeg", "stim", "ecog", "seeg", "eog", "ecg", "emg", "dbs", "bio"] ch_names = np.arange(len(ch_types)).astype(str).tolist() - info = create_info(ch_names, sfreq=1000, - ch_types=ch_types) + info = create_info(ch_names, sfreq=1000, ch_types=ch_types) data = rng.random(size=(len(ch_names), 1000)) * 1e-5 # include subject info and measurement date - subject_info = dict(first_name='mne', last_name='python', - birthday=(1992, 1, 20), sex=1, hand=3) - info['subject_info'] = subject_info + subject_info = dict( + first_name="mne", last_name="python", birthday=(1992, 1, 20), sex=1, hand=3 + ) + info["subject_info"] = subject_info raw = RawArray(data, info) time_now = datetime.now() - meas_date = datetime(year=time_now.year, month=time_now.month, - day=time_now.day, hour=time_now.hour, - minute=time_now.minute, second=time_now.second, - tzinfo=timezone.utc) + meas_date = datetime( + year=time_now.year, + month=time_now.month, + day=time_now.day, + hour=time_now.hour, + minute=time_now.minute, + second=time_now.second, + tzinfo=timezone.utc, + ) raw.set_meas_date(meas_date) temp_fname = tmp_path / f"test.{format}" @@ -221,82 +260,84 @@ def test_rawarray_edf(tmp_path): raw_read = read_raw_edf(temp_fname, infer_types=True, preload=True) # stim channel should be dropped - raw.drop_channels('2') + raw.drop_channels("2") assert raw.ch_names == raw_read.ch_names # only compare the original length, since extra zeros are appended orig_raw_len = len(raw) assert_array_almost_equal( - raw.get_data(), raw_read.get_data()[:, :orig_raw_len], decimal=4) - assert_allclose( - raw.times, raw_read.times[:orig_raw_len], rtol=0, atol=1e-5) + raw.get_data(), raw_read.get_data()[:, :orig_raw_len], decimal=4 + ) + assert_allclose(raw.times, raw_read.times[:orig_raw_len], rtol=0, atol=1e-5) # check channel types except for 'bio', which loses its type orig_ch_types = raw.get_channel_types() read_ch_types = raw_read.get_channel_types() assert_array_equal(orig_ch_types, read_ch_types) - assert raw.info['meas_date'] == raw_read.info['meas_date'] + assert raw.info["meas_date"] == raw_read.info["meas_date"] # channel name can't be longer than 16 characters with the type added raw_bad = raw.copy() - raw_bad.rename_channels({'1': 'abcdefghijklmnopqrstuvwxyz'}) - with pytest.raises(RuntimeError, match='Signal label'), \ - pytest.warns(RuntimeWarning, match='Data has a non-integer'): + raw_bad.rename_channels({"1": "abcdefghijklmnopqrstuvwxyz"}) + with pytest.raises(RuntimeError, match="Signal label"), pytest.warns( + RuntimeWarning, match="Data has a non-integer" + ): raw_bad.export(temp_fname, overwrite=True) # include bad birthday that is non-EDF compliant bad_info = info.copy() - bad_info['subject_info']['birthday'] = (1700, 1, 20) + bad_info["subject_info"]["birthday"] = (1700, 1, 20) raw = RawArray(data, bad_info) - with pytest.raises(RuntimeError, match='Setting patient birth date'): + with pytest.raises(RuntimeError, match="Setting patient birth date"): raw.export(temp_fname, overwrite=True) # include bad measurement date that is non-EDF compliant raw = RawArray(data, info) meas_date = datetime(year=1984, month=1, day=1, tzinfo=timezone.utc) raw.set_meas_date(meas_date) - with pytest.raises(RuntimeError, match='Setting start date time'): + with pytest.raises(RuntimeError, match="Setting start date time"): raw.export(temp_fname, overwrite=True) # test that warning is raised if there are non-voltage based channels raw = RawArray(data, info) - raw.set_channel_types({'9': 'hbr'}, on_unit_change='ignore') - with pytest.warns(RuntimeWarning, match='Non-voltage channels'): + raw.set_channel_types({"9": "hbr"}, on_unit_change="ignore") + with pytest.warns(RuntimeWarning, match="Non-voltage channels"): raw.export(temp_fname, overwrite=True) # data should match up to the non-accepted channel raw_read = read_raw_edf(temp_fname, preload=True) orig_raw_len = len(raw) assert_array_almost_equal( - raw.get_data()[:-1, :], raw_read.get_data()[:, :orig_raw_len], - decimal=4) - assert_allclose( - raw.times, raw_read.times[:orig_raw_len], rtol=0, atol=1e-5) + raw.get_data()[:-1, :], raw_read.get_data()[:, :orig_raw_len], decimal=4 + ) + assert_allclose(raw.times, raw_read.times[:orig_raw_len], rtol=0, atol=1e-5) # the data should still match though raw_read = read_raw_edf(temp_fname, preload=True) - raw.drop_channels('2') + raw.drop_channels("2") assert raw.ch_names == raw_read.ch_names orig_raw_len = len(raw) assert_array_almost_equal( - raw.get_data(), raw_read.get_data()[:, :orig_raw_len], decimal=4) - assert_allclose( - raw.times, raw_read.times[:orig_raw_len], rtol=0, atol=1e-5) + raw.get_data(), raw_read.get_data()[:, :orig_raw_len], decimal=4 + ) + assert_allclose(raw.times, raw_read.times[:orig_raw_len], rtol=0, atol=1e-5) -@pytest.mark.skipif(not _check_edflib_installed(strict=False), - reason='edflib-python not installed') +@pytest.mark.skipif( + not _check_edflib_installed(strict=False), reason="edflib-python not installed" +) @pytest.mark.parametrize( - ['dataset', 'format'], [ - ['test', 'edf'], - pytest.param('misc', 'edf', marks=[pytest.mark.slowtest, - misc._pytest_mark()]), - ]) + ["dataset", "format"], + [ + ["test", "edf"], + pytest.param("misc", "edf", marks=[pytest.mark.slowtest, misc._pytest_mark()]), + ], +) def test_export_raw_edf(tmp_path, dataset, format): """Test saving a Raw instance to EDF format.""" - if dataset == 'test': + if dataset == "test": raw = read_raw_fif(fname_raw) - elif dataset == 'misc': + elif dataset == "misc": fname = misc_path / "ecog" / "sample_ecog_ieeg.fif" raw = read_raw_fif(fname) @@ -309,31 +350,27 @@ def test_export_raw_edf(tmp_path, dataset, format): # test runtime errors with pytest.warns() as record: raw.export(temp_fname, physical_range=(-1e6, 0)) - if dataset == 'test': - assert any( - "Data has a non-integer" in str(rec.message) for rec in record - ) + if dataset == "test": + assert any("Data has a non-integer" in str(rec.message) for rec in record) assert any("The maximum" in str(rec.message) for rec in record) remove(temp_fname) with pytest.warns() as record: raw.export(temp_fname, physical_range=(0, 1e6)) - if dataset == 'test': - assert any( - "Data has a non-integer" in str(rec.message) for rec in record - ) + if dataset == "test": + assert any("Data has a non-integer" in str(rec.message) for rec in record) assert any("The minimum" in str(rec.message) for rec in record) remove(temp_fname) - if dataset == 'test': - with pytest.warns(RuntimeWarning, match='Data has a non-integer'): + if dataset == "test": + with pytest.warns(RuntimeWarning, match="Data has a non-integer"): raw.export(temp_fname) - elif dataset == 'misc': - with pytest.warns(RuntimeWarning, match='EDF format requires'): + elif dataset == "misc": + with pytest.warns(RuntimeWarning, match="EDF format requires"): raw.export(temp_fname) - if 'epoc' in raw.ch_names: - raw.drop_channels(['epoc']) + if "epoc" in raw.ch_names: + raw.drop_channels(["epoc"]) raw_read = read_raw_edf(temp_fname, preload=True) assert orig_ch_names == raw_read.ch_names @@ -346,7 +383,8 @@ def test_export_raw_edf(tmp_path, dataset, format): # will result in a resolution of 0.09 uV. This resolution # though is acceptable for most EEG manufacturers. assert_array_almost_equal( - raw.get_data(), raw_read.get_data()[:, :orig_raw_len], decimal=4) + raw.get_data(), raw_read.get_data()[:, :orig_raw_len], decimal=4 + ) # Due to the data record duration limitations of EDF files, one # cannot store arbitrary float sampling rate exactly. Usually this @@ -354,46 +392,43 @@ def test_export_raw_edf(tmp_path, dataset, format): # decimal points. This for practical purposes does not matter # but will result in an error when say the number of time points # is very very large. - assert_allclose( - raw.times, raw_read.times[:orig_raw_len], rtol=0, atol=1e-5) + assert_allclose(raw.times, raw_read.times[:orig_raw_len], rtol=0, atol=1e-5) -@pytest.mark.xfail(reason='eeglabio (usage?) bugs that should be fixed') -@requires_version('pymatreader') -@pytest.mark.skipif(not _check_eeglabio_installed(strict=False), - reason='eeglabio not installed') -@pytest.mark.parametrize('preload', (True, False)) +@pytest.mark.xfail(reason="eeglabio (usage?) bugs that should be fixed") +@requires_version("pymatreader") +@pytest.mark.skipif( + not _check_eeglabio_installed(strict=False), reason="eeglabio not installed" +) +@pytest.mark.parametrize("preload", (True, False)) def test_export_epochs_eeglab(tmp_path, preload): """Test saving an Epochs instance to EEGLAB's set format.""" import eeglabio + raw, events = _get_data()[:2] raw.load_data() epochs = Epochs(raw, events, preload=preload) temp_fname = tmp_path / "test.set" # TODO: eeglabio 0.2 warns about invalid events - if _compare_version(eeglabio.__version__, '==', '0.0.2-1'): + if _compare_version(eeglabio.__version__, "==", "0.0.2-1"): ctx = _record_warnings else: ctx = nullcontext with ctx(): epochs.export(temp_fname) - epochs.drop_channels([ch for ch in ['epoc', 'STI 014'] - if ch in epochs.ch_names]) + epochs.drop_channels([ch for ch in ["epoc", "STI 014"] if ch in epochs.ch_names]) epochs_read = read_epochs_eeglab(temp_fname) assert epochs.ch_names == epochs_read.ch_names - cart_coords = np.array([d['loc'][:3] - for d in epochs.info['chs']]) # just xyz - cart_coords_read = np.array([d['loc'][:3] - for d in epochs_read.info['chs']]) + cart_coords = np.array([d["loc"][:3] for d in epochs.info["chs"]]) # just xyz + cart_coords_read = np.array([d["loc"][:3] for d in epochs_read.info["chs"]]) assert_allclose(cart_coords, cart_coords_read) - assert_array_equal(epochs.events[:, 0], - epochs_read.events[:, 0]) # latency + assert_array_equal(epochs.events[:, 0], epochs_read.events[:, 0]) # latency assert epochs.event_id.keys() == epochs_read.event_id.keys() # just keys assert_allclose(epochs.times, epochs_read.times) assert_allclose(epochs.get_data(), epochs_read.get_data()) # test overwrite - with pytest.raises(FileExistsError, match='Destination file exists'): + with pytest.raises(FileExistsError, match="Destination file exists"): epochs.export(temp_fname, overwrite=False) with ctx(): epochs.export(temp_fname, overwrite=True) @@ -404,45 +439,46 @@ def test_export_epochs_eeglab(tmp_path, preload): # test warning with unapplied projectors epochs = Epochs(raw, events, preload=preload, proj=False) - with pytest.warns(RuntimeWarning, - match='Epochs instance has unapplied projectors.'): + with pytest.warns( + RuntimeWarning, match="Epochs instance has unapplied projectors." + ): epochs.export(Path(temp_fname), overwrite=True) -@pytest.mark.filterwarnings('ignore::FutureWarning') -@requires_version('mffpy', '0.5.7') +@pytest.mark.filterwarnings("ignore::FutureWarning") +@requires_version("mffpy", "0.5.7") @testing.requires_testing_data -@pytest.mark.parametrize('fmt', ('auto', 'mff')) -@pytest.mark.parametrize('do_history', (True, False)) +@pytest.mark.parametrize("fmt", ("auto", "mff")) +@pytest.mark.parametrize("do_history", (True, False)) def test_export_evokeds_to_mff(tmp_path, fmt, do_history): """Test exporting evoked dataset to MFF.""" evoked = read_evokeds_mff(egi_evoked_fname) export_fname = tmp_path / "evoked.mff" history = [ { - 'name': 'Test Segmentation', - 'method': 'Segmentation', - 'settings': ['Setting 1', 'Setting 2'], - 'results': ['Result 1', 'Result 2'] + "name": "Test Segmentation", + "method": "Segmentation", + "settings": ["Setting 1", "Setting 2"], + "results": ["Result 1", "Result 2"], }, { - 'name': 'Test Averaging', - 'method': 'Averaging', - 'settings': ['Setting 1', 'Setting 2'], - 'results': ['Result 1', 'Result 2'] - } + "name": "Test Averaging", + "method": "Averaging", + "settings": ["Setting 1", "Setting 2"], + "results": ["Result 1", "Result 2"], + }, ] if do_history: export_evokeds_mff(export_fname, evoked, history=history) else: export_evokeds(export_fname, evoked, fmt=fmt) # Drop non-EEG channels - evoked = [ave.drop_channels(['ECG', 'EMG']) for ave in evoked] + evoked = [ave.drop_channels(["ECG", "EMG"]) for ave in evoked] evoked_exported = read_evokeds_mff(export_fname) assert len(evoked) == len(evoked_exported) for ave, ave_exported in zip(evoked, evoked_exported): # Compare infos - assert object_diff(ave_exported.info, ave.info) == '' + assert object_diff(ave_exported.info, ave.info) == "" # Compare data assert_allclose(ave_exported.data, ave.data) # Compare properties @@ -452,16 +488,14 @@ def test_export_evokeds_to_mff(tmp_path, fmt, do_history): assert_allclose(ave_exported.times, ave.times) # test overwrite - with pytest.raises(FileExistsError, match='Destination file exists'): + with pytest.raises(FileExistsError, match="Destination file exists"): if do_history: - export_evokeds_mff(export_fname, evoked, history=history, - overwrite=False) + export_evokeds_mff(export_fname, evoked, history=history, overwrite=False) else: export_evokeds(export_fname, evoked, overwrite=False) if do_history: - export_evokeds_mff(export_fname, evoked, history=history, - overwrite=True) + export_evokeds_mff(export_fname, evoked, history=history, overwrite=True) else: export_evokeds(export_fname, evoked, overwrite=True) @@ -469,35 +503,33 @@ def test_export_evokeds_to_mff(tmp_path, fmt, do_history): evoked[0].export(export_fname, overwrite=True) -@pytest.mark.filterwarnings('ignore::FutureWarning') -@requires_version('mffpy', '0.5.7') +@pytest.mark.filterwarnings("ignore::FutureWarning") +@requires_version("mffpy", "0.5.7") @testing.requires_testing_data def test_export_to_mff_no_device(): """Test no device type throws ValueError.""" - evoked = read_evokeds_mff(egi_evoked_fname, condition='Category 1') - evoked.info['device_info'] = None - with pytest.raises(ValueError, match='No device type.'): - export_evokeds('output.mff', evoked) + evoked = read_evokeds_mff(egi_evoked_fname, condition="Category 1") + evoked.info["device_info"] = None + with pytest.raises(ValueError, match="No device type."): + export_evokeds("output.mff", evoked) -@pytest.mark.filterwarnings('ignore::FutureWarning') -@requires_version('mffpy', '0.5.7') +@pytest.mark.filterwarnings("ignore::FutureWarning") +@requires_version("mffpy", "0.5.7") def test_export_to_mff_incompatible_sfreq(): """Test non-whole number sampling frequency throws ValueError.""" evoked = read_evokeds(fname_evoked) with pytest.raises(ValueError, match=f'sfreq: {evoked[0].info["sfreq"]}'): - export_evokeds('output.mff', evoked) + export_evokeds("output.mff", evoked) -@pytest.mark.parametrize('fmt,ext', [ - ('EEGLAB', 'set'), - ('EDF', 'edf'), - ('BrainVision', 'vhdr'), - ('auto', 'vhdr') -]) +@pytest.mark.parametrize( + "fmt,ext", + [("EEGLAB", "set"), ("EDF", "edf"), ("BrainVision", "vhdr"), ("auto", "vhdr")], +) def test_export_evokeds_unsupported_format(fmt, ext): """Test exporting evoked dataset to non-supported formats.""" evoked = read_evokeds(fname_evoked) errstr = fmt.lower() if fmt != "auto" else "vhdr" with pytest.raises(ValueError, match=f"Format '{errstr}' is not .*"): - export_evokeds(f'output.{ext}', evoked, fmt=fmt) + export_evokeds(f"output.{ext}", evoked, fmt=fmt) diff --git a/mne/filter.py b/mne/filter.py index 2fc0d10b2c4..5277a1fd502 100644 --- a/mne/filter.py +++ b/mne/filter.py @@ -8,11 +8,25 @@ from .annotations import _annotations_starts_stops from .io.pick import _picks_to_idx -from .cuda import (_setup_cuda_fft_multiply_repeated, _fft_multiply_repeated, - _setup_cuda_fft_resample, _fft_resample, _smart_pad) +from .cuda import ( + _setup_cuda_fft_multiply_repeated, + _fft_multiply_repeated, + _setup_cuda_fft_resample, + _fft_resample, + _smart_pad, +) from .parallel import parallel_func -from .utils import (logger, verbose, sum_squared, warn, _pl, - _check_preload, _validate_type, _check_option, _ensure_int) +from .utils import ( + logger, + verbose, + sum_squared, + warn, + _pl, + _check_preload, + _validate_type, + _check_option, + _ensure_int, +) from ._ola import _COLA # These values from Ifeachor and Jervis. @@ -66,20 +80,178 @@ def next_fast_len(target): Copied from SciPy with minor modifications. """ from bisect import bisect_left - hams = (8, 9, 10, 12, 15, 16, 18, 20, 24, 25, 27, 30, 32, 36, 40, 45, 48, - 50, 54, 60, 64, 72, 75, 80, 81, 90, 96, 100, 108, 120, 125, 128, - 135, 144, 150, 160, 162, 180, 192, 200, 216, 225, 240, 243, 250, - 256, 270, 288, 300, 320, 324, 360, 375, 384, 400, 405, 432, 450, - 480, 486, 500, 512, 540, 576, 600, 625, 640, 648, 675, 720, 729, - 750, 768, 800, 810, 864, 900, 960, 972, 1000, 1024, 1080, 1125, - 1152, 1200, 1215, 1250, 1280, 1296, 1350, 1440, 1458, 1500, 1536, - 1600, 1620, 1728, 1800, 1875, 1920, 1944, 2000, 2025, 2048, 2160, - 2187, 2250, 2304, 2400, 2430, 2500, 2560, 2592, 2700, 2880, 2916, - 3000, 3072, 3125, 3200, 3240, 3375, 3456, 3600, 3645, 3750, 3840, - 3888, 4000, 4050, 4096, 4320, 4374, 4500, 4608, 4800, 4860, 5000, - 5120, 5184, 5400, 5625, 5760, 5832, 6000, 6075, 6144, 6250, 6400, - 6480, 6561, 6750, 6912, 7200, 7290, 7500, 7680, 7776, 8000, 8100, - 8192, 8640, 8748, 9000, 9216, 9375, 9600, 9720, 10000) + + hams = ( + 8, + 9, + 10, + 12, + 15, + 16, + 18, + 20, + 24, + 25, + 27, + 30, + 32, + 36, + 40, + 45, + 48, + 50, + 54, + 60, + 64, + 72, + 75, + 80, + 81, + 90, + 96, + 100, + 108, + 120, + 125, + 128, + 135, + 144, + 150, + 160, + 162, + 180, + 192, + 200, + 216, + 225, + 240, + 243, + 250, + 256, + 270, + 288, + 300, + 320, + 324, + 360, + 375, + 384, + 400, + 405, + 432, + 450, + 480, + 486, + 500, + 512, + 540, + 576, + 600, + 625, + 640, + 648, + 675, + 720, + 729, + 750, + 768, + 800, + 810, + 864, + 900, + 960, + 972, + 1000, + 1024, + 1080, + 1125, + 1152, + 1200, + 1215, + 1250, + 1280, + 1296, + 1350, + 1440, + 1458, + 1500, + 1536, + 1600, + 1620, + 1728, + 1800, + 1875, + 1920, + 1944, + 2000, + 2025, + 2048, + 2160, + 2187, + 2250, + 2304, + 2400, + 2430, + 2500, + 2560, + 2592, + 2700, + 2880, + 2916, + 3000, + 3072, + 3125, + 3200, + 3240, + 3375, + 3456, + 3600, + 3645, + 3750, + 3840, + 3888, + 4000, + 4050, + 4096, + 4320, + 4374, + 4500, + 4608, + 4800, + 4860, + 5000, + 5120, + 5184, + 5400, + 5625, + 5760, + 5832, + 6000, + 6075, + 6144, + 6250, + 6400, + 6480, + 6561, + 6750, + 6912, + 7200, + 7290, + 7500, + 7680, + 7776, + 8000, + 8100, + 8192, + 8640, + 8748, + 9000, + 9216, + 9375, + 9600, + 9720, + 10000, + ) if target <= 6: return target @@ -92,7 +264,7 @@ def next_fast_len(target): if target <= hams[-1]: return hams[bisect_left(hams, target)] - match = float('inf') # Anything found will be smaller + match = float("inf") # Anything found will be smaller p5 = 1 while p5 < target: p35 = p5 @@ -121,8 +293,16 @@ def next_fast_len(target): return match -def _overlap_add_filter(x, h, n_fft=None, phase='zero', picks=None, - n_jobs=None, copy=True, pad='reflect_limited'): +def _overlap_add_filter( + x, + h, + n_fft=None, + phase="zero", + picks=None, + n_jobs=None, + copy=True, + pad="reflect_limited", +): """Filter the signal x using h with overlap-add FFTs. Parameters @@ -162,12 +342,12 @@ def _overlap_add_filter(x, h, n_fft=None, phase='zero', picks=None, # response _check_zero_phase_length(len(h), phase) if len(h) == 1: - return x * h ** 2 if phase == 'zero-double' else x * h + return x * h**2 if phase == "zero-double" else x * h n_edge = max(min(len(h), x.shape[1]) - 1, 0) - logger.debug('Smart-padding with: %s samples on each edge' % n_edge) + logger.debug("Smart-padding with: %s samples on each edge" % n_edge) n_x = x.shape[1] + 2 * n_edge - if phase == 'zero-double': + if phase == "zero-double": h = np.convolve(h, h[::-1]) # Determine FFT length to use @@ -176,10 +356,14 @@ def _overlap_add_filter(x, h, n_fft=None, phase='zero', picks=None, max_fft = n_x if max_fft >= min_fft: # cost function based on number of multiplications - N = 2 ** np.arange(np.ceil(np.log2(min_fft)), - np.ceil(np.log2(max_fft)) + 1, dtype=int) - cost = (np.ceil(n_x / (N - len(h) + 1).astype(np.float64)) * - N * (np.log2(N) + 1)) + N = 2 ** np.arange( + np.ceil(np.log2(min_fft)), np.ceil(np.log2(max_fft)) + 1, dtype=int + ) + cost = ( + np.ceil(n_x / (N - len(h) + 1).astype(np.float64)) + * N + * (np.log2(N) + 1) + ) # add a heuristic term to prevent too-long FFT's which are slow # (not predicted by mult. cost alone, 4e-5 exp. determined) @@ -189,10 +373,12 @@ def _overlap_add_filter(x, h, n_fft=None, phase='zero', picks=None, else: # Use only a single block n_fft = next_fast_len(min_fft) - logger.debug('FFT block length: %s' % n_fft) + logger.debug("FFT block length: %s" % n_fft) if n_fft < min_fft: - raise ValueError('n_fft is too short, has to be at least ' - '2 * len(h) - 1 (%s), got %s' % (min_fft, n_fft)) + raise ValueError( + "n_fft is too short, has to be at least " + "2 * len(h) - 1 (%s), got %s" % (min_fft, n_fft) + ) # Figure out if we should use CUDA n_jobs, cuda_dict = _setup_cuda_fft_multiply_repeated(n_jobs, h, n_fft) @@ -202,11 +388,13 @@ def _overlap_add_filter(x, h, n_fft=None, phase='zero', picks=None, parallel, p_fun, _ = parallel_func(_1d_overlap_filter, n_jobs) if n_jobs == 1: for p in picks: - x[p] = _1d_overlap_filter(x[p], len(h), n_edge, phase, - cuda_dict, pad, n_fft) + x[p] = _1d_overlap_filter( + x[p], len(h), n_edge, phase, cuda_dict, pad, n_fft + ) else: - data_new = parallel(p_fun(x[p], len(h), n_edge, phase, - cuda_dict, pad, n_fft) for p in picks) + data_new = parallel( + p_fun(x[p], len(h), n_edge, phase, cuda_dict, pad, n_fft) for p in picks + ) for pp, p in enumerate(picks): x[p] = data_new[pp] @@ -223,7 +411,7 @@ def _1d_overlap_filter(x, n_h, n_edge, phase, cuda_dict, pad, n_fft): n_seg = n_fft - n_h + 1 n_segments = int(np.ceil(n_x / float(n_seg))) - shift = ((n_h - 1) // 2 if phase.startswith('zero') else 0) + n_edge + shift = ((n_h - 1) // 2 if phase.startswith("zero") else 0) + n_edge # Now the actual filtering step is identical for zero-phase (filtfilt-like) # or single-pass @@ -242,13 +430,14 @@ def _1d_overlap_filter(x, n_h, n_edge, phase, cuda_dict, pad, n_fft): x_filtered[start_filt:stop_filt] += prod[start_prod:stop_prod] # Remove mirrored edges that we added and cast (n_edge can be zero) - x_filtered = x_filtered[:n_x - 2 * n_edge].astype(x.dtype) + x_filtered = x_filtered[: n_x - 2 * n_edge].astype(x.dtype) return x_filtered def _filter_attenuation(h, freq, gain): """Compute minimum attenuation at stop frequency.""" from scipy.signal import freqz + _, filt_resp = freqz(h.ravel(), worN=np.pi * freq) filt_resp = np.abs(filt_resp) # use amplitude response filt_resp[np.where(gain == 1)] = 0 @@ -269,12 +458,13 @@ def _prep_for_filtering(x, copy, picks=None): x.shape = (np.prod(x.shape[:-1]), x.shape[-1]) if len(orig_shape) == 3: n_epochs, n_channels, n_times = orig_shape - offset = np.repeat(np.arange(0, n_channels * n_epochs, n_channels), - len(picks)) + offset = np.repeat(np.arange(0, n_channels * n_epochs, n_channels), len(picks)) picks = np.tile(picks, n_epochs) + offset elif len(orig_shape) > 3: - raise ValueError('picks argument is not supported for data with more' - ' than three dimensions') + raise ValueError( + "picks argument is not supported for data with more" + " than three dimensions" + ) assert all(0 <= pick < x.shape[0] for pick in picks) # guaranteed by above return x, orig_shape, picks @@ -283,6 +473,7 @@ def _prep_for_filtering(x, copy, picks=None): def _firwin_design(N, freq, gain, window, sfreq): """Construct a FIR filter using firwin.""" from scipy.signal import firwin + assert freq[0] == 0 assert len(freq) > 1 assert len(freq) == len(gain) @@ -297,30 +488,37 @@ def _firwin_design(N, freq, gain, window, sfreq): assert this_gain in (0, 1) if this_gain != prev_gain: # Get the correct N to satistify the requested transition bandwidth - transition = (prev_freq - this_freq) / 2. + transition = (prev_freq - this_freq) / 2.0 this_N = int(round(_length_factors[window] / transition)) - this_N += (1 - this_N % 2) # make it odd + this_N += 1 - this_N % 2 # make it odd if this_N > N: - raise ValueError('The requested filter length %s is too short ' - 'for the requested %0.2f Hz transition band, ' - 'which requires %s samples' - % (N, transition * sfreq / 2., this_N)) + raise ValueError( + "The requested filter length %s is too short " + "for the requested %0.2f Hz transition band, " + "which requires %s samples" % (N, transition * sfreq / 2.0, this_N) + ) # Construct a lowpass - this_h = firwin(this_N, (prev_freq + this_freq) / 2., - window=window, pass_zero=True, fs=freq[-1] * 2) + this_h = firwin( + this_N, + (prev_freq + this_freq) / 2.0, + window=window, + pass_zero=True, + fs=freq[-1] * 2, + ) assert this_h.shape == (this_N,) offset = (N - this_N) // 2 if this_gain == 0: - h[offset:N - offset] -= this_h + h[offset : N - offset] -= this_h else: - h[offset:N - offset] += this_h + h[offset : N - offset] += this_h prev_gain = this_gain prev_freq = this_freq return h -def _construct_fir_filter(sfreq, freq, gain, filter_length, phase, fir_window, - fir_design): +def _construct_fir_filter( + sfreq, freq, gain, filter_length, phase, fir_window, fir_design +): """Filter signal using gain control points in the frequency domain. The filter impulse response is constructed from a Hann window (window @@ -358,50 +556,53 @@ def _construct_fir_filter(sfreq, freq, gain, filter_length, phase, fir_window, Filter coefficients. """ assert freq[0] == 0 - if fir_design == 'firwin2': + if fir_design == "firwin2": from scipy.signal import firwin2 as fir_design else: - assert fir_design == 'firwin' + assert fir_design == "firwin" fir_design = partial(_firwin_design, sfreq=sfreq) from scipy.signal import minimum_phase # issue a warning if attenuation is less than this - min_att_db = 12 if phase == 'minimum' else 20 + min_att_db = 12 if phase == "minimum" else 20 # normalize frequencies - freq = np.array(freq) / (sfreq / 2.) + freq = np.array(freq) / (sfreq / 2.0) if freq[0] != 0 or freq[-1] != 1: - raise ValueError('freq must start at 0 and end an Nyquist (%s), got %s' - % (sfreq / 2., freq)) + raise ValueError( + "freq must start at 0 and end an Nyquist (%s), got %s" % (sfreq / 2.0, freq) + ) gain = np.array(gain) # Use overlap-add filter with a fixed length N = _check_zero_phase_length(filter_length, phase, gain[-1]) # construct symmetric (linear phase) filter - if phase == 'minimum': + if phase == "minimum": h = fir_design(N * 2 - 1, freq, gain, window=fir_window) h = minimum_phase(h) else: h = fir_design(N, freq, gain, window=fir_window) assert h.size == N att_db, att_freq = _filter_attenuation(h, freq, gain) - if phase == 'zero-double': + if phase == "zero-double": att_db += 6 if att_db < min_att_db: - att_freq *= sfreq / 2. - warn('Attenuation at stop frequency %0.2f Hz is only %0.2f dB. ' - 'Increase filter_length for higher attenuation.' - % (att_freq, att_db)) + att_freq *= sfreq / 2.0 + warn( + "Attenuation at stop frequency %0.2f Hz is only %0.2f dB. " + "Increase filter_length for higher attenuation." % (att_freq, att_db) + ) return h def _check_zero_phase_length(N, phase, gain_nyq=0): N = int(N) if N % 2 == 0: - if phase == 'zero': - raise RuntimeError('filter_length must be odd if phase="zero", ' - 'got %s' % N) - elif phase == 'zero-double' and gain_nyq == 1: + if phase == "zero": + raise RuntimeError( + 'filter_length must be odd if phase="zero", ' "got %s" % N + ) + elif phase == "zero-double" and gain_nyq == 1: N += 1 return N @@ -410,39 +611,43 @@ def _check_coefficients(system): """Check for filter stability.""" if isinstance(system, tuple): from scipy.signal import tf2zpk + z, p, k = tf2zpk(*system) else: # sos from scipy.signal import sos2zpk + z, p, k = sos2zpk(system) if np.any(np.abs(p) > 1.0): - raise RuntimeError('Filter poles outside unit circle, filter will be ' - 'unstable. Consider using different filter ' - 'coefficients.') + raise RuntimeError( + "Filter poles outside unit circle, filter will be " + "unstable. Consider using different filter " + "coefficients." + ) -def _iir_filter(x, iir_params, picks, n_jobs, copy, phase='zero'): +def _iir_filter(x, iir_params, picks, n_jobs, copy, phase="zero"): """Call filtfilt or lfilter.""" # set up array for filtering, reshape to 2D, operate on last axis from scipy.signal import filtfilt, sosfiltfilt, lfilter, sosfilt + x, orig_shape, picks = _prep_for_filtering(x, copy, picks) - if phase in ('zero', 'zero-double'): - padlen = min(iir_params['padlen'], x.shape[-1] - 1) - if 'sos' in iir_params: - fun = partial(sosfiltfilt, sos=iir_params['sos'], padlen=padlen, - axis=-1) - _check_coefficients(iir_params['sos']) + if phase in ("zero", "zero-double"): + padlen = min(iir_params["padlen"], x.shape[-1] - 1) + if "sos" in iir_params: + fun = partial(sosfiltfilt, sos=iir_params["sos"], padlen=padlen, axis=-1) + _check_coefficients(iir_params["sos"]) else: - fun = partial(filtfilt, b=iir_params['b'], a=iir_params['a'], - padlen=padlen, axis=-1) - _check_coefficients((iir_params['b'], iir_params['a'])) + fun = partial( + filtfilt, b=iir_params["b"], a=iir_params["a"], padlen=padlen, axis=-1 + ) + _check_coefficients((iir_params["b"], iir_params["a"])) else: - if 'sos' in iir_params: - fun = partial(sosfilt, sos=iir_params['sos'], axis=-1) - _check_coefficients(iir_params['sos']) + if "sos" in iir_params: + fun = partial(sosfilt, sos=iir_params["sos"], axis=-1) + _check_coefficients(iir_params["sos"]) else: - fun = partial(lfilter, b=iir_params['b'], a=iir_params['a'], - axis=-1) - _check_coefficients((iir_params['b'], iir_params['a'])) + fun = partial(lfilter, b=iir_params["b"], a=iir_params["a"], axis=-1) + _check_coefficients((iir_params["b"], iir_params["a"])) parallel, p_fun, n_jobs = parallel_func(fun, n_jobs) if n_jobs == 1: for p in picks: @@ -472,14 +677,15 @@ def estimate_ringing_samples(system, max_try=100000): The approximate ringing. """ from scipy import signal + if isinstance(system, tuple): # TF - kind = 'ba' + kind = "ba" b, a = system - zi = [0.] * (len(a) - 1) + zi = [0.0] * (len(a) - 1) else: - kind = 'sos' + kind = "sos" sos = system - zi = [[0.] * 2] * len(sos) + zi = [[0.0] * 2] * len(sos) n_per_chunk = 1000 n_chunks_max = int(np.ceil(max_try / float(n_per_chunk))) x = np.zeros(n_per_chunk) @@ -487,7 +693,7 @@ def estimate_ringing_samples(system, max_try=100000): last_good = n_per_chunk thresh_val = 0 for ii in range(n_chunks_max): - if kind == 'ba': + if kind == "ba": h, zi = signal.lfilter(b, a, x, zi=zi) else: h, zi = signal.sosfilt(sos, x, zi=zi) @@ -501,24 +707,32 @@ def estimate_ringing_samples(system, max_try=100000): idx = (ii - 1) * n_per_chunk + last_good break else: - warn('Could not properly estimate ringing for the filter') + warn("Could not properly estimate ringing for the filter") idx = n_per_chunk * n_chunks_max return idx _ftype_dict = { - 'butter': 'Butterworth', - 'cheby1': 'Chebyshev I', - 'cheby2': 'Chebyshev II', - 'ellip': 'Cauer/elliptic', - 'bessel': 'Bessel/Thomson', + "butter": "Butterworth", + "cheby1": "Chebyshev I", + "cheby2": "Chebyshev II", + "ellip": "Cauer/elliptic", + "bessel": "Bessel/Thomson", } @verbose -def construct_iir_filter(iir_params, f_pass=None, f_stop=None, sfreq=None, - btype=None, return_copy=True, *, phase='zero', - verbose=None): +def construct_iir_filter( + iir_params, + f_pass=None, + f_stop=None, + sfreq=None, + btype=None, + return_copy=True, + *, + phase="zero", + verbose=None, +): """Use IIR parameters to get filtering coefficients. This function works like a wrapper for iirdesign and iirfilter in @@ -636,136 +850,190 @@ def construct_iir_filter(iir_params, f_pass=None, f_stop=None, sfreq=None, :ref:`disc-filtering` and :ref:`tut-filter-resample`. """ # noqa: E501 from scipy.signal import iirfilter, iirdesign, freqz, sosfreqz - known_filters = ('bessel', 'butter', 'butterworth', 'cauer', 'cheby1', - 'cheby2', 'chebyshev1', 'chebyshev2', 'chebyshevi', - 'chebyshevii', 'ellip', 'elliptic') + + known_filters = ( + "bessel", + "butter", + "butterworth", + "cauer", + "cheby1", + "cheby2", + "chebyshev1", + "chebyshev2", + "chebyshevi", + "chebyshevii", + "ellip", + "elliptic", + ) if not isinstance(iir_params, dict): - raise TypeError('iir_params must be a dict, got %s' % type(iir_params)) + raise TypeError("iir_params must be a dict, got %s" % type(iir_params)) # if the filter has been designed, we're good to go Wp = None - if 'sos' in iir_params: - system = iir_params['sos'] - output = 'sos' - elif 'a' in iir_params and 'b' in iir_params: - system = (iir_params['b'], iir_params['a']) - output = 'ba' + if "sos" in iir_params: + system = iir_params["sos"] + output = "sos" + elif "a" in iir_params and "b" in iir_params: + system = (iir_params["b"], iir_params["a"]) + output = "ba" else: - output = iir_params.get('output', 'sos') - _check_option('output', output, ('ba', 'sos')) + output = iir_params.get("output", "sos") + _check_option("output", output, ("ba", "sos")) # ensure we have a valid ftype - if 'ftype' not in iir_params: - raise RuntimeError('ftype must be an entry in iir_params if ''b'' ' - 'and ''a'' are not specified') - ftype = iir_params['ftype'] + if "ftype" not in iir_params: + raise RuntimeError( + "ftype must be an entry in iir_params if " + "b" + " " + "and " + "a" + " are not specified" + ) + ftype = iir_params["ftype"] if ftype not in known_filters: - raise RuntimeError('ftype must be in filter_dict from ' - 'scipy.signal (e.g., butter, cheby1, etc.) not ' - '%s' % ftype) + raise RuntimeError( + "ftype must be in filter_dict from " + "scipy.signal (e.g., butter, cheby1, etc.) not " + "%s" % ftype + ) # use order-based design f_pass = np.atleast_1d(f_pass) if f_pass.ndim > 1: - raise ValueError('frequencies must be 1D, got %dD' % f_pass.ndim) - edge_freqs = ', '.join('%0.2f' % (f,) for f in f_pass) + raise ValueError("frequencies must be 1D, got %dD" % f_pass.ndim) + edge_freqs = ", ".join("%0.2f" % (f,) for f in f_pass) Wp = f_pass / (float(sfreq) / 2) # IT will de designed ftype_nice = _ftype_dict.get(ftype, ftype) - _validate_type(phase, str, 'phase') - _check_option('phase', phase, ('zero', 'zero-double', 'forward')) - if phase in ('zero-double', 'zero'): - ptype = 'zero-phase (two-pass forward and reverse) non-causal' + _validate_type(phase, str, "phase") + _check_option("phase", phase, ("zero", "zero-double", "forward")) + if phase in ("zero-double", "zero"): + ptype = "zero-phase (two-pass forward and reverse) non-causal" else: - ptype = 'non-linear phase (one-pass forward) causal' - logger.info('') - logger.info('IIR filter parameters') - logger.info('---------------------') - logger.info(f'{ftype_nice} {btype} {ptype} filter:') + ptype = "non-linear phase (one-pass forward) causal" + logger.info("") + logger.info("IIR filter parameters") + logger.info("---------------------") + logger.info(f"{ftype_nice} {btype} {ptype} filter:") # SciPy designs forward for -3dB, so forward-backward is -6dB - if 'order' in iir_params: - singleton = btype in ('low', 'lowpass', 'high', 'highpass') + if "order" in iir_params: + singleton = btype in ("low", "lowpass", "high", "highpass") use_Wp = Wp.item() if singleton else Wp - kwargs = dict(N=iir_params['order'], Wn=use_Wp, btype=btype, - ftype=ftype, output=output) - for key in ('rp', 'rs'): + kwargs = dict( + N=iir_params["order"], + Wn=use_Wp, + btype=btype, + ftype=ftype, + output=output, + ) + for key in ("rp", "rs"): if key in iir_params: kwargs[key] = iir_params[key] system = iirfilter(**kwargs) - if phase in ('zero', 'zero-double'): - ptype, pmul = '(effective, after forward-backward)', 2 + if phase in ("zero", "zero-double"): + ptype, pmul = "(effective, after forward-backward)", 2 else: - ptype, pmul = '(forward)', 1 - logger.info('- Filter order %d %s' - % (pmul * iir_params['order'] * len(Wp), ptype)) + ptype, pmul = "(forward)", 1 + logger.info( + "- Filter order %d %s" % (pmul * iir_params["order"] * len(Wp), ptype) + ) else: # use gpass / gstop design Ws = np.asanyarray(f_stop) / (float(sfreq) / 2) - if 'gpass' not in iir_params or 'gstop' not in iir_params: - raise ValueError('iir_params must have at least ''gstop'' and' - ' ''gpass'' (or ''N'') entries') - system = iirdesign(Wp, Ws, iir_params['gpass'], - iir_params['gstop'], ftype=ftype, output=output) + if "gpass" not in iir_params or "gstop" not in iir_params: + raise ValueError( + "iir_params must have at least " + "gstop" + " and" + " " + "gpass" + " (or " + "N" + ") entries" + ) + system = iirdesign( + Wp, + Ws, + iir_params["gpass"], + iir_params["gstop"], + ftype=ftype, + output=output, + ) if system is None: - raise RuntimeError('coefficients could not be created from iir_params') + raise RuntimeError("coefficients could not be created from iir_params") # do some sanity checks _check_coefficients(system) # get the gains at the cutoff frequencies if Wp is not None: - if output == 'sos': + if output == "sos": cutoffs = sosfreqz(system, worN=Wp * np.pi)[1] else: cutoffs = freqz(system[0], system[1], worN=Wp * np.pi)[1] cutoffs = 20 * np.log10(np.abs(cutoffs)) # 2 * 20 here because we do forward-backward filtering - if phase in ('zero', 'zero-double'): + if phase in ("zero", "zero-double"): cutoffs *= 2 - cutoffs = ', '.join(['%0.2f' % (c,) for c in cutoffs]) - logger.info('- Cutoff%s at %s Hz: %s dB' - % (_pl(f_pass), edge_freqs, cutoffs)) + cutoffs = ", ".join(["%0.2f" % (c,) for c in cutoffs]) + logger.info("- Cutoff%s at %s Hz: %s dB" % (_pl(f_pass), edge_freqs, cutoffs)) # now deal with padding - if 'padlen' not in iir_params: + if "padlen" not in iir_params: padlen = estimate_ringing_samples(system) else: - padlen = iir_params['padlen'] + padlen = iir_params["padlen"] if return_copy: iir_params = deepcopy(iir_params) iir_params.update(dict(padlen=padlen)) - if output == 'sos': + if output == "sos": iir_params.update(sos=system) else: iir_params.update(b=system[0], a=system[1]) - logger.info('') + logger.info("") return iir_params def _check_method(method, iir_params, extra_types=()): """Parse method arguments.""" - allowed_types = ['iir', 'fir', 'fft'] + list(extra_types) - _validate_type(method, 'str', 'method') - _check_option('method', method, allowed_types) - if method == 'fft': - method = 'fir' # use the better name - if method == 'iir': + allowed_types = ["iir", "fir", "fft"] + list(extra_types) + _validate_type(method, "str", "method") + _check_option("method", method, allowed_types) + if method == "fft": + method = "fir" # use the better name + if method == "iir": if iir_params is None: iir_params = dict() - if len(iir_params) == 0 or (len(iir_params) == 1 and - 'output' in iir_params): - iir_params = dict(order=4, ftype='butter', - output=iir_params.get('output', 'sos')) + if len(iir_params) == 0 or (len(iir_params) == 1 and "output" in iir_params): + iir_params = dict( + order=4, ftype="butter", output=iir_params.get("output", "sos") + ) elif iir_params is not None: raise ValueError('iir_params must be None if method != "iir"') return iir_params, method @verbose -def filter_data(data, sfreq, l_freq, h_freq, picks=None, filter_length='auto', - l_trans_bandwidth='auto', h_trans_bandwidth='auto', - n_jobs=None, method='fir', iir_params=None, copy=True, - phase='zero', fir_window='hamming', fir_design='firwin', - pad='reflect_limited', *, verbose=None): +def filter_data( + data, + sfreq, + l_freq, + h_freq, + picks=None, + filter_length="auto", + l_trans_bandwidth="auto", + h_trans_bandwidth="auto", + n_jobs=None, + method="fir", + iir_params=None, + copy=True, + phase="zero", + fir_window="hamming", + fir_design="firwin", + pad="reflect_limited", + *, + verbose=None, +): """Filter a subset of channels. Parameters @@ -834,21 +1102,42 @@ def filter_data(data, sfreq, l_freq, h_freq, picks=None, filter_length='auto', data = _check_filterable(data) iir_params, method = _check_method(method, iir_params) filt = create_filter( - data, sfreq, l_freq, h_freq, filter_length, l_trans_bandwidth, - h_trans_bandwidth, method, iir_params, phase, fir_window, fir_design) - if method in ('fir', 'fft'): - data = _overlap_add_filter(data, filt, None, phase, picks, n_jobs, - copy, pad) + data, + sfreq, + l_freq, + h_freq, + filter_length, + l_trans_bandwidth, + h_trans_bandwidth, + method, + iir_params, + phase, + fir_window, + fir_design, + ) + if method in ("fir", "fft"): + data = _overlap_add_filter(data, filt, None, phase, picks, n_jobs, copy, pad) else: data = _iir_filter(data, filt, picks, n_jobs, copy, phase) return data @verbose -def create_filter(data, sfreq, l_freq, h_freq, filter_length='auto', - l_trans_bandwidth='auto', h_trans_bandwidth='auto', - method='fir', iir_params=None, phase='zero', - fir_window='hamming', fir_design='firwin', verbose=None): +def create_filter( + data, + sfreq, + l_freq, + h_freq, + filter_length="auto", + l_trans_bandwidth="auto", + h_trans_bandwidth="auto", + method="fir", + iir_params=None, + phase="zero", + fir_window="hamming", + fir_design="firwin", + verbose=None, +): r"""Create a FIR or IIR filter. ``l_freq`` and ``h_freq`` are the frequencies below which and above @@ -967,61 +1256,127 @@ def create_filter(data, sfreq, l_freq, h_freq, filter_length='auto', """ sfreq = float(sfreq) if sfreq < 0: - raise ValueError('sfreq must be positive') + raise ValueError("sfreq must be positive") # If no data specified, sanity checking will be skipped if data is None: - logger.info('No data specified. Sanity checks related to the length of' - ' the signal relative to the filter order will be' - ' skipped.') + logger.info( + "No data specified. Sanity checks related to the length of" + " the signal relative to the filter order will be" + " skipped." + ) if h_freq is not None: h_freq = np.array(h_freq, float).ravel() - if (h_freq > (sfreq / 2.)).any(): - raise ValueError('h_freq (%s) must be less than the Nyquist ' - 'frequency %s' % (h_freq, sfreq / 2.)) + if (h_freq > (sfreq / 2.0)).any(): + raise ValueError( + "h_freq (%s) must be less than the Nyquist " + "frequency %s" % (h_freq, sfreq / 2.0) + ) if l_freq is not None: l_freq = np.array(l_freq, float).ravel() if (l_freq == 0).all(): l_freq = None iir_params, method = _check_method(method, iir_params) if l_freq is None and h_freq is None: - data, sfreq, _, _, _, _, filter_length, phase, fir_window, \ - fir_design = _triage_filter_params( - data, sfreq, None, None, None, None, - filter_length, method, phase, fir_window, fir_design) - if method == 'iir': + ( + data, + sfreq, + _, + _, + _, + _, + filter_length, + phase, + fir_window, + fir_design, + ) = _triage_filter_params( + data, + sfreq, + None, + None, + None, + None, + filter_length, + method, + phase, + fir_window, + fir_design, + ) + if method == "iir": out = dict() if iir_params is None else deepcopy(iir_params) - out.update(b=np.array([1.]), a=np.array([1.])) + out.update(b=np.array([1.0]), a=np.array([1.0])) else: - freq = [0, sfreq / 2.] - gain = [1., 1.] + freq = [0, sfreq / 2.0] + gain = [1.0, 1.0] if l_freq is None and h_freq is not None: h_freq = h_freq.item() - logger.info('Setting up low-pass filter at %0.2g Hz' % (h_freq,)) - data, sfreq, _, f_p, _, f_s, filter_length, phase, fir_window, \ - fir_design = _triage_filter_params( - data, sfreq, None, h_freq, None, h_trans_bandwidth, - filter_length, method, phase, fir_window, fir_design) - if method == 'iir': - out = construct_iir_filter(iir_params, f_p, f_s, sfreq, 'lowpass', - phase=phase) + logger.info("Setting up low-pass filter at %0.2g Hz" % (h_freq,)) + ( + data, + sfreq, + _, + f_p, + _, + f_s, + filter_length, + phase, + fir_window, + fir_design, + ) = _triage_filter_params( + data, + sfreq, + None, + h_freq, + None, + h_trans_bandwidth, + filter_length, + method, + phase, + fir_window, + fir_design, + ) + if method == "iir": + out = construct_iir_filter( + iir_params, f_p, f_s, sfreq, "lowpass", phase=phase + ) else: # 'fir' freq = [0, f_p, f_s] gain = [1, 1, 0] - if f_s != sfreq / 2.: - freq += [sfreq / 2.] + if f_s != sfreq / 2.0: + freq += [sfreq / 2.0] gain += [0] elif l_freq is not None and h_freq is None: l_freq = l_freq.item() - logger.info('Setting up high-pass filter at %0.2g Hz' % (l_freq,)) - data, sfreq, pass_, _, stop, _, filter_length, phase, fir_window, \ - fir_design = _triage_filter_params( - data, sfreq, l_freq, None, l_trans_bandwidth, None, - filter_length, method, phase, fir_window, fir_design) - if method == 'iir': - out = construct_iir_filter(iir_params, pass_, stop, sfreq, - 'highpass', phase=phase) + logger.info("Setting up high-pass filter at %0.2g Hz" % (l_freq,)) + ( + data, + sfreq, + pass_, + _, + stop, + _, + filter_length, + phase, + fir_window, + fir_design, + ) = _triage_filter_params( + data, + sfreq, + l_freq, + None, + l_trans_bandwidth, + None, + filter_length, + method, + phase, + fir_window, + fir_design, + ) + if method == "iir": + out = construct_iir_filter( + iir_params, pass_, stop, sfreq, "highpass", phase=phase + ) else: # 'fir' - freq = [stop, pass_, sfreq / 2.] + freq = [stop, pass_, sfreq / 2.0] gain = [0, 1, 1] if stop != 0: freq = [0] + freq @@ -1029,22 +1384,47 @@ def create_filter(data, sfreq, l_freq, h_freq, filter_length='auto', elif l_freq is not None and h_freq is not None: if (l_freq < h_freq).any(): l_freq, h_freq = l_freq.item(), h_freq.item() - logger.info('Setting up band-pass filter from %0.2g - %0.2g Hz' - % (l_freq, h_freq)) - data, sfreq, f_p1, f_p2, f_s1, f_s2, filter_length, phase, \ - fir_window, fir_design = _triage_filter_params( - data, sfreq, l_freq, h_freq, l_trans_bandwidth, - h_trans_bandwidth, filter_length, method, phase, - fir_window, fir_design) - if method == 'iir': - out = construct_iir_filter(iir_params, [f_p1, f_p2], - [f_s1, f_s2], sfreq, 'bandpass', - phase=phase) + logger.info( + "Setting up band-pass filter from %0.2g - %0.2g Hz" % (l_freq, h_freq) + ) + ( + data, + sfreq, + f_p1, + f_p2, + f_s1, + f_s2, + filter_length, + phase, + fir_window, + fir_design, + ) = _triage_filter_params( + data, + sfreq, + l_freq, + h_freq, + l_trans_bandwidth, + h_trans_bandwidth, + filter_length, + method, + phase, + fir_window, + fir_design, + ) + if method == "iir": + out = construct_iir_filter( + iir_params, + [f_p1, f_p2], + [f_s1, f_s2], + sfreq, + "bandpass", + phase=phase, + ) else: # 'fir' freq = [f_s1, f_p1, f_p2, f_s2] gain = [0, 1, 1, 0] - if f_s2 != sfreq / 2.: - freq += [sfreq / 2.] + if f_s2 != sfreq / 2.0: + freq += [sfreq / 2.0] gain += [0] if f_s1 != 0: freq = [0] + freq @@ -1053,54 +1433,100 @@ def create_filter(data, sfreq, l_freq, h_freq, filter_length='auto', # This could possibly be removed after 0.14 release, but might # as well leave it in to sanity check notch_filter if len(l_freq) != len(h_freq): - raise ValueError('l_freq and h_freq must be the same length') - msg = 'Setting up band-stop filter' + raise ValueError("l_freq and h_freq must be the same length") + msg = "Setting up band-stop filter" if len(l_freq) == 1: l_freq, h_freq = l_freq.item(), h_freq.item() - msg += ' from %0.2g - %0.2g Hz' % (h_freq, l_freq) + msg += " from %0.2g - %0.2g Hz" % (h_freq, l_freq) logger.info(msg) # Note: order of outputs is intentionally switched here! - data, sfreq, f_s1, f_s2, f_p1, f_p2, filter_length, phase, \ - fir_window, fir_design = _triage_filter_params( - data, sfreq, h_freq, l_freq, h_trans_bandwidth, - l_trans_bandwidth, filter_length, method, phase, - fir_window, fir_design, bands='arr', reverse=True) - if method == 'iir': + ( + data, + sfreq, + f_s1, + f_s2, + f_p1, + f_p2, + filter_length, + phase, + fir_window, + fir_design, + ) = _triage_filter_params( + data, + sfreq, + h_freq, + l_freq, + h_trans_bandwidth, + l_trans_bandwidth, + filter_length, + method, + phase, + fir_window, + fir_design, + bands="arr", + reverse=True, + ) + if method == "iir": if len(f_p1) != 1: - raise ValueError('Multiple stop-bands can only be used ' - 'with FIR filtering') - out = construct_iir_filter(iir_params, [f_p1[0], f_p2[0]], - [f_s1[0], f_s2[0]], sfreq, - 'bandstop', phase=phase) + raise ValueError( + "Multiple stop-bands can only be used " "with FIR filtering" + ) + out = construct_iir_filter( + iir_params, + [f_p1[0], f_p2[0]], + [f_s1[0], f_s2[0]], + sfreq, + "bandstop", + phase=phase, + ) else: # 'fir' freq = np.r_[f_p1, f_s1, f_s2, f_p2] - gain = np.r_[np.ones_like(f_p1), np.zeros_like(f_s1), - np.zeros_like(f_s2), np.ones_like(f_p2)] + gain = np.r_[ + np.ones_like(f_p1), + np.zeros_like(f_s1), + np.zeros_like(f_s2), + np.ones_like(f_p2), + ] order = np.argsort(freq) freq = freq[order] gain = gain[order] if freq[0] != 0: - freq = np.r_[[0.], freq] - gain = np.r_[[1.], gain] - if freq[-1] != sfreq / 2.: - freq = np.r_[freq, [sfreq / 2.]] - gain = np.r_[gain, [1.]] + freq = np.r_[[0.0], freq] + gain = np.r_[[1.0], gain] + if freq[-1] != sfreq / 2.0: + freq = np.r_[freq, [sfreq / 2.0]] + gain = np.r_[gain, [1.0]] if np.any(np.abs(np.diff(gain, 2)) > 1): - raise ValueError('Stop bands are not sufficiently ' - 'separated.') - if method == 'fir': - out = _construct_fir_filter(sfreq, freq, gain, filter_length, phase, - fir_window, fir_design) + raise ValueError("Stop bands are not sufficiently " "separated.") + if method == "fir": + out = _construct_fir_filter( + sfreq, freq, gain, filter_length, phase, fir_window, fir_design + ) return out @verbose -def notch_filter(x, Fs, freqs, filter_length='auto', notch_widths=None, - trans_bandwidth=1, method='fir', iir_params=None, - mt_bandwidth=None, p_value=0.05, picks=None, n_jobs=None, - copy=True, phase='zero', fir_window='hamming', - fir_design='firwin', pad='reflect_limited', *, - verbose=None): +def notch_filter( + x, + Fs, + freqs, + filter_length="auto", + notch_widths=None, + trans_bandwidth=1, + method="fir", + iir_params=None, + mt_bandwidth=None, + p_value=0.05, + picks=None, + n_jobs=None, + copy=True, + phase="zero", + fir_window="hamming", + fir_design="firwin", + pad="reflect_limited", + *, + verbose=None, +): r"""Notch filter for the signal x. Applies a zero-phase notch filter to the signal x, operating on the last @@ -1184,42 +1610,65 @@ def notch_filter(x, Fs, freqs, filter_length='auto', notch_widths=None, & Hemant Bokil, Oxford University Press, New York, 2008. Please cite this in publications if method 'spectrum_fit' is used. """ - x = _check_filterable(x, 'notch filtered', 'notch_filter') - iir_params, method = _check_method(method, iir_params, ['spectrum_fit']) + x = _check_filterable(x, "notch filtered", "notch_filter") + iir_params, method = _check_method(method, iir_params, ["spectrum_fit"]) if freqs is not None: freqs = np.atleast_1d(freqs) - elif method != 'spectrum_fit': - raise ValueError('freqs=None can only be used with method ' - 'spectrum_fit') + elif method != "spectrum_fit": + raise ValueError("freqs=None can only be used with method " "spectrum_fit") # Only have to deal with notch_widths for non-autodetect if freqs is not None: if notch_widths is None: notch_widths = freqs / 200.0 elif np.any(notch_widths < 0): - raise ValueError('notch_widths must be >= 0') + raise ValueError("notch_widths must be >= 0") else: notch_widths = np.atleast_1d(notch_widths) if len(notch_widths) == 1: notch_widths = notch_widths[0] * np.ones_like(freqs) elif len(notch_widths) != len(freqs): - raise ValueError('notch_widths must be None, scalar, or the ' - 'same length as freqs') + raise ValueError( + "notch_widths must be None, scalar, or the " "same length as freqs" + ) - if method in ('fir', 'iir'): + if method in ("fir", "iir"): # Speed this up by computing the fourier coefficients once tb_2 = trans_bandwidth / 2.0 - lows = [freq - nw / 2.0 - tb_2 - for freq, nw in zip(freqs, notch_widths)] - highs = [freq + nw / 2.0 + tb_2 - for freq, nw in zip(freqs, notch_widths)] - xf = filter_data(x, Fs, highs, lows, picks, filter_length, tb_2, tb_2, - n_jobs, method, iir_params, copy, phase, fir_window, - fir_design, pad=pad) - elif method == 'spectrum_fit': - xf = _mt_spectrum_proc(x, Fs, freqs, notch_widths, mt_bandwidth, - p_value, picks, n_jobs, copy, filter_length) + lows = [freq - nw / 2.0 - tb_2 for freq, nw in zip(freqs, notch_widths)] + highs = [freq + nw / 2.0 + tb_2 for freq, nw in zip(freqs, notch_widths)] + xf = filter_data( + x, + Fs, + highs, + lows, + picks, + filter_length, + tb_2, + tb_2, + n_jobs, + method, + iir_params, + copy, + phase, + fir_window, + fir_design, + pad=pad, + ) + elif method == "spectrum_fit": + xf = _mt_spectrum_proc( + x, + Fs, + freqs, + notch_widths, + mt_bandwidth, + p_value, + picks, + n_jobs, + copy, + filter_length, + ) return xf @@ -1230,26 +1679,37 @@ def _get_window_thresh(n_times, sfreq, mt_bandwidth, p_value): # figure out what tapers to use window_fun, _, _ = _compute_mt_params( - n_times, sfreq, mt_bandwidth, False, False, verbose=False) + n_times, sfreq, mt_bandwidth, False, False, verbose=False + ) # F-stat of 1-p point threshold = stats.f.ppf(1 - p_value / n_times, 2, 2 * len(window_fun) - 2) return window_fun, threshold -def _mt_spectrum_proc(x, sfreq, line_freqs, notch_widths, mt_bandwidth, - p_value, picks, n_jobs, copy, filter_length): +def _mt_spectrum_proc( + x, + sfreq, + line_freqs, + notch_widths, + mt_bandwidth, + p_value, + picks, + n_jobs, + copy, + filter_length, +): """Call _mt_spectrum_remove.""" # set up array for filtering, reshape to 2D, operate on last axis x, orig_shape, picks = _prep_for_filtering(x, copy, picks) - if isinstance(filter_length, str) and filter_length == 'auto': - filter_length = '10s' + if isinstance(filter_length, str) and filter_length == "auto": + filter_length = "10s" if filter_length is None: filter_length = x.shape[-1] - filter_length = min(_to_samples(filter_length, sfreq, '', ''), x.shape[-1]) + filter_length = min(_to_samples(filter_length, sfreq, "", ""), x.shape[-1]) get_wt = partial( - _get_window_thresh, sfreq=sfreq, mt_bandwidth=mt_bandwidth, - p_value=p_value) + _get_window_thresh, sfreq=sfreq, mt_bandwidth=mt_bandwidth, p_value=p_value + ) window_fun, threshold = get_wt(filter_length) parallel, p_fun, n_jobs = parallel_func(_mt_spectrum_remove_win, n_jobs) if n_jobs == 1: @@ -1257,34 +1717,41 @@ def _mt_spectrum_proc(x, sfreq, line_freqs, notch_widths, mt_bandwidth, for ii, x_ in enumerate(x): if ii in picks: x[ii], f = _mt_spectrum_remove_win( - x_, sfreq, line_freqs, notch_widths, window_fun, threshold, - get_wt) + x_, sfreq, line_freqs, notch_widths, window_fun, threshold, get_wt + ) freq_list.append(f) else: - data_new = parallel(p_fun(x_, sfreq, line_freqs, notch_widths, - window_fun, threshold, get_wt) - for xi, x_ in enumerate(x) - if xi in picks) + data_new = parallel( + p_fun(x_, sfreq, line_freqs, notch_widths, window_fun, threshold, get_wt) + for xi, x_ in enumerate(x) + if xi in picks + ) freq_list = [d[1] for d in data_new] data_new = np.array([d[0] for d in data_new]) x[picks, :] = data_new # report found frequencies, but do some sanitizing first by binning into # 1 Hz bins - counts = Counter(sum((np.unique(np.round(ff)).tolist() - for f in freq_list for ff in f), list())) - kind = 'Detected' if line_freqs is None else 'Removed' - found_freqs = '\n'.join(f' {freq:6.2f} : ' - f'{counts[freq]:4d} window{_pl(counts[freq])}' - for freq in sorted(counts)) or ' None' - logger.info(f'{kind} notch frequencies (Hz):\n{found_freqs}') + counts = Counter( + sum((np.unique(np.round(ff)).tolist() for f in freq_list for ff in f), list()) + ) + kind = "Detected" if line_freqs is None else "Removed" + found_freqs = ( + "\n".join( + f" {freq:6.2f} : " f"{counts[freq]:4d} window{_pl(counts[freq])}" + for freq in sorted(counts) + ) + or " None" + ) + logger.info(f"{kind} notch frequencies (Hz):\n{found_freqs}") x.shape = orig_shape return x -def _mt_spectrum_remove_win(x, sfreq, line_freqs, notch_widths, - window_fun, threshold, get_thresh): +def _mt_spectrum_remove_win( + x, sfreq, line_freqs, notch_widths, window_fun, threshold, get_thresh +): n_times = x.shape[-1] n_samples = window_fun.shape[1] n_overlap = (n_samples + 1) // 2 @@ -1295,31 +1762,32 @@ def _mt_spectrum_remove_win(x, sfreq, line_freqs, notch_widths, # Define how to process a chunk of data def process(x_): out = _mt_spectrum_remove( - x_, sfreq, line_freqs, notch_widths, window_fun, threshold, - get_thresh) + x_, sfreq, line_freqs, notch_widths, window_fun, threshold, get_thresh + ) rm_freqs.append(out[1]) return (out[0],) # must return a tuple # Define how to store a chunk of fully processed data (it's trivial) def store(x_): stop = idx[0] + x_.shape[-1] - x_out[..., idx[0]:stop] += x_ + x_out[..., idx[0] : stop] += x_ idx[0] = stop - _COLA(process, store, n_times, n_samples, n_overlap, sfreq, - verbose=False).feed(x) + _COLA(process, store, n_times, n_samples, n_overlap, sfreq, verbose=False).feed(x) assert idx[0] == n_times return x_out, rm_freqs -def _mt_spectrum_remove(x, sfreq, line_freqs, notch_widths, - window_fun, threshold, get_thresh): +def _mt_spectrum_remove( + x, sfreq, line_freqs, notch_widths, window_fun, threshold, get_thresh +): """Use MT-spectrum to remove line frequencies. Based on Chronux. If line_freqs is specified, all freqs within notch_width of each line_freq is set to zero. """ from .time_frequency.multitaper import _mt_spectra + assert x.ndim == 1 if x.shape[-1] != window_fun.shape[-1]: window_fun, threshold = get_thresh(x.shape[-1]) @@ -1342,8 +1810,7 @@ def _mt_spectrum_remove(x, sfreq, line_freqs, notch_widths, x_p, freqs = _mt_spectra(x[np.newaxis, :], window_fun, sfreq) # sum of the product of x_p and H0 across tapers (1, n_freqs) - x_p_H0 = np.sum(x_p[:, tapers_odd, :] * - H0[np.newaxis, :, np.newaxis], axis=1) + x_p_H0 = np.sum(x_p[:, tapers_odd, :] * H0[np.newaxis, :, np.newaxis], axis=1) # resulting calculated amplitudes for all freqs A = x_p_H0 / H0_sq @@ -1357,8 +1824,9 @@ def _mt_spectrum_remove(x, sfreq, line_freqs, notch_widths, # numerator for F-statistic num = (n_tapers - 1) * (A * A.conj()).real * H0_sq # denominator for F-statistic - den = (np.sum(np.abs(x_p[:, tapers_odd, :] - x_hat) ** 2, 1) + - np.sum(np.abs(x_p[:, tapers_even, :]) ** 2, 1)) + den = np.sum(np.abs(x_p[:, tapers_odd, :] - x_hat) ** 2, 1) + np.sum( + np.abs(x_p[:, tapers_even, :]) ** 2, 1 + ) den[den == 0] = np.inf f_stat = num / den @@ -1367,10 +1835,11 @@ def _mt_spectrum_remove(x, sfreq, line_freqs, notch_widths, rm_freqs = freqs[indices] else: # specify frequencies - indices_1 = np.unique([np.argmin(np.abs(freqs - lf)) - for lf in line_freqs]) - indices_2 = [np.logical_and(freqs > lf - nw / 2., freqs < lf + nw / 2.) - for lf, nw in zip(line_freqs, notch_widths)] + indices_1 = np.unique([np.argmin(np.abs(freqs - lf)) for lf in line_freqs]) + indices_2 = [ + np.logical_and(freqs > lf - nw / 2.0, freqs < lf + nw / 2.0) + for lf, nw in zip(line_freqs, notch_widths) + ] indices_2 = np.where(np.any(np.array(indices_2), axis=0))[0] indices = np.unique(np.r_[indices_1, indices_2]) rm_freqs = freqs[indices] @@ -1390,7 +1859,7 @@ def _mt_spectrum_remove(x, sfreq, line_freqs, notch_widths, return x - datafit, rm_freqs -def _check_filterable(x, kind='filtered', alternative='filter'): +def _check_filterable(x, kind="filtered", alternative="filter"): # Let's be fairly strict about this -- users can easily coerce to ndarray # at their end, and we already should do it internally any time we are # using these low-level functions. At the same time, let's @@ -1399,6 +1868,7 @@ def _check_filterable(x, kind='filtered', alternative='filter'): from .io.base import BaseRaw from .epochs import BaseEpochs from .evoked import Evoked + if isinstance(x, (BaseRaw, BaseEpochs, Evoked)): try: name = x.__class__.__name__ @@ -1406,15 +1876,21 @@ def _check_filterable(x, kind='filtered', alternative='filter'): pass else: raise TypeError( - 'This low-level function only operates on np.ndarray ' - f'instances. To get a {kind} {name} instance, use a method ' - f'like `inst_new = inst.copy().{alternative}(...)` ' - 'instead.') - _validate_type(x, (np.ndarray, list, tuple), f'Data to be {kind}') + "This low-level function only operates on np.ndarray " + f"instances. To get a {kind} {name} instance, use a method " + f"like `inst_new = inst.copy().{alternative}(...)` " + "instead." + ) + _validate_type(x, (np.ndarray, list, tuple), f"Data to be {kind}") x = np.asanyarray(x) if x.dtype != np.float64: - raise ValueError('Data to be %s must be real floating, got %s' - % (kind, x.dtype,)) + raise ValueError( + "Data to be %s must be real floating, got %s" + % ( + kind, + x.dtype, + ) + ) return x @@ -1424,8 +1900,18 @@ def _resamp_ratio_len(up, down, n): @verbose -def resample(x, up=1., down=1., npad=100, axis=-1, window='boxcar', - n_jobs=None, pad='reflect_limited', *, verbose=None): +def resample( + x, + up=1.0, + down=1.0, + npad=100, + axis=-1, + window="boxcar", + n_jobs=None, + pad="reflect_limited", + *, + verbose=None, +): """Resample an array. Operates along the last dimension of the array. @@ -1469,16 +1955,19 @@ def resample(x, up=1., down=1., npad=100, axis=-1, window='boxcar', """ from scipy.signal import get_window from scipy.fft import ifftshift, fftfreq + # check explicitly for backwards compatibility if not isinstance(axis, int): - err = ("The axis parameter needs to be an integer (got %s). " - "The axis parameter was missing from this function for a " - "period of time, you might be intending to specify the " - "subsequent window parameter." % repr(axis)) + err = ( + "The axis parameter needs to be an integer (got %s). " + "The axis parameter was missing from this function for a " + "period of time, you might be intending to specify the " + "subsequent window parameter." % repr(axis) + ) raise TypeError(err) # make sure our arithmetic will work - x = _check_filterable(x, 'resampled', 'resample') + x = _check_filterable(x, "resampled", "resample") ratio, final_len = _resamp_ratio_len(up, down, x.shape[axis]) del up, down if axis < 0: @@ -1489,11 +1978,11 @@ def resample(x, up=1., down=1., npad=100, axis=-1, window='boxcar', orig_shape = x.shape x_len = orig_shape[-1] if x_len == 0: - warn('x has zero length along last axis, returning a copy of x') + warn("x has zero length along last axis, returning a copy of x") return x.copy() bad_msg = 'npad must be "auto" or an integer' if isinstance(npad, str): - if npad != 'auto': + if npad != "auto": raise ValueError(bad_msg) # Figure out reasonable pad that gets us to a power of 2 min_add = min(x_len // 8, 100) * 2 @@ -1520,14 +2009,13 @@ def resample(x, up=1., down=1., npad=100, axis=-1, window='boxcar', if window is not None: if callable(window): W = window(fftfreq(orig_len)) - elif isinstance(window, np.ndarray) and \ - window.shape == (orig_len,): + elif isinstance(window, np.ndarray) and window.shape == (orig_len,): W = window else: W = ifftshift(get_window(window, orig_len)) else: W = np.ones(orig_len) - W *= (float(new_len) / float(orig_len)) + W *= float(new_len) / float(orig_len) # figure out if we should use CUDA n_jobs, cuda_dict = _setup_cuda_fft_resample(n_jobs, W, new_len) @@ -1538,11 +2026,11 @@ def resample(x, up=1., down=1., npad=100, axis=-1, window='boxcar', if n_jobs == 1: y = np.zeros((len(x_flat), new_len - to_removes.sum()), dtype=x.dtype) for xi, x_ in enumerate(x_flat): - y[xi] = _fft_resample(x_, new_len, npads, to_removes, - cuda_dict, pad) + y[xi] = _fft_resample(x_, new_len, npads, to_removes, cuda_dict, pad) else: - y = parallel(p_fun(x_, new_len, npads, to_removes, cuda_dict, pad) - for x_ in x_flat) + y = parallel( + p_fun(x_, new_len, npads, to_removes, cuda_dict, pad) for x_ in x_flat + ) y = np.array(y) # Restore the original array shape (modified for resampling) @@ -1588,8 +2076,7 @@ def _resample_stim_channels(stim_data, up, down): # out-of-bounds, which can happen (having one sample more than # expected) due to padding sample_picks = np.minimum( - (np.arange(resampled_n_samples) / ratio).astype(int), - n_samples - 1 + (np.arange(resampled_n_samples) / ratio).astype(int), n_samples - 1 ) # Create windows starting from sample_picks[i], ending at sample_picks[i+1] @@ -1598,7 +2085,7 @@ def _resample_stim_channels(stim_data, up, down): # Use the first non-zero value in each window for window_i, window in enumerate(windows): for stim_num, stim in enumerate(stim_data): - nonzero = stim[window[0]:window[1]].nonzero()[0] + nonzero = stim[window[0] : window[1]].nonzero()[0] if len(nonzero) > 0: val = stim[window[0] + nonzero[0]] else: @@ -1637,14 +2124,15 @@ def detrend(x, order=1, axis=-1): True """ from scipy.signal import detrend + if axis > len(x.shape): - raise ValueError('x does not have %d axes' % axis) + raise ValueError("x does not have %d axes" % axis) if order == 0: - fit = 'constant' + fit = "constant" elif order == 1: - fit = 'linear' + fit = "linear" else: - raise ValueError('order must be 0 or 1') + raise ValueError("order must be 0 or 1") y = detrend(x, axis=axis, type=fit) @@ -1659,31 +2147,33 @@ def detrend(x, order=1, axis=-1): # (Hamming) then δs = 10 ** (53 / -20.), which means that the passband # deviation should be 20 * np.log10(1 + 10 ** (53 / -20.)) == 0.0194. _fir_window_dict = { - 'hann': dict(name='Hann', ripple=0.0546, attenuation=44), - 'hamming': dict(name='Hamming', ripple=0.0194, attenuation=53), - 'blackman': dict(name='Blackman', ripple=0.0017, attenuation=74), + "hann": dict(name="Hann", ripple=0.0546, attenuation=44), + "hamming": dict(name="Hamming", ripple=0.0194, attenuation=53), + "blackman": dict(name="Blackman", ripple=0.0017, attenuation=74), } _known_fir_windows = tuple(sorted(_fir_window_dict.keys())) -_known_phases_fir = ('linear', 'zero', 'zero-double', 'minimum') -_known_phases_iir = ('zero', 'zero-double', 'forward') -_known_fir_designs = ('firwin', 'firwin2') +_known_phases_fir = ("linear", "zero", "zero-double", "minimum") +_known_phases_iir = ("zero", "zero-double", "forward") +_known_fir_designs = ("firwin", "firwin2") _fir_design_dict = { - 'firwin': 'Windowed time-domain', - 'firwin2': 'Windowed frequency-domain', + "firwin": "Windowed time-domain", + "firwin2": "Windowed frequency-domain", } def _to_samples(filter_length, sfreq, phase, fir_design): - _validate_type(filter_length, (str, 'int-like'), 'filter_length') + _validate_type(filter_length, (str, "int-like"), "filter_length") if isinstance(filter_length, str): filter_length = filter_length.lower() - err_msg = ('filter_length, if a string, must be a ' - 'human-readable time, e.g. "10s", or "auto", not ' - '"%s"' % filter_length) - if filter_length.lower().endswith('ms'): + err_msg = ( + "filter_length, if a string, must be a " + 'human-readable time, e.g. "10s", or "auto", not ' + '"%s"' % filter_length + ) + if filter_length.lower().endswith("ms"): mult_fact = 1e-3 filter_length = filter_length[:-2] - elif filter_length[-1].lower() == 's': + elif filter_length[-1].lower() == "s": mult_fact = 1 filter_length = filter_length[:-1] else: @@ -1693,54 +2183,62 @@ def _to_samples(filter_length, sfreq, phase, fir_design): filter_length = float(filter_length) except ValueError: raise ValueError(err_msg) - filter_length = max(int(np.ceil(filter_length * mult_fact * - sfreq)), 1) - if fir_design == 'firwin': + filter_length = max(int(np.ceil(filter_length * mult_fact * sfreq)), 1) + if fir_design == "firwin": filter_length += (filter_length - 1) % 2 - filter_length = _ensure_int(filter_length, 'filter_length') + filter_length = _ensure_int(filter_length, "filter_length") return filter_length -def _triage_filter_params(x, sfreq, l_freq, h_freq, - l_trans_bandwidth, h_trans_bandwidth, - filter_length, method, phase, fir_window, - fir_design, bands='scalar', reverse=False): +def _triage_filter_params( + x, + sfreq, + l_freq, + h_freq, + l_trans_bandwidth, + h_trans_bandwidth, + filter_length, + method, + phase, + fir_window, + fir_design, + bands="scalar", + reverse=False, +): """Validate and automate filter parameter selection.""" - _validate_type(phase, 'str', 'phase') - if method == 'fir': - _check_option('phase', phase, _known_phases_fir, - extra='when FIR filtering') + _validate_type(phase, "str", "phase") + if method == "fir": + _check_option("phase", phase, _known_phases_fir, extra="when FIR filtering") else: - _check_option('phase', phase, _known_phases_iir, - extra='when IIR filtering') - _validate_type(fir_window, 'str', 'fir_window') - _check_option('fir_window', fir_window, _known_fir_windows) - _validate_type(fir_design, 'str', 'fir_design') - _check_option('fir_design', fir_design, _known_fir_designs) + _check_option("phase", phase, _known_phases_iir, extra="when IIR filtering") + _validate_type(fir_window, "str", "fir_window") + _check_option("fir_window", fir_window, _known_fir_windows) + _validate_type(fir_design, "str", "fir_design") + _check_option("fir_design", fir_design, _known_fir_designs) # Helpers for reporting - report_phase = 'non-linear phase' if phase == 'minimum' else 'zero-phase' - causality = 'causal' if phase == 'minimum' else 'non-causal' - if phase == 'zero-double': - report_pass = 'two-pass forward and reverse' + report_phase = "non-linear phase" if phase == "minimum" else "zero-phase" + causality = "causal" if phase == "minimum" else "non-causal" + if phase == "zero-double": + report_pass = "two-pass forward and reverse" else: - report_pass = 'one-pass' + report_pass = "one-pass" if l_freq is not None: if h_freq is not None: - kind = 'bandstop' if reverse else 'bandpass' + kind = "bandstop" if reverse else "bandpass" else: - kind = 'highpass' + kind = "highpass" assert not reverse elif h_freq is not None: - kind = 'lowpass' + kind = "lowpass" assert not reverse else: - kind = 'allpass' + kind = "allpass" def float_array(c): return np.array(c, float).ravel() - if bands == 'arr': + if bands == "arr": cast = float_array else: cast = float @@ -1748,164 +2246,193 @@ def float_array(c): if l_freq is not None: l_freq = cast(l_freq) if np.any(l_freq <= 0): - raise ValueError('highpass frequency %s must be greater than zero' - % (l_freq,)) + raise ValueError( + "highpass frequency %s must be greater than zero" % (l_freq,) + ) if h_freq is not None: h_freq = cast(h_freq) - if np.any(h_freq >= sfreq / 2.): - raise ValueError('lowpass frequency %s must be less than Nyquist ' - '(%s)' % (h_freq, sfreq / 2.)) + if np.any(h_freq >= sfreq / 2.0): + raise ValueError( + "lowpass frequency %s must be less than Nyquist " + "(%s)" % (h_freq, sfreq / 2.0) + ) dB_cutoff = False # meaning, don't try to compute or report - if bands == 'scalar' or (len(h_freq) == 1 and len(l_freq) == 1): - if phase == 'zero': - dB_cutoff = '-6 dB' - elif phase == 'zero-double': - dB_cutoff = '-12 dB' + if bands == "scalar" or (len(h_freq) == 1 and len(l_freq) == 1): + if phase == "zero": + dB_cutoff = "-6 dB" + elif phase == "zero-double": + dB_cutoff = "-12 dB" # we go to the next power of two when in FIR and zero-double mode - if method == 'iir': + if method == "iir": # Ignore these parameters, effectively l_stop, h_stop = l_freq, h_freq else: # method == 'fir' l_stop = h_stop = None - logger.info('') - logger.info('FIR filter parameters') - logger.info('---------------------') - logger.info('Designing a %s, %s, %s %s filter:' - % (report_pass, report_phase, causality, kind)) - logger.info('- %s design (%s) method' - % (_fir_design_dict[fir_design], fir_design)) + logger.info("") + logger.info("FIR filter parameters") + logger.info("---------------------") + logger.info( + "Designing a %s, %s, %s %s filter:" + % (report_pass, report_phase, causality, kind) + ) + logger.info( + "- %s design (%s) method" % (_fir_design_dict[fir_design], fir_design) + ) this_dict = _fir_window_dict[fir_window] - if fir_design == 'firwin': - logger.info('- {name:s} window with {ripple:0.4f} passband ripple ' - 'and {attenuation:d} dB stopband attenuation' - .format(**this_dict)) + if fir_design == "firwin": + logger.info( + "- {name:s} window with {ripple:0.4f} passband ripple " + "and {attenuation:d} dB stopband attenuation".format(**this_dict) + ) else: - logger.info('- {name:s} window'.format(**this_dict)) + logger.info("- {name:s} window".format(**this_dict)) if l_freq is not None: # high-pass component if isinstance(l_trans_bandwidth, str): - if l_trans_bandwidth != 'auto': - raise ValueError('l_trans_bandwidth must be "auto" if ' - 'string, got "%s"' % l_trans_bandwidth) - l_trans_bandwidth = np.minimum(np.maximum(0.25 * l_freq, 2.), - l_freq) + if l_trans_bandwidth != "auto": + raise ValueError( + 'l_trans_bandwidth must be "auto" if ' + 'string, got "%s"' % l_trans_bandwidth + ) + l_trans_bandwidth = np.minimum(np.maximum(0.25 * l_freq, 2.0), l_freq) l_trans_rep = np.array(l_trans_bandwidth, float) if l_trans_rep.size == 1: - l_trans_rep = f'{l_trans_rep.item():0.2f}' - with np.printoptions(precision=2, floatmode='fixed'): - msg = f'- Lower transition bandwidth: {l_trans_rep} Hz' + l_trans_rep = f"{l_trans_rep.item():0.2f}" + with np.printoptions(precision=2, floatmode="fixed"): + msg = f"- Lower transition bandwidth: {l_trans_rep} Hz" if dB_cutoff: l_freq_rep = np.array(l_freq, float) if l_freq_rep.size == 1: - l_freq_rep = f'{l_freq_rep.item():0.2f}' - cutoff_rep = np.array( - l_freq - l_trans_bandwidth / 2., float) + l_freq_rep = f"{l_freq_rep.item():0.2f}" + cutoff_rep = np.array(l_freq - l_trans_bandwidth / 2.0, float) if cutoff_rep.size == 1: - cutoff_rep = f'{cutoff_rep.item():0.2f}' + cutoff_rep = f"{cutoff_rep.item():0.2f}" # Could be an array - logger.info(f'- Lower passband edge: {l_freq_rep}') - msg += f' ({dB_cutoff} cutoff frequency: {cutoff_rep} Hz)' + logger.info(f"- Lower passband edge: {l_freq_rep}") + msg += f" ({dB_cutoff} cutoff frequency: {cutoff_rep} Hz)" logger.info(msg) l_trans_bandwidth = cast(l_trans_bandwidth) if np.any(l_trans_bandwidth <= 0): - raise ValueError('l_trans_bandwidth must be positive, got %s' - % (l_trans_bandwidth,)) + raise ValueError( + "l_trans_bandwidth must be positive, got %s" % (l_trans_bandwidth,) + ) l_stop = l_freq - l_trans_bandwidth if reverse: # band-stop style l_stop += l_trans_bandwidth l_freq += l_trans_bandwidth if np.any(l_stop < 0): - raise ValueError('Filter specification invalid: Lower stop ' - 'frequency negative (%0.2f Hz). Increase pass' - ' frequency or reduce the transition ' - 'bandwidth (l_trans_bandwidth)' % l_stop) + raise ValueError( + "Filter specification invalid: Lower stop " + "frequency negative (%0.2f Hz). Increase pass" + " frequency or reduce the transition " + "bandwidth (l_trans_bandwidth)" % l_stop + ) if h_freq is not None: # low-pass component if isinstance(h_trans_bandwidth, str): - if h_trans_bandwidth != 'auto': - raise ValueError('h_trans_bandwidth must be "auto" if ' - 'string, got "%s"' % h_trans_bandwidth) - h_trans_bandwidth = np.minimum(np.maximum(0.25 * h_freq, 2.), - sfreq / 2. - h_freq) + if h_trans_bandwidth != "auto": + raise ValueError( + 'h_trans_bandwidth must be "auto" if ' + 'string, got "%s"' % h_trans_bandwidth + ) + h_trans_bandwidth = np.minimum( + np.maximum(0.25 * h_freq, 2.0), sfreq / 2.0 - h_freq + ) h_trans_rep = np.array(h_trans_bandwidth, float) if h_trans_rep.size == 1: - h_trans_rep = f'{h_trans_rep.item():0.2f}' - with np.printoptions(precision=2, floatmode='fixed'): - msg = f'- Upper transition bandwidth: {h_trans_rep} Hz' + h_trans_rep = f"{h_trans_rep.item():0.2f}" + with np.printoptions(precision=2, floatmode="fixed"): + msg = f"- Upper transition bandwidth: {h_trans_rep} Hz" if dB_cutoff: h_freq_rep = np.array(h_freq, float) if h_freq_rep.size == 1: - h_freq_rep = f'{h_freq_rep.item():0.2f}' - cutoff_rep = np.array( - h_freq + h_trans_bandwidth / 2., float) + h_freq_rep = f"{h_freq_rep.item():0.2f}" + cutoff_rep = np.array(h_freq + h_trans_bandwidth / 2.0, float) if cutoff_rep.size == 1: - cutoff_rep = f'{cutoff_rep.item():0.2f}' - logger.info(f'- Upper passband edge: {h_freq_rep} Hz') - msg += f' ({dB_cutoff} cutoff frequency: {cutoff_rep} Hz)' + cutoff_rep = f"{cutoff_rep.item():0.2f}" + logger.info(f"- Upper passband edge: {h_freq_rep} Hz") + msg += f" ({dB_cutoff} cutoff frequency: {cutoff_rep} Hz)" logger.info(msg) h_trans_bandwidth = cast(h_trans_bandwidth) if np.any(h_trans_bandwidth <= 0): - raise ValueError('h_trans_bandwidth must be positive, got %s' - % (h_trans_bandwidth,)) + raise ValueError( + "h_trans_bandwidth must be positive, got %s" % (h_trans_bandwidth,) + ) h_stop = h_freq + h_trans_bandwidth if reverse: # band-stop style h_stop -= h_trans_bandwidth h_freq -= h_trans_bandwidth if np.any(h_stop > sfreq / 2): - raise ValueError('Effective band-stop frequency (%s) is too ' - 'high (maximum based on Nyquist is %s)' - % (h_stop, sfreq / 2.)) + raise ValueError( + "Effective band-stop frequency (%s) is too " + "high (maximum based on Nyquist is %s)" % (h_stop, sfreq / 2.0) + ) - if isinstance(filter_length, str) and filter_length.lower() == 'auto': + if isinstance(filter_length, str) and filter_length.lower() == "auto": filter_length = filter_length.lower() h_check = l_check = np.inf if h_freq is not None: h_check = min(np.atleast_1d(h_trans_bandwidth)) if l_freq is not None: l_check = min(np.atleast_1d(l_trans_bandwidth)) - mult_fact = 2. if fir_design == 'firwin2' else 1. - filter_length = '%ss' % (_length_factors[fir_window] * mult_fact / - float(min(h_check, l_check)),) + mult_fact = 2.0 if fir_design == "firwin2" else 1.0 + filter_length = "%ss" % ( + _length_factors[fir_window] * mult_fact / float(min(h_check, l_check)), + ) next_pow_2 = False # disable old behavior else: - next_pow_2 = ( - isinstance(filter_length, str) and phase == 'zero-double') + next_pow_2 = isinstance(filter_length, str) and phase == "zero-double" filter_length = _to_samples(filter_length, sfreq, phase, fir_design) # use correct type of filter (must be odd length for firwin and for # zero phase) - if fir_design == 'firwin' or phase == 'zero': + if fir_design == "firwin" or phase == "zero": filter_length += (filter_length - 1) % 2 - logger.info('- Filter length: %s samples (%0.3f s)' - % (filter_length, filter_length / sfreq)) - logger.info('') + logger.info( + "- Filter length: %s samples (%0.3f s)" + % (filter_length, filter_length / sfreq) + ) + logger.info("") if filter_length <= 0: - raise ValueError('filter_length must be positive, got %s' - % (filter_length,)) + raise ValueError( + "filter_length must be positive, got %s" % (filter_length,) + ) if next_pow_2: filter_length = 2 ** int(np.ceil(np.log2(filter_length))) - if fir_design == 'firwin': + if fir_design == "firwin": filter_length += (filter_length - 1) % 2 # If we have data supplied, do a sanity check if x is not None: x = _check_filterable(x) len_x = x.shape[-1] - if method != 'fir': + if method != "fir": filter_length = len_x if filter_length > len_x and not (l_freq is None and h_freq is None): - warn('filter_length (%s) is longer than the signal (%s), ' - 'distortion is likely. Reduce filter length or filter a ' - 'longer signal.' % (filter_length, len_x)) - - logger.debug('Using filter length: %s' % filter_length) - return (x, sfreq, l_freq, h_freq, l_stop, h_stop, filter_length, phase, - fir_window, fir_design) + warn( + "filter_length (%s) is longer than the signal (%s), " + "distortion is likely. Reduce filter length or filter a " + "longer signal." % (filter_length, len_x) + ) + + logger.debug("Using filter length: %s" % filter_length) + return ( + x, + sfreq, + l_freq, + h_freq, + l_stop, + h_stop, + filter_length, + phase, + fir_window, + fir_design, + ) class FilterMixin: @@ -1957,26 +2484,40 @@ def savgol_filter(self, h_freq, verbose=None): >>> evoked.plot() # doctest:+SKIP """ # noqa: E501 from scipy.signal import savgol_filter - _check_preload(self, 'inst.savgol_filter') + + _check_preload(self, "inst.savgol_filter") h_freq = float(h_freq) - if h_freq >= self.info['sfreq'] / 2.: - raise ValueError('h_freq must be less than half the sample rate') + if h_freq >= self.info["sfreq"] / 2.0: + raise ValueError("h_freq must be less than half the sample rate") # savitzky-golay filtering - window_length = (int(np.round(self.info['sfreq'] / - h_freq)) // 2) * 2 + 1 - logger.info('Using savgol length %d' % window_length) - self._data[:] = savgol_filter(self._data, axis=-1, polyorder=5, - window_length=window_length) + window_length = (int(np.round(self.info["sfreq"] / h_freq)) // 2) * 2 + 1 + logger.info("Using savgol length %d" % window_length) + self._data[:] = savgol_filter( + self._data, axis=-1, polyorder=5, window_length=window_length + ) return self @verbose - def filter(self, l_freq, h_freq, picks=None, filter_length='auto', - l_trans_bandwidth='auto', h_trans_bandwidth='auto', n_jobs=None, - method='fir', iir_params=None, phase='zero', - fir_window='hamming', fir_design='firwin', - skip_by_annotation=('edge', 'bad_acq_skip'), pad='edge', *, - verbose=None): + def filter( + self, + l_freq, + h_freq, + picks=None, + filter_length="auto", + l_trans_bandwidth="auto", + h_trans_bandwidth="auto", + n_jobs=None, + method="fir", + iir_params=None, + phase="zero", + fir_window="hamming", + fir_design="firwin", + skip_by_annotation=("edge", "bad_acq_skip"), + pad="edge", + *, + verbose=None, + ): """Filter a subset of channels. Parameters @@ -2045,38 +2586,62 @@ def filter(self, l_freq, h_freq, picks=None, filter_length='auto', .. versionadded:: 0.15 """ from .io.base import BaseRaw - _check_preload(self, 'inst.filter') - if pad is None and method != 'iir': - pad = 'edge' - update_info, picks = _filt_check_picks(self.info, picks, - l_freq, h_freq) + + _check_preload(self, "inst.filter") + if pad is None and method != "iir": + pad = "edge" + update_info, picks = _filt_check_picks(self.info, picks, l_freq, h_freq) if isinstance(self, BaseRaw): # Deal with annotations onsets, ends = _annotations_starts_stops( - self, skip_by_annotation, invert=True) - logger.info('Filtering raw data in %d contiguous segment%s' - % (len(onsets), _pl(onsets))) + self, skip_by_annotation, invert=True + ) + logger.info( + "Filtering raw data in %d contiguous segment%s" + % (len(onsets), _pl(onsets)) + ) else: onsets, ends = np.array([0]), np.array([self._data.shape[1]]) max_idx = (ends - onsets).argmax() for si, (start, stop) in enumerate(zip(onsets, ends)): # Only output filter params once (for info level), and only warn # once about the length criterion (longest segment is too short) - use_verbose = verbose if si == max_idx else 'error' + use_verbose = verbose if si == max_idx else "error" filter_data( - self._data[:, start:stop], self.info['sfreq'], l_freq, h_freq, - picks, filter_length, l_trans_bandwidth, h_trans_bandwidth, - n_jobs, method, iir_params, copy=False, phase=phase, - fir_window=fir_window, fir_design=fir_design, pad=pad, - verbose=use_verbose) + self._data[:, start:stop], + self.info["sfreq"], + l_freq, + h_freq, + picks, + filter_length, + l_trans_bandwidth, + h_trans_bandwidth, + n_jobs, + method, + iir_params, + copy=False, + phase=phase, + fir_window=fir_window, + fir_design=fir_design, + pad=pad, + verbose=use_verbose, + ) # update info if filter is applied to all data channels, # and it's not a band-stop filter _filt_update_info(self.info, update_info, l_freq, h_freq) return self @verbose - def resample(self, sfreq, npad='auto', window='boxcar', n_jobs=None, - pad='edge', *, verbose=None): + def resample( + self, + sfreq, + npad="auto", + window="boxcar", + n_jobs=None, + pad="edge", + *, + verbose=None, + ): """Resample data. If appropriate, an anti-aliasing filter is applied before resampling. @@ -2114,23 +2679,26 @@ def resample(self, sfreq, npad='auto', window='boxcar', n_jobs=None, """ from .epochs import BaseEpochs from .evoked import Evoked + # Should be guaranteed by our inheritance, and the fact that # mne.io.base.BaseRaw overrides this method assert isinstance(self, (BaseEpochs, Evoked)) - _check_preload(self, 'inst.resample') + _check_preload(self, "inst.resample") sfreq = float(sfreq) - o_sfreq = self.info['sfreq'] - self._data = resample(self._data, sfreq, o_sfreq, npad, window=window, - n_jobs=n_jobs, pad=pad) - lowpass = self.info.get('lowpass') + o_sfreq = self.info["sfreq"] + self._data = resample( + self._data, sfreq, o_sfreq, npad, window=window, n_jobs=n_jobs, pad=pad + ) + lowpass = self.info.get("lowpass") lowpass = np.inf if lowpass is None else lowpass with self.info._unlock(): - self.info['lowpass'] = min(lowpass, sfreq / 2.) - self.info['sfreq'] = float(sfreq) - new_times = (np.arange(self._data.shape[-1], dtype=np.float64) / - sfreq + self.times[0]) + self.info["lowpass"] = min(lowpass, sfreq / 2.0) + self.info["sfreq"] = float(sfreq) + new_times = ( + np.arange(self._data.shape[-1], dtype=np.float64) / sfreq + self.times[0] + ) # adjust indirectly affected variables self._set_times(new_times) self._raw_times = self.times @@ -2138,8 +2706,9 @@ def resample(self, sfreq, npad='auto', window='boxcar', n_jobs=None, return self @verbose - def apply_hilbert(self, picks=None, envelope=False, n_jobs=None, - n_fft='auto', *, verbose=None): + def apply_hilbert( + self, picks=None, envelope=False, n_jobs=None, n_fft="auto", *, verbose=None + ): """Compute analytic signal or envelope for a subset of channels. Parameters @@ -2203,18 +2772,22 @@ def apply_hilbert(self, picks=None, envelope=False, n_jobs=None, by computing the analytic signal in sensor space, applying the MNE inverse, and computing the envelope in source space. """ - _check_preload(self, 'inst.apply_hilbert') + _check_preload(self, "inst.apply_hilbert") if n_fft is None: n_fft = len(self.times) elif isinstance(n_fft, str): - if n_fft != 'auto': - raise ValueError('n_fft must be an integer, string, or None, ' - 'got %s' % (type(n_fft),)) + if n_fft != "auto": + raise ValueError( + "n_fft must be an integer, string, or None, " + "got %s" % (type(n_fft),) + ) n_fft = next_fast_len(len(self.times)) n_fft = int(n_fft) if n_fft < len(self.times): - raise ValueError("n_fft (%d) must be at least the number of time " - "points (%d)" % (n_fft, len(self.times))) + raise ValueError( + "n_fft (%d) must be at least the number of time " + "points (%d)" % (n_fft, len(self.times)) + ) dtype = None if envelope else np.complex128 picks = _picks_to_idx(self.info, picks, exclude=(), with_ref_meg=False) args, kwargs = (), dict(n_fft=n_fft, envelope=envelope) @@ -2228,12 +2801,13 @@ def apply_hilbert(self, picks=None, envelope=False, n_jobs=None, # modify data inplace to save memory for idx in picks: self._data[..., idx, :] = _check_fun( - _my_hilbert, data_in[..., idx, :], *args, **kwargs) + _my_hilbert, data_in[..., idx, :], *args, **kwargs + ) else: # use parallel function data_picks_new = parallel( - p_fun(_my_hilbert, data_in[..., p, :], *args, **kwargs) - for p in picks) + p_fun(_my_hilbert, data_in[..., p, :], *args, **kwargs) for p in picks + ) for pp, p in enumerate(picks): self._data[..., p, :] = data_picks_new[pp] return self @@ -2244,10 +2818,11 @@ def _check_fun(fun, d, *args, **kwargs): want_shape = d.shape d = fun(d, *args, **kwargs) if not isinstance(d, np.ndarray): - raise TypeError('Return value must be an ndarray') + raise TypeError("Return value must be an ndarray") if d.shape != want_shape: - raise ValueError('Return data must have shape %s not %s' - % (want_shape, d.shape)) + raise ValueError( + "Return data must have shape %s not %s" % (want_shape, d.shape) + ) return d @@ -2271,6 +2846,7 @@ def _my_hilbert(x, n_fft=None, envelope=False): The hilbert transform of the signal, or the envelope. """ from scipy.signal import hilbert + n_x = x.shape[-1] out = hilbert(x, N=n_fft, axis=-1)[..., :n_x] if envelope: @@ -2279,9 +2855,14 @@ def _my_hilbert(x, n_fft=None, envelope=False): @verbose -def design_mne_c_filter(sfreq, l_freq=None, h_freq=40., - l_trans_bandwidth=None, h_trans_bandwidth=5., - verbose=None): +def design_mne_c_filter( + sfreq, + l_freq=None, + h_freq=40.0, + l_trans_bandwidth=None, + h_trans_bandwidth=5.0, + verbose=None, +): """Create a FIR filter like that used by MNE-C. Parameters @@ -2315,39 +2896,39 @@ def design_mne_c_filter(sfreq, l_freq=None, h_freq=40., and ones in the passband, with squared cosine ramps in between. """ from scipy.fft import irfft + n_freqs = (4096 + 2 * 2048) // 2 + 1 freq_resp = np.ones(n_freqs) l_freq = 0 if l_freq is None else float(l_freq) if l_trans_bandwidth is None: l_width = 3 else: - l_width = (int(((n_freqs - 1) * l_trans_bandwidth) / - (0.5 * sfreq)) + 1) // 2 + l_width = (int(((n_freqs - 1) * l_trans_bandwidth) / (0.5 * sfreq)) + 1) // 2 l_start = int(((n_freqs - 1) * l_freq) / (0.5 * sfreq)) - h_freq = sfreq / 2. if h_freq is None else float(h_freq) - h_width = (int(((n_freqs - 1) * h_trans_bandwidth) / - (0.5 * sfreq)) + 1) // 2 + h_freq = sfreq / 2.0 if h_freq is None else float(h_freq) + h_width = (int(((n_freqs - 1) * h_trans_bandwidth) / (0.5 * sfreq)) + 1) // 2 h_start = int(((n_freqs - 1) * h_freq) / (0.5 * sfreq)) - logger.info('filter : %7.3f ... %6.1f Hz bins : %d ... %d of %d ' - 'hpw : %d lpw : %d' % (l_freq, h_freq, l_start, h_start, - n_freqs, l_width, h_width)) + logger.info( + "filter : %7.3f ... %6.1f Hz bins : %d ... %d of %d " + "hpw : %d lpw : %d" + % (l_freq, h_freq, l_start, h_start, n_freqs, l_width, h_width) + ) if l_freq > 0: start = l_start - l_width + 1 stop = start + 2 * l_width - 1 if start < 0 or stop >= n_freqs: - raise RuntimeError('l_freq too low or l_trans_bandwidth too large') - freq_resp[:start] = 0. - k = np.arange(-l_width + 1, l_width) / float(l_width) + 3. - freq_resp[start:stop] = np.cos(np.pi / 4. * k) ** 2 + raise RuntimeError("l_freq too low or l_trans_bandwidth too large") + freq_resp[:start] = 0.0 + k = np.arange(-l_width + 1, l_width) / float(l_width) + 3.0 + freq_resp[start:stop] = np.cos(np.pi / 4.0 * k) ** 2 - if h_freq < sfreq / 2.: + if h_freq < sfreq / 2.0: start = h_start - h_width + 1 stop = start + 2 * h_width - 1 if start < 0 or stop >= n_freqs: - raise RuntimeError('h_freq too high or h_trans_bandwidth too ' - 'large') - k = np.arange(-h_width + 1, h_width) / float(h_width) + 1. - freq_resp[start:stop] *= np.cos(np.pi / 4. * k) ** 2 + raise RuntimeError("h_freq too high or h_trans_bandwidth too " "large") + k = np.arange(-h_width + 1, h_width) / float(h_width) + 1.0 + freq_resp[start:stop] *= np.cos(np.pi / 4.0 * k) ** 2 freq_resp[stop:] = 0.0 # Get the time-domain version of this signal h = irfft(freq_resp, n=2 * len(freq_resp) - 1) @@ -2357,32 +2938,44 @@ def design_mne_c_filter(sfreq, l_freq=None, h_freq=40., def _filt_check_picks(info, picks, h_freq, l_freq): from .io.pick import _picks_to_idx + update_info = False # This will pick *all* data channels - picks = _picks_to_idx(info, picks, 'data_or_ica', exclude=()) + picks = _picks_to_idx(info, picks, "data_or_ica", exclude=()) if h_freq is not None or l_freq is not None: - data_picks = _picks_to_idx(info, None, 'data_or_ica', exclude=(), - allow_empty=True) + data_picks = _picks_to_idx( + info, None, "data_or_ica", exclude=(), allow_empty=True + ) if len(data_picks) == 0: - logger.info('No data channels found. The highpass and ' - 'lowpass values in the measurement info will not ' - 'be updated.') + logger.info( + "No data channels found. The highpass and " + "lowpass values in the measurement info will not " + "be updated." + ) elif np.in1d(data_picks, picks).all(): update_info = True else: - logger.info('Filtering a subset of channels. The highpass and ' - 'lowpass values in the measurement info will not ' - 'be updated.') + logger.info( + "Filtering a subset of channels. The highpass and " + "lowpass values in the measurement info will not " + "be updated." + ) return update_info, picks def _filt_update_info(info, update_info, l_freq, h_freq): if update_info: - if h_freq is not None and (l_freq is None or l_freq < h_freq) and \ - (info["lowpass"] is None or h_freq < info['lowpass']): + if ( + h_freq is not None + and (l_freq is None or l_freq < h_freq) + and (info["lowpass"] is None or h_freq < info["lowpass"]) + ): with info._unlock(): - info['lowpass'] = float(h_freq) - if l_freq is not None and (h_freq is None or l_freq < h_freq) and \ - (info["highpass"] is None or l_freq > info['highpass']): + info["lowpass"] = float(h_freq) + if ( + l_freq is not None + and (h_freq is None or l_freq < h_freq) + and (info["highpass"] is None or l_freq > info["highpass"]) + ): with info._unlock(): - info['highpass'] = float(l_freq) + info["highpass"] = float(l_freq) diff --git a/mne/fixes.py b/mne/fixes.py index b59439b3b88..c05dfaec344 100644 --- a/mne/fixes.py +++ b/mne/fixes.py @@ -30,6 +30,7 @@ # from the standard library with the release of Python 3.12. For version # comparisons, we use setuptools's `parse_version` if available. + def _compare_version(version_a, operator, version_b): """Compare two version strings via a user-specified operator. @@ -49,14 +50,16 @@ def _compare_version(version_a, operator, version_b): The result of the version comparison. """ from packaging.version import parse + with warnings.catch_warnings(record=True): - warnings.simplefilter('ignore') + warnings.simplefilter("ignore") return eval(f'parse("{version_a}") {operator} parse("{version_b}")') ############################################################################### # Misc + def _median_complex(data, axis): """Compute marginal median on complex data safely. @@ -65,8 +68,9 @@ def _median_complex(data, axis): """ # np.median must be passed real arrays for the desired result if np.iscomplexobj(data): - data = (np.median(np.real(data), axis=axis) - + 1j * np.median(np.imag(data), axis=axis)) + data = np.median(np.real(data), axis=axis) + 1j * np.median( + np.imag(data), axis=axis + ) else: data = np.median(data, axis=axis) return data @@ -79,19 +83,21 @@ def _safe_svd(A, **kwargs): # For SciPy 0.18 and up, we can work around it by using # lapack_driver='gesvd' instead. from scipy import linalg - if kwargs.get('overwrite_a', False): - raise ValueError('Cannot set overwrite_a=True with this function') + + if kwargs.get("overwrite_a", False): + raise ValueError("Cannot set overwrite_a=True with this function") try: return linalg.svd(A, **kwargs) except np.linalg.LinAlgError as exp: from .utils import warn - warn('SVD error (%s), attempting to use GESVD instead of GESDD' - % (exp,)) - return linalg.svd(A, lapack_driver='gesvd', **kwargs) + + warn("SVD error (%s), attempting to use GESVD instead of GESDD" % (exp,)) + return linalg.svd(A, lapack_driver="gesvd", **kwargs) def _csc_matrix_cast(x): from scipy.sparse import csc_matrix + return csc_matrix(x) @@ -102,25 +108,26 @@ def _csc_matrix_cast(x): def rng_uniform(rng): """Get the unform/randint from the rng.""" # prefer Generator.integers, fall back to RandomState.randint - return getattr(rng, 'integers', getattr(rng, 'randint', None)) + return getattr(rng, "integers", getattr(rng, "randint", None)) def _validate_sos(sos): """Helper to validate a SOS input""" sos = np.atleast_2d(sos) if sos.ndim != 2: - raise ValueError('sos array must be 2D') + raise ValueError("sos array must be 2D") n_sections, m = sos.shape if m != 6: - raise ValueError('sos array must be shape (n_sections, 6)') + raise ValueError("sos array must be shape (n_sections, 6)") if not (sos[:, 3] == 1).all(): - raise ValueError('sos[:, 3] should be all ones') + raise ValueError("sos[:, 3] should be all ones") return sos, n_sections ############################################################################### # Misc utilities + # get_fdata() requires knowing the dtype ahead of time, so let's triage on our # own instead def _get_img_fdata(img): @@ -134,22 +141,30 @@ def _read_volume_info(fobj): versions of nibabel (<=2.1.0) don't have it. """ volume_info = dict() - head = np.fromfile(fobj, '>i4', 1) + head = np.fromfile(fobj, ">i4", 1) if not np.array_equal(head, [20]): # Read two bytes more - head = np.concatenate([head, np.fromfile(fobj, '>i4', 2)]) + head = np.concatenate([head, np.fromfile(fobj, ">i4", 2)]) if not np.array_equal(head, [2, 0, 20]): warnings.warn("Unknown extension code.") return volume_info - volume_info['head'] = head - for key in ['valid', 'filename', 'volume', 'voxelsize', 'xras', 'yras', - 'zras', 'cras']: - pair = fobj.readline().decode('utf-8').split('=') + volume_info["head"] = head + for key in [ + "valid", + "filename", + "volume", + "voxelsize", + "xras", + "yras", + "zras", + "cras", + ]: + pair = fobj.readline().decode("utf-8").split("=") if pair[0].strip() != key or len(pair) != 2: - raise OSError('Error parsing volume info.') - if key in ('valid', 'filename'): + raise OSError("Error parsing volume info.") + if key in ("valid", "filename"): volume_info[key] = pair[1].strip() - elif key == 'volume': + elif key == "volume": volume_info[key] = np.array(pair[1].split()).astype(int) else: volume_info[key] = np.array(pair[1].split()).astype(float) @@ -194,24 +209,24 @@ def is_regressor(estimator): _DEFAULT_TAGS = { - 'non_deterministic': False, - 'requires_positive_X': False, - 'requires_positive_y': False, - 'X_types': ['2darray'], - 'poor_score': False, - 'no_validation': False, - 'multioutput': False, + "non_deterministic": False, + "requires_positive_X": False, + "requires_positive_y": False, + "X_types": ["2darray"], + "poor_score": False, + "no_validation": False, + "multioutput": False, "allow_nan": False, - 'stateless': False, - 'multilabel': False, - '_skip_test': False, - '_xfail_checks': False, - 'multioutput_only': False, - 'binary_only': False, - 'requires_fit': True, - 'preserves_dtype': [np.float64], - 'requires_y': False, - 'pairwise': False, + "stateless": False, + "multilabel": False, + "_skip_test": False, + "_xfail_checks": False, + "multioutput_only": False, + "binary_only": False, + "requires_fit": True, + "preserves_dtype": [np.float64], + "requires_y": False, + "pairwise": False, } @@ -230,7 +245,7 @@ def _get_param_names(cls): """Get parameter names for the estimator""" # fetch the constructor or the original constructor before # deprecation wrapping if any - init = getattr(cls.__init__, 'deprecated_original', cls.__init__) + init = getattr(cls.__init__, "deprecated_original", cls.__init__) if init is object.__init__: # No explicit constructor to introspect return [] @@ -239,16 +254,20 @@ def _get_param_names(cls): # to represent init_signature = inspect.signature(init) # Consider the constructor parameters excluding 'self' - parameters = [p for p in init_signature.parameters.values() - if p.name != 'self' and p.kind != p.VAR_KEYWORD] + parameters = [ + p + for p in init_signature.parameters.values() + if p.name != "self" and p.kind != p.VAR_KEYWORD + ] for p in parameters: if p.kind == p.VAR_POSITIONAL: - raise RuntimeError("scikit-learn estimators should always " - "specify their parameters in the signature" - " of their __init__ (no varargs)." - " %s with constructor %s doesn't " - " follow this convention." - % (cls, init_signature)) + raise RuntimeError( + "scikit-learn estimators should always " + "specify their parameters in the signature" + " of their __init__ (no varargs)." + " %s with constructor %s doesn't " + " follow this convention." % (cls, init_signature) + ) # Extract and sort argument names excluding 'self' return sorted([p.name for p in parameters]) @@ -283,9 +302,9 @@ def get_params(self, deep=True): warnings.filters.pop(0) # XXX: should we rather test if instance of estimator? - if deep and hasattr(value, 'get_params'): + if deep and hasattr(value, "get_params"): deep_items = value.get_params().items() - out.update((key + '__' + k, val) for k, val in deep_items) + out.update((key + "__" + k, val) for k, val in deep_items) out[key] = value return out @@ -312,24 +331,27 @@ def set_params(self, **params): return self valid_params = self.get_params(deep=True) for key, value in params.items(): - split = key.split('__', 1) + split = key.split("__", 1) if len(split) > 1: # nested objects case name, sub_name = split if name not in valid_params: - raise ValueError('Invalid parameter %s for estimator %s. ' - 'Check the list of available parameters ' - 'with `estimator.get_params().keys()`.' % - (name, self)) + raise ValueError( + "Invalid parameter %s for estimator %s. " + "Check the list of available parameters " + "with `estimator.get_params().keys()`." % (name, self) + ) sub_object = valid_params[name] sub_object.set_params(**{sub_name: value}) else: # simple objects case if key not in valid_params: - raise ValueError('Invalid parameter %s for estimator %s. ' - 'Check the list of available parameters ' - 'with `estimator.get_params().keys()`.' % - (key, self.__class__.__name__)) + raise ValueError( + "Invalid parameter %s for estimator %s. " + "Check the list of available parameters " + "with `estimator.get_params().keys()`." + % (key, self.__class__.__name__) + ) setattr(self, key, value) return self @@ -338,7 +360,7 @@ def __repr__(self): pprint(self.get_params(deep=False), params) params.seek(0) class_name = self.__class__.__name__ - return '%s(%s)' % (class_name, params.read().strip()) + return "%s(%s)" % (class_name, params.read().strip()) # __getstate__ and __setstate__ are omitted because they only contain # conditionals that are not satisfied by our objects (e.g., @@ -350,7 +372,7 @@ def _more_tags(self): def _get_tags(self): collected_tags = {} for base_class in reversed(inspect.getmro(self.__class__)): - if hasattr(base_class, '_more_tags'): + if hasattr(base_class, "_more_tags"): # need the if because mixins might not have _more_tags # but might do redundant work in estimators # (i.e. calling more tags on BaseEstimator multiple times) @@ -391,21 +413,25 @@ def _check_fit_params(X, fit_params, indices=None): indexing. """ try: - from sklearn.utils.validation import \ - _check_fit_params as _sklearn_check_fit_params + from sklearn.utils.validation import ( + _check_fit_params as _sklearn_check_fit_params, + ) + return _sklearn_check_fit_params(X, fit_params, indices) except ImportError: from sklearn.model_selection import _validation - fit_params_validated = \ - {k: _validation._index_param_value(X, v, indices) - for k, v in fit_params.items()} + fit_params_validated = { + k: _validation._index_param_value(X, v, indices) + for k, v in fit_params.items() + } return fit_params_validated ############################################################################### # Copied from sklearn to simplify code paths + def empirical_covariance(X, assume_centered=False): """Computes the Maximum likelihood covariance estimator @@ -432,8 +458,9 @@ def empirical_covariance(X, assume_centered=False): X = np.reshape(X, (1, -1)) if X.shape[0] == 1: - warnings.warn("Only one sample available. " - "You may want to reshape your data array") + warnings.warn( + "Only one sample available. " "You may want to reshape your data array" + ) if assume_centered: covariance = np.dot(X.T, X) / X.shape[0] @@ -471,6 +498,7 @@ class EmpiricalCovariance(BaseEstimator): (stored only if store_precision is True) """ + def __init__(self, store_precision=True, assume_centered=False): self.store_precision = store_precision self.assume_centered = assume_centered @@ -489,6 +517,7 @@ def _set_covariance(self, covariance): """ from scipy import linalg + # covariance = check_array(covariance) # set covariance self.covariance_ = covariance @@ -508,6 +537,7 @@ def get_precision(self): """ from scipy import linalg + if self.store_precision: precision = self.precision_ else: @@ -535,8 +565,7 @@ def fit(self, X, y=None): self.location_ = np.zeros(X.shape[1]) else: self.location_ = X.mean(0) - covariance = empirical_covariance( - X, assume_centered=self.assume_centered) + covariance = empirical_covariance(X, assume_centered=self.assume_centered) self._set_covariance(covariance) return self @@ -563,15 +592,13 @@ def score(self, X_test, y=None): estimator of its covariance matrix. """ # compute empirical covariance of the test set - test_cov = empirical_covariance( - X_test - self.location_, assume_centered=True) + test_cov = empirical_covariance(X_test - self.location_, assume_centered=True) # compute log likelihood res = log_likelihood(test_cov, self.get_precision()) return res - def error_norm(self, comp_cov, norm='frobenius', scaling=True, - squared=True): + def error_norm(self, comp_cov, norm="frobenius", scaling=True, squared=True): """Computes the Mean Squared Error between two covariance estimators. Parameters @@ -597,16 +624,18 @@ def error_norm(self, comp_cov, norm='frobenius', scaling=True, `self` and `comp_cov` covariance estimators. """ from scipy import linalg + # compute the error error = comp_cov - self.covariance_ # compute the error norm if norm == "frobenius": - squared_norm = np.sum(error ** 2) + squared_norm = np.sum(error**2) elif norm == "spectral": squared_norm = np.amax(linalg.svdvals(np.dot(error.T, error))) else: raise NotImplementedError( - "Only spectral and frobenius norms are implemented") + "Only spectral and frobenius norms are implemented" + ) # optionally scale the error norm if scaling: squared_norm = squared_norm / error.shape[0] @@ -637,8 +666,7 @@ def mahalanobis(self, observations): precision = self.get_precision() # compute mahalanobis distances centered_obs = observations - self.location_ - mahalanobis_dist = np.sum( - np.dot(centered_obs, precision) * centered_obs, 1) + mahalanobis_dist = np.sum(np.dot(centered_obs, precision) * centered_obs, 1) return mahalanobis_dist @@ -663,17 +691,19 @@ def log_likelihood(emp_cov, precision): sample mean of the log-likelihood """ p = precision.shape[0] - log_likelihood_ = - np.sum(emp_cov * precision) + _logdet(precision) + log_likelihood_ = -np.sum(emp_cov * precision) + _logdet(precision) log_likelihood_ -= p * np.log(2 * np.pi) - log_likelihood_ /= 2. + log_likelihood_ /= 2.0 return log_likelihood_ # sklearn uses np.linalg for this, but ours is more robust to zero eigenvalues + def _logdet(A): """Compute the log det of a positive semidefinite matrix.""" from scipy import linalg + vals = linalg.eigvalsh(A) # avoid negative (numerical errors) or zero (semi-definite matrix) values tol = vals.max() * vals.size * np.finfo(np.float64).eps @@ -694,37 +724,37 @@ def _infer_dimension_(spectrum, n_samples, n_features): def _assess_dimension_(spectrum, rank, n_samples, n_features): from scipy.special import gammaln + if rank > len(spectrum): - raise ValueError("The tested rank cannot exceed the rank of the" - " dataset") + raise ValueError("The tested rank cannot exceed the rank of the" " dataset") - pu = -rank * log(2.) + pu = -rank * log(2.0) for i in range(rank): - pu += (gammaln((n_features - i) / 2.) - - log(np.pi) * (n_features - i) / 2.) + pu += gammaln((n_features - i) / 2.0) - log(np.pi) * (n_features - i) / 2.0 pl = np.sum(np.log(spectrum[:rank])) - pl = -pl * n_samples / 2. + pl = -pl * n_samples / 2.0 if rank == n_features: pv = 0 v = 1 else: v = np.sum(spectrum[rank:]) / (n_features - rank) - pv = -np.log(v) * n_samples * (n_features - rank) / 2. + pv = -np.log(v) * n_samples * (n_features - rank) / 2.0 - m = n_features * rank - rank * (rank + 1.) / 2. - pp = log(2. * np.pi) * (m + rank + 1.) / 2. + m = n_features * rank - rank * (rank + 1.0) / 2.0 + pp = log(2.0 * np.pi) * (m + rank + 1.0) / 2.0 - pa = 0. + pa = 0.0 spectrum_ = spectrum.copy() spectrum_[rank:n_features] = v for i in range(rank): for j in range(i + 1, len(spectrum)): - pa += log((spectrum[i] - spectrum[j]) * - (1. / spectrum_[j] - 1. / spectrum_[i])) + log(n_samples) + pa += log( + (spectrum[i] - spectrum[j]) * (1.0 / spectrum_[j] - 1.0 / spectrum_[i]) + ) + log(n_samples) - ll = pu + pl + pv + pp - pa / 2. - rank * log(n_samples) / 2. + ll = pu + pl + pv + pp - pa / 2.0 - rank * log(n_samples) / 2.0 return ll @@ -762,23 +792,30 @@ def stable_cumsum(arr, axis=None, rtol=1e-05, atol=1e-08): """ out = np.cumsum(arr, axis=axis, dtype=np.float64) expected = np.sum(arr, axis=axis, dtype=np.float64) - if not np.all(np.isclose(out.take(-1, axis=axis), expected, rtol=rtol, - atol=atol, equal_nan=True)): - warnings.warn('cumsum was found to be unstable: ' - 'its last element does not correspond to sum', - RuntimeWarning) + if not np.all( + np.isclose( + out.take(-1, axis=axis), expected, rtol=rtol, atol=atol, equal_nan=True + ) + ): + warnings.warn( + "cumsum was found to be unstable: " + "its last element does not correspond to sum", + RuntimeWarning, + ) return out ############################################################################### # From nilearn + def _crop_colorbar(cbar, cbar_vmin, cbar_vmax): """ crop a colorbar to show from cbar_vmin to cbar_vmax Used when symmetric_cbar=False is used. """ import matplotlib + if (cbar_vmin is None) and (cbar_vmax is None): return cbar_tick_locs = cbar.locator.locs @@ -786,8 +823,7 @@ def _crop_colorbar(cbar, cbar_vmin, cbar_vmax): cbar_vmax = cbar_tick_locs.max() if cbar_vmin is None: cbar_vmin = cbar_tick_locs.min() - new_tick_locs = np.linspace(cbar_vmin, cbar_vmax, - len(cbar_tick_locs)) + new_tick_locs = np.linspace(cbar_vmin, cbar_vmax, len(cbar_tick_locs)) cbar.ax.set_ylim(cbar_vmin, cbar_vmax) X = cbar._mesh()[0] @@ -797,9 +833,11 @@ def _crop_colorbar(cbar, cbar_vmin, cbar_vmax): ii = [0, 1, N - 2, N - 1, 2 * N - 1, 2 * N - 2, N + 1, N, 0] x = X.T.reshape(-1)[ii] y = Y.T.reshape(-1)[ii] - xy = (np.column_stack([y, x]) - if cbar.orientation == 'horizontal' else - np.column_stack([x, y])) + xy = ( + np.column_stack([y, x]) + if cbar.orientation == "horizontal" + else np.column_stack([x, y]) + ) cbar.outline.set_xy(xy) cbar.set_ticks(new_tick_locs) @@ -812,29 +850,36 @@ def _crop_colorbar(cbar, cbar_vmin, cbar_vmax): # Here we choose different defaults to speed things up by default try: import numba - if _compare_version(numba.__version__, '<', '0.53.1'): + + if _compare_version(numba.__version__, "<", "0.53.1"): raise ImportError prange = numba.prange - def jit(nopython=True, nogil=True, fastmath=True, cache=True, - **kwargs): # noqa - return numba.jit(nopython=nopython, nogil=nogil, fastmath=fastmath, - cache=cache, **kwargs) + + def jit(nopython=True, nogil=True, fastmath=True, cache=True, **kwargs): # noqa + return numba.jit( + nopython=nopython, nogil=nogil, fastmath=fastmath, cache=cache, **kwargs + ) + except Exception: # could be ImportError, SystemError, etc. has_numba = False else: - has_numba = (os.getenv('MNE_USE_NUMBA', 'true').lower() == 'true') + has_numba = os.getenv("MNE_USE_NUMBA", "true").lower() == "true" if not has_numba: + def jit(**kwargs): # noqa def _jit(func): return func + return _jit + prange = range bincount = np.bincount mean = np.mean else: + @jit() def bincount(x, weights, minlength): # noqa: D103 out = np.zeros(minlength) @@ -865,6 +910,7 @@ def mean(array, axis): ############################################################################### # Matplotlib + # workaround: plt.close() doesn't spawn close_event on Agg backend # https://github.com/matplotlib/matplotlib/issues/18609 # scheduled to be fixed by MPL 3.6 @@ -872,13 +918,15 @@ def _close_event(fig): """Force calling of the MPL figure close event.""" from .utils import logger from matplotlib import backend_bases + try: fig.canvas.callbacks.process( - 'close_event', backend_bases.CloseEvent( - name='close_event', canvas=fig.canvas)) - logger.debug(f'Called {fig!r}.canvas.close_event()') + "close_event", + backend_bases.CloseEvent(name="close_event", canvas=fig.canvas), + ) + logger.debug(f"Called {fig!r}.canvas.close_event()") except ValueError: # old mpl with Qt - logger.debug(f'Calling {fig!r}.canvas.close_event() failed') + logger.debug(f"Calling {fig!r}.canvas.close_event() failed") pass # pragma: no cover @@ -891,7 +939,7 @@ def _is_last_row(ax): def _sharex(ax1, ax2): - if hasattr(ax1.axes, 'sharex'): + if hasattr(ax1.axes, "sharex"): ax1.axes.sharex(ax2) else: ax1.get_shared_x_axes().join(ax1, ax2) @@ -900,6 +948,7 @@ def _sharex(ax1, ax2): ############################################################################### # SciPy deprecation of pinv + pinvh rcond (never worked properly anyway) in 1.7 + def pinvh(a, rtol=None): """Compute a pseudo-inverse of a Hermitian matrix.""" s, u = np.linalg.eigh(a) @@ -907,7 +956,7 @@ def pinvh(a, rtol=None): if rtol is None: rtol = s.size * np.finfo(s.dtype).eps maxS = np.max(np.abs(s)) - above_cutoff = (abs(s) > maxS * rtol) + above_cutoff = abs(s) > maxS * rtol psigma_diag = 1.0 / s[above_cutoff] u = u[:, above_cutoff] return (u * psigma_diag) @ u.conj().T @@ -929,9 +978,12 @@ def pinv(a, rtol=None): ############################################################################### # h5py uses np.product which is deprecated in NumPy 1.25 + @contextmanager def _numpy_h5py_dep(): # h5io uses np.product with warnings.catch_warnings(record=True): - warnings.filterwarnings('ignore', '`product` is deprecated.*', DeprecationWarning) + warnings.filterwarnings( + "ignore", "`product` is deprecated.*", DeprecationWarning + ) yield diff --git a/mne/forward/__init__.py b/mne/forward/__init__.py index 83788b8f706..c5fbeced9a4 100644 --- a/mne/forward/__init__.py +++ b/mne/forward/__init__.py @@ -1,22 +1,48 @@ """Forward modeling code.""" -from .forward import (Forward, read_forward_solution, write_forward_solution, - is_fixed_orient, _read_forward_meas_info, - _select_orient_forward, - compute_orient_prior, compute_depth_prior, - apply_forward, apply_forward_raw, - restrict_forward_to_stc, restrict_forward_to_label, - average_forward_solutions, _stc_src_sel, - _fill_measurement_info, _apply_forward, - _subject_from_forward, convert_forward_solution, - _merge_fwds, _do_forward_solution) -from ._make_forward import (make_forward_solution, _prepare_for_forward, - _prep_meg_channels, _prep_eeg_channels, - _to_forward_dict, _create_meg_coils, - _read_coil_defs, _transform_orig_meg_coils, - make_forward_dipole, use_coil_def) -from ._compute_forward import (_magnetic_dipole_field_vec, _compute_forwards, - _concatenate_coils) -from ._field_interpolation import (_make_surface_mapping, make_field_map, - _as_meg_type_inst, _map_meg_or_eeg_channels) +from .forward import ( + Forward, + read_forward_solution, + write_forward_solution, + is_fixed_orient, + _read_forward_meas_info, + _select_orient_forward, + compute_orient_prior, + compute_depth_prior, + apply_forward, + apply_forward_raw, + restrict_forward_to_stc, + restrict_forward_to_label, + average_forward_solutions, + _stc_src_sel, + _fill_measurement_info, + _apply_forward, + _subject_from_forward, + convert_forward_solution, + _merge_fwds, + _do_forward_solution, +) +from ._make_forward import ( + make_forward_solution, + _prepare_for_forward, + _prep_meg_channels, + _prep_eeg_channels, + _to_forward_dict, + _create_meg_coils, + _read_coil_defs, + _transform_orig_meg_coils, + make_forward_dipole, + use_coil_def, +) +from ._compute_forward import ( + _magnetic_dipole_field_vec, + _compute_forwards, + _concatenate_coils, +) +from ._field_interpolation import ( + _make_surface_mapping, + make_field_map, + _as_meg_type_inst, + _map_meg_or_eeg_channels, +) from . import _lead_dots # for testing purposes diff --git a/mne/forward/_compute_forward.py b/mne/forward/_compute_forward.py index 9b4ee7dba1c..5fe55906220 100644 --- a/mne/forward/_compute_forward.py +++ b/mne/forward/_compute_forward.py @@ -30,21 +30,22 @@ # ############################################################################# # COIL SPECIFICATION AND FIELD COMPUTATION MATRIX + def _dup_coil_set(coils, coord_frame, t): """Make a duplicate.""" - if t is not None and coord_frame != t['from']: - raise RuntimeError('transformation frame does not match the coil set') + if t is not None and coord_frame != t["from"]: + raise RuntimeError("transformation frame does not match the coil set") coils = deepcopy(coils) if t is not None: - coord_frame = t['to'] + coord_frame = t["to"] for coil in coils: - for key in ('ex', 'ey', 'ez'): + for key in ("ex", "ey", "ez"): if key in coil: - coil[key] = apply_trans(t['trans'], coil[key], False) - coil['r0'] = apply_trans(t['trans'], coil['r0']) - coil['rmag'] = apply_trans(t['trans'], coil['rmag']) - coil['cosmag'] = apply_trans(t['trans'], coil['cosmag'], False) - coil['coord_frame'] = t['to'] + coil[key] = apply_trans(t["trans"], coil[key], False) + coil["r0"] = apply_trans(t["trans"], coil["r0"]) + coil["rmag"] = apply_trans(t["trans"], coil["rmag"]) + coil["cosmag"] = apply_trans(t["trans"], coil["cosmag"], False) + coil["coord_frame"] = t["to"] return coils, coord_frame @@ -53,10 +54,9 @@ def _check_coil_frame(coils, coord_frame, bem): if coord_frame != FIFF.FIFFV_COORD_MRI: if coord_frame == FIFF.FIFFV_COORD_HEAD: # Make a transformed duplicate - coils, coord_Frame = _dup_coil_set(coils, coord_frame, - bem['head_mri_t']) + coils, coord_Frame = _dup_coil_set(coils, coord_frame, bem["head_mri_t"]) else: - raise RuntimeError('Bad coil coordinate frame %s' % coord_frame) + raise RuntimeError("Bad coil coordinate frame %s" % coord_frame) return coils, coord_frame @@ -88,12 +88,17 @@ def _lin_field_coeff(surf, mult, rmags, cosmags, ws, bins, n_jobs): (?) """ parallel, p_fun, n_jobs = parallel_func( - _do_lin_field_coeff, n_jobs, max_jobs=len(surf['tris'])) + _do_lin_field_coeff, n_jobs, max_jobs=len(surf["tris"]) + ) nas = np.array_split - coeffs = parallel(p_fun(surf['rr'], t, tn, ta, rmags, cosmags, ws, bins) - for t, tn, ta in zip(nas(surf['tris'], n_jobs), - nas(surf['tri_nn'], n_jobs), - nas(surf['tri_area'], n_jobs))) + coeffs = parallel( + p_fun(surf["rr"], t, tn, ta, rmags, cosmags, ws, bins) + for t, tn, ta in zip( + nas(surf["tris"], n_jobs), + nas(surf["tri_nn"], n_jobs), + nas(surf["tri_area"], n_jobs), + ) + ) return mult * np.sum(coeffs, axis=0) @@ -154,22 +159,21 @@ def _do_lin_field_coeff(bem_rr, tris, tn, ta, rmags, cosmags, ws, bins): for ti in range(3): x = np.sum(c[:, ti], axis=-1) x /= den[:, tri[ti]] / tri_area - coeff[:, tri[ti]] += \ - bincount(bins, weights=x, minlength=bins[-1] + 1) + coeff[:, tri[ti]] += bincount(bins, weights=x, minlength=bins[-1] + 1) return coeff def _concatenate_coils(coils): """Concatenate MEG coil parameters.""" - rmags = np.concatenate([coil['rmag'] for coil in coils]) - cosmags = np.concatenate([coil['cosmag'] for coil in coils]) - ws = np.concatenate([coil['w'] for coil in coils]) - n_int = np.array([len(coil['rmag']) for coil in coils]) + rmags = np.concatenate([coil["rmag"] for coil in coils]) + cosmags = np.concatenate([coil["cosmag"] for coil in coils]) + ws = np.concatenate([coil["w"] for coil in coils]) + n_int = np.array([len(coil["rmag"]) for coil in coils]) if n_int[-1] == 0: # We assume each sensor has at least one integration point, # which should be a safe assumption. But let's check it here, since # our code elsewhere relies on bins[-1] + 1 being the number of sensors - raise RuntimeError('not supported') + raise RuntimeError("not supported") bins = np.repeat(np.arange(len(n_int)), n_int) return rmags, cosmags, ws, bins @@ -208,8 +212,8 @@ def _bem_specify_coils(bem, coils, coord_frame, mults, n_jobs): # Process each of the surfaces rmags, cosmags, ws, bins = _triage_coils(coils) del coils - lens = np.cumsum(np.r_[0, [len(s['rr']) for s in bem['surfs']]]) - sol = np.zeros((bins[-1] + 1, bem['solution'].shape[1])) + lens = np.cumsum(np.r_[0, [len(s["rr"]) for s in bem["surfs"]]]) + sol = np.zeros((bins[-1] + 1, bem["solution"].shape[1])) lims = np.concatenate([np.arange(0, sol.shape[0], 100), [sol.shape[0]]]) # Put through the bem (in channel-based chunks to save memory) @@ -217,10 +221,11 @@ def _bem_specify_coils(bem, coils, coord_frame, mults, n_jobs): mask = np.logical_and(bins >= start, bins < stop) r, c, w, b = rmags[mask], cosmags[mask], ws[mask], bins[mask] - start # Compute coeffs for each surface, one at a time - for o1, o2, surf, mult in zip(lens[:-1], lens[1:], - bem['surfs'], bem['field_mult']): + for o1, o2, surf, mult in zip( + lens[:-1], lens[1:], bem["surfs"], bem["field_mult"] + ): coeff = _lin_field_coeff(surf, mult, r, c, w, b, n_jobs) - sol[start:stop] += np.dot(coeff, bem['solution'][o1:o2]) + sol[start:stop] += np.dot(coeff, bem["solution"][o1:o2]) sol *= mults return sol @@ -242,20 +247,22 @@ def _bem_specify_els(bem, els, mults): sol : ndarray, shape (n_EEG_sensors, n_BEM_vertices) EEG solution """ - sol = np.zeros((len(els), bem['solution'].shape[1])) - scalp = bem['surfs'][0] + sol = np.zeros((len(els), bem["solution"].shape[1])) + scalp = bem["surfs"][0] # Operate on all integration points for all electrodes (in MRI coords) - rrs = np.concatenate([apply_trans(bem['head_mri_t']['trans'], el['rmag']) - for el in els], axis=0) - ws = np.concatenate([el['w'] for el in els]) + rrs = np.concatenate( + [apply_trans(bem["head_mri_t"]["trans"], el["rmag"]) for el in els], axis=0 + ) + ws = np.concatenate([el["w"] for el in els]) tri_weights, tri_idx = _project_onto_surface(rrs, scalp) tri_weights *= ws[:, np.newaxis] - weights = np.matmul(tri_weights[:, np.newaxis], - bem['solution'][scalp['tris'][tri_idx]])[:, 0] + weights = np.matmul( + tri_weights[:, np.newaxis], bem["solution"][scalp["tris"][tri_idx]] + )[:, 0] # there are way more vertices than electrodes generally, so let's iterate # over the electrodes - edges = np.concatenate([[0], np.cumsum([len(el['w']) for el in els])]) + edges = np.concatenate([[0], np.cumsum([len(el["w"]) for el in els])]) for ii, (start, stop) in enumerate(zip(edges[:-1], edges[1:])): sol[ii] = weights[start:stop].sum(0) sol *= mults @@ -302,7 +309,7 @@ def _bem_inf_pots(mri_rr, bem_rr, mri_Q=None): this_diff = bem_rr - rr diff_norm = np.sum(this_diff * this_diff, axis=1) diff_norm *= np.sqrt(diff_norm) - diff_norm[diff_norm == 0] = 1. + diff_norm[diff_norm == 0] = 1.0 if mri_Q is not None: this_diff = np.dot(this_diff, mri_Q.T) this_diff /= diff_norm.reshape(-1, 1) @@ -310,6 +317,7 @@ def _bem_inf_pots(mri_rr, bem_rr, mri_Q=None): return diff + # This function has been refactored to process all points simultaneously # def _bem_inf_field(rd, Q, rp, d): # """Infinite-medium magnetic field. See (7) in Mosher, 1999""" @@ -370,8 +378,7 @@ def _bem_inf_fields(rr, rmag, cosmag): @fill_doc -def _bem_pot_or_field(rr, mri_rr, mri_Q, coils, solution, bem_rr, n_jobs, - coil_type): +def _bem_pot_or_field(rr, mri_rr, mri_Q, coils, solution, bem_rr, n_jobs, coil_type): """Calculate the magnetic field or electric potential forward solution. The code is very similar between EEG and MEG potentials, so combine them. @@ -404,22 +411,25 @@ def _bem_pot_or_field(rr, mri_rr, mri_Q, coils, solution, bem_rr, n_jobs, # Both MEG and EEG have the inifinite-medium potentials # This could be just vectorized, but eats too much memory, so instead we # reduce memory by chunking within _do_inf_pots and parallelize, too: - parallel, p_fun, n_jobs = parallel_func( - _do_inf_pots, n_jobs, max_jobs=len(rr)) + parallel, p_fun, n_jobs = parallel_func(_do_inf_pots, n_jobs, max_jobs=len(rr)) nas = np.array_split - B = np.sum(parallel(p_fun(mri_rr, sr.copy(), np.ascontiguousarray(mri_Q), - np.array(sol)) # copy and contig - for sr, sol in zip(nas(bem_rr, n_jobs), - nas(solution.T, n_jobs))), axis=0) + B = np.sum( + parallel( + p_fun( + mri_rr, sr.copy(), np.ascontiguousarray(mri_Q), np.array(sol) + ) # copy and contig + for sr, sol in zip(nas(bem_rr, n_jobs), nas(solution.T, n_jobs)) + ), + axis=0, + ) # The copy()s above should make it so the whole objects don't need to be # pickled... # Only MEG coils are sensitive to the primary current distribution. - if coil_type == 'meg': + if coil_type == "meg": # Primary current contribution (can be calc. in coil/dipole coords) parallel, p_fun, n_jobs = parallel_func(_do_prim_curr, n_jobs) - pcc = np.concatenate(parallel(p_fun(r, coils) - for r in nas(rr, n_jobs)), axis=0) + pcc = np.concatenate(parallel(p_fun(r, coils) for r in nas(rr, n_jobs)), axis=0) B += pcc B *= _MAG_FACTOR return B @@ -451,8 +461,9 @@ def _do_prim_curr(rr, coils): pp = _bem_inf_fields(rr[start:stop], rmags, cosmags) pp *= ws pp.shape = (3 * (stop - start), -1) - pc[3 * start:3 * stop] = [bincount(bins, this_pp, bins[-1] + 1) - for this_pp in pp] + pc[3 * start : 3 * stop] = [ + bincount(bins, this_pp, bins[-1] + 1) for this_pp in pp + ] return pc @@ -493,21 +504,21 @@ def _do_inf_pots(mri_rr, bem_rr, mri_Q, sol): # v0 in Hämäläinen et al., 1989 == v_inf in Mosher, et al., 1999 v0s = _bem_inf_pots(mri_rr[start:stop], bem_rr, mri_Q) v0s = v0s.reshape(-1, v0s.shape[2]) - B[3 * start:3 * stop] = np.dot(v0s, sol) + B[3 * start : 3 * stop] = np.dot(v0s, sol) return B # ############################################################################# # SPHERE COMPUTATION -def _sphere_pot_or_field(rr, mri_rr, mri_Q, coils, solution, bem_rr, - n_jobs, coil_type): + +def _sphere_pot_or_field(rr, mri_rr, mri_Q, coils, solution, bem_rr, n_jobs, coil_type): """Do potential or field for spherical model.""" - fun = _eeg_spherepot_coil if coil_type == 'eeg' else _sphere_field - parallel, p_fun, n_jobs = parallel_func( - fun, n_jobs, max_jobs=len(rr)) - B = np.concatenate(parallel(p_fun(r, coils, sphere=solution) - for r in np.array_split(rr, n_jobs))) + fun = _eeg_spherepot_coil if coil_type == "eeg" else _sphere_field + parallel, p_fun, n_jobs = parallel_func(fun, n_jobs, max_jobs=len(rr)) + B = np.concatenate( + parallel(p_fun(r, coils, sphere=solution) for r in np.array_split(rr, n_jobs)) + ) return B @@ -521,7 +532,7 @@ def _sphere_field(rrs, coils, sphere): by Matti Hämäläinen, February 1990 """ rmags, cosmags, ws, bins = _triage_coils(coils) - return _do_sphere_field(rrs, rmags, cosmags, ws, bins, sphere['r0']) + return _do_sphere_field(rrs, rmags, cosmags, ws, bins, sphere["r0"]) @jit() @@ -557,8 +568,9 @@ def _do_sphere_field(rrs, rmags, cosmags, ws, bins, r0): _jit_cross(v1, rr_, cosmags) v2 = np.empty((cosmags.shape[0], 3)) _jit_cross(v2, rr_, this_poss) - xx = ((good * ws).reshape(-1, 1) * - (v1 / F.reshape(-1, 1) + v2 * g.reshape(-1, 1))) + xx = (good * ws).reshape(-1, 1) * ( + v1 / F.reshape(-1, 1) + v2 * g.reshape(-1, 1) + ) for jj in range(3): zz = bincount(bins, xx[:, jj], n_coils) B[3 * ri + jj, :] = zz @@ -573,24 +585,24 @@ def _eeg_spherepot_coil(rrs, coils, sphere): del coils # Shift to the sphere model coordinates - rrs = rrs - sphere['r0'] + rrs = rrs - sphere["r0"] B = np.zeros((3 * len(rrs), n_coils)) for ri, rr in enumerate(rrs): # Only process dipoles inside the innermost sphere - if np.sqrt(np.dot(rr, rr)) >= sphere['layers'][0]['rad']: + if np.sqrt(np.dot(rr, rr)) >= sphere["layers"][0]["rad"]: continue # fwd_eeg_spherepot_vec vval_one = np.zeros((len(rmags), 3)) # Make a weighted sum over the equivalence parameters - for eq in range(sphere['nfit']): + for eq in range(sphere["nfit"]): # Scale the dipole position - rd = sphere['mu'][eq] * rr + rd = sphere["mu"][eq] * rr rd2 = np.sum(rd * rd) rd2_inv = 1.0 / rd2 # Go over all electrodes - this_pos = rmags - sphere['r0'] + this_pos = rmags - sphere["r0"] # Scale location onto the surface of the sphere (not used) # if sphere['scale_pos']: @@ -616,17 +628,19 @@ def _eeg_spherepot_coil(rrs, coils, sphere): c2 = a3 + (a + r) / (r * F) # Mix them together and scale by lambda/(rd*rd) - m1 = (c1 - c2 * rrd) + m1 = c1 - c2 * rrd m2 = c2 * rd2 - vval_one += (sphere['lambda'][eq] * rd2_inv * - (m1[:, np.newaxis] * rd + - m2[:, np.newaxis] * this_pos)) + vval_one += ( + sphere["lambda"][eq] + * rd2_inv + * (m1[:, np.newaxis] * rd + m2[:, np.newaxis] * this_pos) + ) # compute total result xx = vval_one * ws[:, np.newaxis] zz = np.array([bincount(bins, x, bins[-1] + 1) for x in xx.T]) - B[3 * ri:3 * ri + 3, :] = zz + B[3 * ri : 3 * ri + 3, :] = zz # finishing by scaling by 1/(4*M_PI) B *= 0.25 / np.pi return B @@ -642,14 +656,14 @@ def _triage_coils(coils): _MIN_DIST_LIMIT = 1e-5 -def _magnetic_dipole_field_vec(rrs, coils, too_close='raise'): +def _magnetic_dipole_field_vec(rrs, coils, too_close="raise"): rmags, cosmags, ws, bins = _triage_coils(coils) fwd, min_dist = _compute_mdfv(rrs, rmags, cosmags, ws, bins, too_close) if min_dist < _MIN_DIST_LIMIT: - msg = 'Coil too close (dist = %g mm)' % (min_dist * 1000,) - if too_close == 'raise': + msg = "Coil too close (dist = %g mm)" % (min_dist * 1000,) + if too_close == "raise": raise RuntimeError(msg) - func = warn if too_close == 'warning' else logger.info + func = warn if too_close == "warning" else logger.info func(msg) return fwd @@ -682,7 +696,7 @@ def _compute_mdfv(rrs, rmags, cosmags, ws, bins, too_close): dist2 = dist2_.reshape(-1, 1) dist = np.sqrt(dist2) min_dist = min(dist.min(), min_dist) - if min_dist < _MIN_DIST_LIMIT and too_close == 'raise': + if min_dist < _MIN_DIST_LIMIT and too_close == "raise": break t_ = np.sum(diff * cosmags, axis=1) t = t_.reshape(-1, 1) @@ -696,6 +710,7 @@ def _compute_mdfv(rrs, rmags, cosmags, ws, bins, too_close): # ############################################################################# # MAIN TRIAGING FUNCTION + @verbose def _prep_field_computation(rr, *, sensors, bem, n_jobs, verbose=None): """Precompute and store some things that are used for both MEG and EEG. @@ -717,44 +732,47 @@ def _prep_field_computation(rr, *, sensors, bem, n_jobs, verbose=None): %(verbose)s """ bem_rr = mults = mri_Q = head_mri_t = None - if not bem['is_sphere']: - if bem['bem_method'] != FIFF.FIFFV_BEM_APPROX_LINEAR: - raise RuntimeError('only linear collocation supported') + if not bem["is_sphere"]: + if bem["bem_method"] != FIFF.FIFFV_BEM_APPROX_LINEAR: + raise RuntimeError("only linear collocation supported") # Store (and apply soon) μ_0/(4π) factor before source computations - mults = np.repeat(bem['source_mult'] / (4.0 * np.pi), - [len(s['rr']) for s in bem['surfs']])[np.newaxis, :] + mults = np.repeat( + bem["source_mult"] / (4.0 * np.pi), [len(s["rr"]) for s in bem["surfs"]] + )[np.newaxis, :] # Get positions of BEM points for every surface - bem_rr = np.concatenate([s['rr'] for s in bem['surfs']]) + bem_rr = np.concatenate([s["rr"] for s in bem["surfs"]]) # The dipole location and orientation must be transformed - head_mri_t = bem['head_mri_t'] - mri_Q = bem['head_mri_t']['trans'][:3, :3].T + head_mri_t = bem["head_mri_t"] + mri_Q = bem["head_mri_t"]["trans"][:3, :3].T solutions = dict() for coil_type in sensors: - coils = sensors[coil_type]['defs'] - if not bem['is_sphere']: - if coil_type == 'meg': + coils = sensors[coil_type]["defs"] + if not bem["is_sphere"]: + if coil_type == "meg": # MEG field computation matrices for BEM - start = 'Composing the field computation matrix' - logger.info('\n' + start + '...') + start = "Composing the field computation matrix" + logger.info("\n" + start + "...") cf = FIFF.FIFFV_COORD_HEAD # multiply solution by "mults" here for simplicity solution = _bem_specify_coils(bem, coils, cf, mults, n_jobs) else: # Compute solution for EEG sensor - logger.info('Setting up for EEG...') + logger.info("Setting up for EEG...") solution = _bem_specify_els(bem, coils, mults) else: solution = bem - if coil_type == 'eeg': - logger.info('Using the equivalent source approach in the ' - 'homogeneous sphere for EEG') - sensors[coil_type]['defs'] = _triage_coils(coils) + if coil_type == "eeg": + logger.info( + "Using the equivalent source approach in the " + "homogeneous sphere for EEG" + ) + sensors[coil_type]["defs"] = _triage_coils(coils) solutions[coil_type] = solution # Get appropriate forward physics function depending on sphere or BEM model - fun = _sphere_pot_or_field if bem['is_sphere'] else _bem_pot_or_field + fun = _sphere_pot_or_field if bem["is_sphere"] else _bem_pot_or_field # Update fwd_data with # bem_rr (3D BEM vertex positions) @@ -764,8 +782,8 @@ def _prep_field_computation(rr, *, sensors, bem, n_jobs, verbose=None): # solutions (len 2 list; [ndarray, shape (n_MEG_sens, n BEM vertices), # ndarray, shape (n_EEG_sens, n BEM vertices)] fwd_data = dict( - bem_rr=bem_rr, mri_Q=mri_Q, head_mri_t=head_mri_t, fun=fun, - solutions=solutions) + bem_rr=bem_rr, mri_Q=mri_Q, head_mri_t=head_mri_t, fun=fun, solutions=solutions + ) return fwd_data @@ -775,26 +793,34 @@ def _compute_forwards_meeg(rr, *, sensors, fwd_data, n_jobs, silent=False): Bs = dict() # The dipole location and orientation must be transformed to mri coords mri_rr = None - if fwd_data['head_mri_t'] is not None: - mri_rr = np.ascontiguousarray( - apply_trans(fwd_data['head_mri_t']['trans'], rr)) - mri_Q, bem_rr, fun = fwd_data['mri_Q'], fwd_data['bem_rr'], fwd_data['fun'] - solutions = fwd_data['solutions'] + if fwd_data["head_mri_t"] is not None: + mri_rr = np.ascontiguousarray(apply_trans(fwd_data["head_mri_t"]["trans"], rr)) + mri_Q, bem_rr, fun = fwd_data["mri_Q"], fwd_data["bem_rr"], fwd_data["fun"] + solutions = fwd_data["solutions"] del fwd_data for coil_type, sens in sensors.items(): - coils = sens['defs'] - compensator = sens.get('compensator', None) - post_picks = sens.get('post_picks', None) + coils = sens["defs"] + compensator = sens.get("compensator", None) + post_picks = sens.get("post_picks", None) solution = solutions.get(coil_type, None) # Do the actual forward calculation for a list MEG/EEG sensors if not silent: - logger.info('Computing %s at %d source location%s ' - '(free orientations)...' - % (coil_type.upper(), len(rr), _pl(rr))) + logger.info( + "Computing %s at %d source location%s " + "(free orientations)..." % (coil_type.upper(), len(rr), _pl(rr)) + ) # Calculate forward solution using spherical or BEM model - B = fun(rr, mri_rr, mri_Q, coils=coils, solution=solution, - bem_rr=bem_rr, n_jobs=n_jobs, coil_type=coil_type) + B = fun( + rr, + mri_rr, + mri_Q, + coils=coils, + solution=solution, + bem_rr=bem_rr, + n_jobs=n_jobs, + coil_type=coil_type, + ) # Compensate if needed (only done for MEG systems w/compensation) if compensator is not None: @@ -810,16 +836,16 @@ def _compute_forwards(rr, *, bem, sensors, n_jobs, verbose=None): """Compute the MEG and EEG forward solutions.""" # Split calculation into two steps to save (potentially) a lot of time # when e.g. dipole fitting - solver = bem.get('solver', 'mne') - _check_option('solver', solver, ('mne', 'openmeeg')) - if bem['is_sphere'] or solver == 'mne': - fwd_data = _prep_field_computation( - rr, sensors=sensors, bem=bem, n_jobs=n_jobs) + solver = bem.get("solver", "mne") + _check_option("solver", solver, ("mne", "openmeeg")) + if bem["is_sphere"] or solver == "mne": + fwd_data = _prep_field_computation(rr, sensors=sensors, bem=bem, n_jobs=n_jobs) Bs = _compute_forwards_meeg( - rr, sensors=sensors, fwd_data=fwd_data, n_jobs=n_jobs) + rr, sensors=sensors, fwd_data=fwd_data, n_jobs=n_jobs + ) else: Bs = _compute_forwards_openmeeg(rr, bem=bem, sensors=sensors) - n_sensors_want = sum(len(s['ch_names']) for s in sensors.values()) + n_sensors_want = sum(len(s["ch_names"]) for s in sensors.values()) n_sensors = sum(B.shape[1] for B in Bs.values()) n_sources = list(Bs.values())[0].shape[0] assert (n_sources, n_sensors) == (len(rr) * 3, n_sensors_want) @@ -830,30 +856,30 @@ def _compute_forwards_openmeeg(rr, *, bem, sensors): """Compute the MEG and EEG forward solutions for OpenMEEG.""" if len(bem["surfs"]) != 3: raise RuntimeError("Only 3-layer BEM is supported for OpenMEEG.") - om = _import_openmeeg('compute a forward solution using OpenMEEG') + om = _import_openmeeg("compute a forward solution using OpenMEEG") hminv = om.SymMatrix(bem["solution"]) - geom = _make_openmeeg_geometry(bem, invert_transform(bem['head_mri_t'])) + geom = _make_openmeeg_geometry(bem, invert_transform(bem["head_mri_t"])) # Make dipoles for all XYZ orientations dipoles = np.c_[ np.kron(rr.T, np.ones(3)[None, :]).T, - np.kron(np.ones(len(rr))[:, None], - np.eye(3)), + np.kron(np.ones(len(rr))[:, None], np.eye(3)), ] dipoles = np.asfortranarray(dipoles) dipoles = om.Matrix(dipoles) dsm = om.DipSourceMat(geom, dipoles, "Brain") Bs = dict() - if 'eeg' in sensors: - rmags, _, ws, bins = _concatenate_coils(sensors['eeg']['defs']) + if "eeg" in sensors: + rmags, _, ws, bins = _concatenate_coils(sensors["eeg"]["defs"]) rmags = np.asfortranarray(rmags.astype(np.float64)) eeg_sensors = om.Sensors(om.Matrix(np.asfortranarray(rmags)), geom) h2em = om.Head2EEGMat(geom, eeg_sensors) eeg_fwd_full = om.GainEEG(hminv, dsm, h2em).array() - Bs['eeg'] = np.array([bincount(bins, ws * x, bins[-1] + 1) - for x in eeg_fwd_full.T], float) - if 'meg' in sensors: - rmags, cosmags, ws, bins = _concatenate_coils(sensors['meg']['defs']) + Bs["eeg"] = np.array( + [bincount(bins, ws * x, bins[-1] + 1) for x in eeg_fwd_full.T], float + ) + if "meg" in sensors: + rmags, cosmags, ws, bins = _concatenate_coils(sensors["meg"]["defs"]) rmags = np.asfortranarray(rmags.astype(np.float64)) cosmags = np.asfortranarray(cosmags.astype(np.float64)) labels = [str(ii) for ii in range(len(rmags))] @@ -862,13 +888,14 @@ def _compute_forwards_openmeeg(rr, *, bem, sensors): h2mm = om.Head2MEGMat(geom, meg_sensors) ds2mm = om.DipSource2MEGMat(dipoles, meg_sensors) meg_fwd_full = om.GainMEG(hminv, dsm, h2mm, ds2mm).array() - B = np.array([bincount(bins, ws * x, bins[-1] + 1) - for x in meg_fwd_full.T], float) - compensator = sensors['meg'].get('compensator', None) - post_picks = sensors['meg'].get('post_picks', None) + B = np.array( + [bincount(bins, ws * x, bins[-1] + 1) for x in meg_fwd_full.T], float + ) + compensator = sensors["meg"].get("compensator", None) + post_picks = sensors["meg"].get("post_picks", None) if compensator is not None: B = B @ compensator.T if post_picks is not None: B = B[:, post_picks] - Bs['meg'] = B + Bs["meg"] = B return Bs diff --git a/mne/forward/_field_interpolation.py b/mne/forward/_field_interpolation.py index fdc21ab8e9c..acad17a7fca 100644 --- a/mne/forward/_field_interpolation.py +++ b/mne/forward/_field_interpolation.py @@ -18,11 +18,13 @@ from ..surface import get_head_surf, get_meg_helmet_surf from ..transforms import transform_surface_to, _find_trans, _get_trans from ._make_forward import _create_meg_coils, _create_eeg_els, _read_coil_defs -from ._lead_dots import (_do_self_dots, _do_surface_dots, _get_legen_table, - _do_cross_dots) -from ..utils import ( - logger, verbose, _check_option, _reg_pinv, _pl, _check_fname +from ._lead_dots import ( + _do_self_dots, + _do_surface_dots, + _get_legen_table, + _do_cross_dots, ) +from ..utils import logger, verbose, _check_option, _reg_pinv, _pl, _check_fname from ..epochs import EpochsArray, BaseEpochs from ..evoked import Evoked, EvokedArray @@ -30,9 +32,10 @@ def _setup_dots(mode, info, coils, ch_type): """Set up dot products.""" from scipy.interpolate import interp1d + int_rad = 0.06 noise = make_ad_hoc_cov(info, dict(mag=20e-15, grad=5e-13, eeg=1e-6)) - n_coeff, interp = (50, 'nearest') if mode == 'fast' else (100, 'linear') + n_coeff, interp = (50, "nearest") if mode == "fast" else (100, "linear") lut, n_fact = _get_legen_table(ch_type, False, n_coeff, verbose=False) lut_fun = interp1d(np.linspace(-1, 1, lut.shape[0]), lut, interp, axis=0) return int_rad, noise, lut_fun, n_fact @@ -40,27 +43,27 @@ def _setup_dots(mode, info, coils, ch_type): def _compute_mapping_matrix(fmd, info): """Do the hairy computations.""" - logger.info(' Preparing the mapping matrix...') + logger.info(" Preparing the mapping matrix...") # assemble a projector and apply it to the data - ch_names = fmd['ch_names'] - projs = info.get('projs', list()) + ch_names = fmd["ch_names"] + projs = info.get("projs", list()) proj_op = make_projector(projs, ch_names)[0] - proj_dots = np.dot(proj_op.T, np.dot(fmd['self_dots'], proj_op)) + proj_dots = np.dot(proj_op.T, np.dot(fmd["self_dots"], proj_op)) - noise_cov = fmd['noise'] + noise_cov = fmd["noise"] # Whiten - if not noise_cov['diag']: + if not noise_cov["diag"]: raise NotImplementedError # this shouldn't happen - whitener = np.diag(1.0 / np.sqrt(noise_cov['data'].ravel())) + whitener = np.diag(1.0 / np.sqrt(noise_cov["data"].ravel())) whitened_dots = np.dot(whitener.T, np.dot(proj_dots, whitener)) # SVD is numerically better than the eigenvalue composition even if # mat is supposed to be symmetric and positive definite - if fmd.get('pinv_method', 'tsvd') == 'tsvd': - inv, fmd['nest'] = _pinv_trunc(whitened_dots, fmd['miss']) + if fmd.get("pinv_method", "tsvd") == "tsvd": + inv, fmd["nest"] = _pinv_trunc(whitened_dots, fmd["miss"]) else: - assert fmd['pinv_method'] == 'tikhonov', fmd['pinv_method'] - inv, fmd['nest'] = _pinv_tikhonov(whitened_dots, fmd['miss']) + assert fmd["pinv_method"] == "tikhonov", fmd["pinv_method"] + inv, fmd["nest"] = _pinv_tikhonov(whitened_dots, fmd["miss"]) # Sandwich with the whitener inv_whitened = np.dot(whitener.T, np.dot(inv, whitener)) @@ -71,13 +74,14 @@ def _compute_mapping_matrix(fmd, info): # Finally sandwich in the selection matrix # This one picks up the correct lead field projection - mapping_mat = np.dot(fmd['surface_dots'], inv_whitened_proj) + mapping_mat = np.dot(fmd["surface_dots"], inv_whitened_proj) # Optionally apply the average electrode reference to the final field map - if fmd['kind'] == 'eeg' and _has_eeg_average_ref_proj(info): + if fmd["kind"] == "eeg" and _has_eeg_average_ref_proj(info): logger.info( - ' The map has an average electrode reference ' - f'({mapping_mat.shape[0]} channels)') + " The map has an average electrode reference " + f"({mapping_mat.shape[0]} channels)" + ) mapping_mat -= np.mean(mapping_mat, axis=0) return mapping_mat @@ -85,15 +89,18 @@ def _compute_mapping_matrix(fmd, info): def _pinv_trunc(x, miss): """Compute pseudoinverse, truncating at most "miss" fraction of varexp.""" from scipy import linalg + u, s, v = linalg.svd(x, full_matrices=False) # Eigenvalue truncation varexp = np.cumsum(s) varexp /= varexp[-1] n = np.where(varexp >= (1.0 - miss))[0][0] + 1 - logger.info(' Truncating at %d/%d components to omit less than %g ' - '(%0.2g)' % (n, len(s), miss, 1. - varexp[n - 1])) - s = 1. / s[:n] + logger.info( + " Truncating at %d/%d components to omit less than %g " + "(%0.2g)" % (n, len(s), miss, 1.0 - varexp[n - 1]) + ) + s = 1.0 / s[:n] inv = ((u[:, :n] * s) @ v[:n]).T return inv, n @@ -101,8 +108,10 @@ def _pinv_trunc(x, miss): def _pinv_tikhonov(x, reg): # _reg_pinv requires square Hermitian, which we have here inv, _, n = _reg_pinv(x, reg=reg, rank=None) - logger.info(f' Truncating at {n}/{len(x)} components and regularizing ' - f'with α={reg:0.1e}') + logger.info( + f" Truncating at {n}/{len(x)} components and regularizing " + f"with α={reg:0.1e}" + ) return inv, n @@ -131,58 +140,76 @@ def _map_meg_or_eeg_channels(info_from, info_to, mode, origin, miss=None): """ # no need to apply trans because both from and to coils are in device # coordinates - info_kinds = set(ch['kind'] for ch in info_to['chs']) - info_kinds |= set(ch['kind'] for ch in info_from['chs']) + info_kinds = set(ch["kind"] for ch in info_to["chs"]) + info_kinds |= set(ch["kind"] for ch in info_from["chs"]) if FIFF.FIFFV_REF_MEG_CH in info_kinds: # refs same as MEG info_kinds |= set([FIFF.FIFFV_MEG_CH]) info_kinds -= set([FIFF.FIFFV_REF_MEG_CH]) info_kinds = sorted(info_kinds) # This should be guaranteed by the callers - assert (len(info_kinds) == 1 and info_kinds[0] in ( - FIFF.FIFFV_MEG_CH, FIFF.FIFFV_EEG_CH)) - kind = 'eeg' if info_kinds[0] == FIFF.FIFFV_EEG_CH else 'meg' + assert len(info_kinds) == 1 and info_kinds[0] in ( + FIFF.FIFFV_MEG_CH, + FIFF.FIFFV_EEG_CH, + ) + kind = "eeg" if info_kinds[0] == FIFF.FIFFV_EEG_CH else "meg" # # Step 1. Prepare the coil definitions # - if kind == 'meg': + if kind == "meg": templates = _read_coil_defs(verbose=False) - coils_from = _create_meg_coils(info_from['chs'], 'normal', - info_from['dev_head_t'], templates) - coils_to = _create_meg_coils(info_to['chs'], 'normal', - info_to['dev_head_t'], templates) - pinv_method = 'tsvd' + coils_from = _create_meg_coils( + info_from["chs"], "normal", info_from["dev_head_t"], templates + ) + coils_to = _create_meg_coils( + info_to["chs"], "normal", info_to["dev_head_t"], templates + ) + pinv_method = "tsvd" miss = 1e-4 else: - coils_from = _create_eeg_els(info_from['chs']) - coils_to = _create_eeg_els(info_to['chs']) - pinv_method = 'tikhonov' + coils_from = _create_eeg_els(info_from["chs"]) + coils_to = _create_eeg_els(info_to["chs"]) + pinv_method = "tikhonov" miss = 1e-1 - if _has_eeg_average_ref_proj(info_from) and \ - not _has_eeg_average_ref_proj(info_to): + if _has_eeg_average_ref_proj(info_from) and not _has_eeg_average_ref_proj( + info_to + ): raise RuntimeError( - 'info_to must have an average EEG reference projector if ' - 'info_from has one') + "info_to must have an average EEG reference projector if " + "info_from has one" + ) origin = _check_origin(origin, info_from) # # Step 2. Calculate the dot products # - int_rad, noise, lut_fun, n_fact = _setup_dots( - mode, info_from, coils_from, kind) - logger.info(f' Computing dot products for {len(coils_from)} ' - f'{kind.upper()} channel{_pl(coils_from)}...') - self_dots = _do_self_dots(int_rad, False, coils_from, origin, kind, - lut_fun, n_fact, n_jobs=None) - logger.info(f' Computing cross products for {len(coils_from)} → ' - f'{len(coils_to)} {kind.upper()} channel{_pl(coils_to)}...') - cross_dots = _do_cross_dots(int_rad, False, coils_from, coils_to, - origin, kind, lut_fun, n_fact).T - - ch_names = [c['ch_name'] for c in info_from['chs']] - fmd = dict(kind=kind, ch_names=ch_names, - origin=origin, noise=noise, self_dots=self_dots, - surface_dots=cross_dots, int_rad=int_rad, miss=miss, - pinv_method=pinv_method) + int_rad, noise, lut_fun, n_fact = _setup_dots(mode, info_from, coils_from, kind) + logger.info( + f" Computing dot products for {len(coils_from)} " + f"{kind.upper()} channel{_pl(coils_from)}..." + ) + self_dots = _do_self_dots( + int_rad, False, coils_from, origin, kind, lut_fun, n_fact, n_jobs=None + ) + logger.info( + f" Computing cross products for {len(coils_from)} → " + f"{len(coils_to)} {kind.upper()} channel{_pl(coils_to)}..." + ) + cross_dots = _do_cross_dots( + int_rad, False, coils_from, coils_to, origin, kind, lut_fun, n_fact + ).T + + ch_names = [c["ch_name"] for c in info_from["chs"]] + fmd = dict( + kind=kind, + ch_names=ch_names, + origin=origin, + noise=noise, + self_dots=self_dots, + surface_dots=cross_dots, + int_rad=int_rad, + miss=miss, + pinv_method=pinv_method, + ) # # Step 3. Compute the mapping matrix @@ -191,7 +218,7 @@ def _map_meg_or_eeg_channels(info_from, info_to, mode, origin, miss=None): return mapping -def _as_meg_type_inst(inst, ch_type='grad', mode='fast'): +def _as_meg_type_inst(inst, ch_type="grad", mode="fast"): """Compute virtual evoked using interpolated fields in mag/grad channels. Parameters @@ -210,30 +237,31 @@ def _as_meg_type_inst(inst, ch_type='grad', mode='fast'): inst : instance of mne.EvokedArray or mne.EpochsArray The transformed evoked object containing only virtual channels. """ - _check_option('ch_type', ch_type, ['mag', 'grad']) + _check_option("ch_type", ch_type, ["mag", "grad"]) # pick the original and destination channels - pick_from = pick_types(inst.info, meg=True, eeg=False, - ref_meg=False) - pick_to = pick_types(inst.info, meg=ch_type, eeg=False, - ref_meg=False) + pick_from = pick_types(inst.info, meg=True, eeg=False, ref_meg=False) + pick_to = pick_types(inst.info, meg=ch_type, eeg=False, ref_meg=False) if len(pick_to) == 0: - raise ValueError('No channels matching the destination channel type' - ' found in info. Please pass an evoked containing' - 'both the original and destination channels. Only the' - ' locations of the destination channels will be used' - ' for interpolation.') + raise ValueError( + "No channels matching the destination channel type" + " found in info. Please pass an evoked containing" + "both the original and destination channels. Only the" + " locations of the destination channels will be used" + " for interpolation." + ) info_from = pick_info(inst.info, pick_from) info_to = pick_info(inst.info, pick_to) # XXX someday we should probably expose the origin mapping = _map_meg_or_eeg_channels( - info_from, info_to, origin=(0., 0., 0.04), mode=mode) + info_from, info_to, origin=(0.0, 0.0, 0.04), mode=mode + ) # compute data by multiplying by the 'gain matrix' from # original sensors to virtual sensors - if hasattr(inst, 'get_data'): + if hasattr(inst, "get_data"): data = inst.get_data() else: data = inst.data @@ -242,8 +270,7 @@ def _as_meg_type_inst(inst, ch_type='grad', mode='fast'): if ndim == 2: data = data[np.newaxis, :, :] - data_ = np.empty((data.shape[0], len(mapping), data.shape[2]), - dtype=data.dtype) + data_ = np.empty((data.shape[0], len(mapping), data.shape[2]), dtype=data.dtype) for d, d_ in zip(data, data_): d_[:] = np.dot(mapping, d[pick_from]) @@ -251,28 +278,41 @@ def _as_meg_type_inst(inst, ch_type='grad', mode='fast'): info = pick_info(inst.info, sel=pick_to, copy=True) # change channel names to emphasize they contain interpolated data - for ch in info['chs']: - ch['ch_name'] += '_v' + for ch in info["chs"]: + ch["ch_name"] += "_v" info._update_redundant() info._check_consistency() if isinstance(inst, Evoked): assert ndim == 2 data_ = data_[0] # undo new axis - inst_ = EvokedArray(data_, info, tmin=inst.times[0], - comment=inst.comment, nave=inst.nave) + inst_ = EvokedArray( + data_, info, tmin=inst.times[0], comment=inst.comment, nave=inst.nave + ) else: assert isinstance(inst, BaseEpochs) - inst_ = EpochsArray(data_, info, tmin=inst.tmin, - events=inst.events, - event_id=inst.event_id, - metadata=inst.metadata) + inst_ = EpochsArray( + data_, + info, + tmin=inst.tmin, + events=inst.events, + event_id=inst.event_id, + metadata=inst.metadata, + ) return inst_ @verbose -def _make_surface_mapping(info, surf, ch_type='meg', trans=None, mode='fast', - n_jobs=None, origin=(0., 0., 0.04), verbose=None): +def _make_surface_mapping( + info, + surf, + ch_type="meg", + trans=None, + mode="fast", + n_jobs=None, + origin=(0.0, 0.0, 0.04), + verbose=None, +): """Re-map M/EEG data to a surface. Parameters @@ -303,88 +343,108 @@ def _make_surface_mapping(info, surf, ch_type='meg', trans=None, mode='fast', A n_vertices x n_sensors array that remaps the MEG or EEG data, as `new_data = np.dot(mapping, data)`. """ - if not all(key in surf for key in ['rr', 'nn']): + if not all(key in surf for key in ["rr", "nn"]): raise KeyError('surf must have both "rr" and "nn"') - if 'coord_frame' not in surf: - raise KeyError('The surface coordinate frame must be specified ' - 'in surf["coord_frame"]') - _check_option('mode', mode, ['accurate', 'fast']) + if "coord_frame" not in surf: + raise KeyError( + "The surface coordinate frame must be specified " 'in surf["coord_frame"]' + ) + _check_option("mode", mode, ["accurate", "fast"]) # deal with coordinate frames here -- always go to "head" (easiest) orig_surf = surf - surf = transform_surface_to(deepcopy(surf), 'head', trans) + surf = transform_surface_to(deepcopy(surf), "head", trans) origin = _check_origin(origin, info) # # Step 1. Prepare the coil definitions # Do the dot products, assume surf in head coords # - _check_option('ch_type', ch_type, ['meg', 'eeg']) - if ch_type == 'meg': + _check_option("ch_type", ch_type, ["meg", "eeg"]) + if ch_type == "meg": picks = pick_types(info, meg=True, eeg=False, ref_meg=False) - logger.info('Prepare MEG mapping...') + logger.info("Prepare MEG mapping...") else: picks = pick_types(info, meg=False, eeg=True, ref_meg=False) - logger.info('Prepare EEG mapping...') + logger.info("Prepare EEG mapping...") if len(picks) == 0: - raise RuntimeError('cannot map, no channels found') + raise RuntimeError("cannot map, no channels found") # XXX this code does not do any checking for compensation channels, # but it seems like this must be intentional from the ref_meg=False # (presumably from the C code) - dev_head_t = info['dev_head_t'] + dev_head_t = info["dev_head_t"] info = pick_info(_simplify_info(info), picks) - info['dev_head_t'] = dev_head_t + info["dev_head_t"] = dev_head_t # create coil defs in head coordinates - if ch_type == 'meg': + if ch_type == "meg": # Put them in head coordinates - coils = _create_meg_coils(info['chs'], 'normal', info['dev_head_t']) - type_str = 'coils' + coils = _create_meg_coils(info["chs"], "normal", info["dev_head_t"]) + type_str = "coils" miss = 1e-4 # Smoothing criterion for MEG else: # EEG - coils = _create_eeg_els(info['chs']) - type_str = 'electrodes' + coils = _create_eeg_els(info["chs"]) + type_str = "electrodes" miss = 1e-3 # Smoothing criterion for EEG # # Step 2. Calculate the dot products # int_rad, noise, lut_fun, n_fact = _setup_dots(mode, info, coils, ch_type) - logger.info('Computing dot products for %i %s...' % (len(coils), type_str)) - self_dots = _do_self_dots(int_rad, False, coils, origin, ch_type, - lut_fun, n_fact, n_jobs) - sel = np.arange(len(surf['rr'])) # eventually we should do sub-selection - logger.info('Computing dot products for %i surface locations...' - % len(sel)) - surface_dots = _do_surface_dots(int_rad, False, coils, surf, sel, - origin, ch_type, lut_fun, n_fact, - n_jobs) + logger.info("Computing dot products for %i %s..." % (len(coils), type_str)) + self_dots = _do_self_dots( + int_rad, False, coils, origin, ch_type, lut_fun, n_fact, n_jobs + ) + sel = np.arange(len(surf["rr"])) # eventually we should do sub-selection + logger.info("Computing dot products for %i surface locations..." % len(sel)) + surface_dots = _do_surface_dots( + int_rad, False, coils, surf, sel, origin, ch_type, lut_fun, n_fact, n_jobs + ) # # Step 4. Return the result # - fmd = dict(kind=ch_type, surf=surf, ch_names=info['ch_names'], coils=coils, - origin=origin, noise=noise, self_dots=self_dots, - surface_dots=surface_dots, int_rad=int_rad, miss=miss) - logger.info('Field mapping data ready') - - fmd['data'] = _compute_mapping_matrix(fmd, info) + fmd = dict( + kind=ch_type, + surf=surf, + ch_names=info["ch_names"], + coils=coils, + origin=origin, + noise=noise, + self_dots=self_dots, + surface_dots=surface_dots, + int_rad=int_rad, + miss=miss, + ) + logger.info("Field mapping data ready") + + fmd["data"] = _compute_mapping_matrix(fmd, info) # bring the original back, whatever coord frame it was in - fmd['surf'] = orig_surf + fmd["surf"] = orig_surf # Remove some unnecessary fields - del fmd['self_dots'] - del fmd['surface_dots'] - del fmd['int_rad'] - del fmd['miss'] + del fmd["self_dots"] + del fmd["surface_dots"] + del fmd["int_rad"] + del fmd["miss"] return fmd @verbose -def make_field_map(evoked, trans='auto', subject=None, subjects_dir=None, - ch_type=None, mode='fast', meg_surf='helmet', - origin=(0., 0., 0.04), n_jobs=None, *, - head_source=('bem', 'head'), verbose=None): +def make_field_map( + evoked, + trans="auto", + subject=None, + subjects_dir=None, + ch_type=None, + mode="fast", + meg_surf="helmet", + origin=(0.0, 0.0, 0.04), + n_jobs=None, + *, + head_source=("bem", "head"), + verbose=None, +): """Compute surface maps used for field display in 3D. Parameters @@ -433,9 +493,9 @@ def make_field_map(evoked, trans='auto', subject=None, subjects_dir=None, info = evoked.info if ch_type is None: - types = [t for t in ['eeg', 'meg'] if t in evoked] + types = [t for t in ["eeg", "meg"] if t in evoked] else: - _check_option('ch_type', ch_type, ['eeg', 'meg']) + _check_option("ch_type", ch_type, ["eeg", "meg"]) types = [ch_type] if subjects_dir is not None: @@ -446,35 +506,40 @@ def make_field_map(evoked, trans='auto', subject=None, subjects_dir=None, name="subjects_dir", need_dir=True, ) - if isinstance(trans, str) and trans == 'auto': + if isinstance(trans, str) and trans == "auto": # let's try to do this in MRI coordinates so they're easy to plot trans = _find_trans(subject, subjects_dir) - trans, trans_type = _get_trans(trans, fro='head', to='mri') + trans, trans_type = _get_trans(trans, fro="head", to="mri") - if 'eeg' in types and trans_type == 'identity': - logger.info('No trans file available. EEG data ignored.') - types.remove('eeg') + if "eeg" in types and trans_type == "identity": + logger.info("No trans file available. EEG data ignored.") + types.remove("eeg") if len(types) == 0: - raise RuntimeError('No data available for mapping.') + raise RuntimeError("No data available for mapping.") - _check_option('meg_surf', meg_surf, ['helmet', 'head']) + _check_option("meg_surf", meg_surf, ["helmet", "head"]) surfs = [] for this_type in types: - if this_type == 'meg' and meg_surf == 'helmet': + if this_type == "meg" and meg_surf == "helmet": surf = get_meg_helmet_surf(info, trans) else: - surf = get_head_surf( - subject, source=head_source, subjects_dir=subjects_dir) + surf = get_head_surf(subject, source=head_source, subjects_dir=subjects_dir) surfs.append(surf) surf_maps = list() for this_type, this_surf in zip(types, surfs): - this_map = _make_surface_mapping(evoked.info, this_surf, this_type, - trans, n_jobs=n_jobs, origin=origin, - mode=mode) + this_map = _make_surface_mapping( + evoked.info, + this_surf, + this_type, + trans, + n_jobs=n_jobs, + origin=origin, + mode=mode, + ) surf_maps.append(this_map) return surf_maps diff --git a/mne/forward/_lead_dots.py b/mne/forward/_lead_dots.py index a97bac9d660..3eda719ac59 100644 --- a/mne/forward/_lead_dots.py +++ b/mne/forward/_lead_dots.py @@ -20,6 +20,7 @@ ############################################################################## # FAST LEGENDRE (DERIVATIVE) POLYNOMIALS USING LOOKUP TABLE + def _next_legen_der(n, x, p0, p01, p0d, p0dd): """Compute the next Legendre polynomial and its derivatives.""" # only good for n > 1 ! @@ -46,50 +47,56 @@ def _get_legen_der(xx, n_coeff=100): p0dds[:2] = [0.0, 0.0] for n in range(2, n_coeff): p0s[n], p0ds[n], p0dds[n] = _next_legen_der( - n, x, p0s[n - 1], p0s[n - 2], p0ds[n - 1], p0dds[n - 1]) + n, x, p0s[n - 1], p0s[n - 2], p0ds[n - 1], p0dds[n - 1] + ) return coeffs @verbose -def _get_legen_table(ch_type, volume_integral=False, n_coeff=100, - n_interp=20000, force_calc=False, verbose=None): +def _get_legen_table( + ch_type, + volume_integral=False, + n_coeff=100, + n_interp=20000, + force_calc=False, + verbose=None, +): """Return a (generated) LUT of Legendre (derivative) polynomial coeffs.""" if n_interp % 2 != 0: - raise RuntimeError('n_interp must be even') - fname = op.join(_get_extra_data_path(), 'tables') + raise RuntimeError("n_interp must be even") + fname = op.join(_get_extra_data_path(), "tables") if not op.isdir(fname): # Updated due to API change (GH 1167) os.makedirs(fname) - if ch_type == 'meg': - fname = op.join(fname, 'legder_%s_%s.bin' % (n_coeff, n_interp)) + if ch_type == "meg": + fname = op.join(fname, "legder_%s_%s.bin" % (n_coeff, n_interp)) leg_fun = _get_legen_der - extra_str = ' derivative' + extra_str = " derivative" lut_shape = (n_interp + 1, n_coeff, 3) else: # 'eeg' - fname = op.join(fname, 'legval_%s_%s.bin' % (n_coeff, n_interp)) + fname = op.join(fname, "legval_%s_%s.bin" % (n_coeff, n_interp)) leg_fun = _get_legen - extra_str = '' + extra_str = "" lut_shape = (n_interp + 1, n_coeff) if not op.isfile(fname) or force_calc: - logger.info('Generating Legendre%s table...' % extra_str) + logger.info("Generating Legendre%s table..." % extra_str) x_interp = np.linspace(-1, 1, n_interp + 1) lut = leg_fun(x_interp, n_coeff).astype(np.float32) if not force_calc: - with open(fname, 'wb') as fid: + with open(fname, "wb") as fid: fid.write(lut.tobytes()) else: - logger.info('Reading Legendre%s table...' % extra_str) - with open(fname, 'rb', buffering=0) as fid: + logger.info("Reading Legendre%s table..." % extra_str) + with open(fname, "rb", buffering=0) as fid: lut = np.fromfile(fid, np.float32) lut.shape = lut_shape # we need this for the integration step n_fact = np.arange(1, n_coeff, dtype=float) - if ch_type == 'meg': + if ch_type == "meg": n_facts = list() # multn, then mult, then multn * (n + 1) if volume_integral: - n_facts.append(n_fact / ((2.0 * n_fact + 1.0) * - (2.0 * n_fact + 3.0))) + n_facts.append(n_fact / ((2.0 * n_fact + 1.0) * (2.0 * n_fact + 3.0))) else: n_facts.append(n_fact / (2.0 * n_fact + 1.0)) n_facts.append(n_facts[0] / (n_fact + 1.0)) @@ -167,8 +174,13 @@ def _comp_sums_meg(beta, ctheta, lut_fun, n_fact, volume_integral): bbeta = np.tile(beta[start:stop][np.newaxis], (n_fact.shape[0], 1)) bbeta[0] *= beta[start:stop] np.cumprod(bbeta, axis=0, out=bbeta) # run inplace - np.einsum('ji,jk,ijk->ki', bbeta, n_fact, lut_fun(ctheta[start:stop]), - out=sums[:, start:stop]) + np.einsum( + "ji,jk,ijk->ki", + bbeta, + n_fact, + lut_fun(ctheta[start:stop]), + out=sums[:, start:stop], + ) return sums @@ -179,8 +191,21 @@ def _comp_sums_meg(beta, ctheta, lut_fun, n_fact, volume_integral): _eeg_const = 1.0 / (4.0 * np.pi) -def _fast_sphere_dot_r0(r, rr1_orig, rr2s, lr1, lr2s, cosmags1, cosmags2s, - w1, w2s, volume_integral, lut, n_fact, ch_type): +def _fast_sphere_dot_r0( + r, + rr1_orig, + rr2s, + lr1, + lr2s, + cosmags1, + cosmags2s, + w1, + w2s, + volume_integral, + lut, + n_fact, + ch_type, +): """Lead field dot product computation for M/EEG in the sphere model. Parameters @@ -230,7 +255,7 @@ def _fast_sphere_dot_r0(r, rr1_orig, rr2s, lr1, lr2s, cosmags1, cosmags2s, cosmags2 = np.concatenate(cosmags2s) # outer product, sum over coords - ct = np.einsum('ik,jk->ij', rr1_orig, rr2) + ct = np.einsum("ik,jk->ij", rr1_orig, rr2) np.clip(ct, -1, 1, ct) # expand axes @@ -239,9 +264,10 @@ def _fast_sphere_dot_r0(r, rr1_orig, rr2s, lr1, lr2s, cosmags1, cosmags2s, lr1lr2 = lr1[:, np.newaxis] * lr2[np.newaxis, :] beta = (r * r) / lr1lr2 - if ch_type == 'meg': - sums = _comp_sums_meg(beta.flatten(), ct.flatten(), lut, n_fact, - volume_integral) + if ch_type == "meg": + sums = _comp_sums_meg( + beta.flatten(), ct.flatten(), lut, n_fact, volume_integral + ) sums.shape = (4,) + beta.shape # Accumulate the result, a little bit streamlined version @@ -252,21 +278,23 @@ def _fast_sphere_dot_r0(r, rr1_orig, rr2s, lr1, lr2s, cosmags1, cosmags2s, # n2c1 = np.sum(cosmags2 * rr1, axis=2) # n2c2 = np.sum(cosmags2 * rr2, axis=2) # n1n2 = np.sum(cosmags1 * cosmags2, axis=2) - n1c1 = np.einsum('ik,ijk->ij', cosmags1, rr1) - n1c2 = np.einsum('ik,ijk->ij', cosmags1, rr2) - n2c1 = np.einsum('jk,ijk->ij', cosmags2, rr1) - n2c2 = np.einsum('jk,ijk->ij', cosmags2, rr2) - n1n2 = np.einsum('ik,jk->ij', cosmags1, cosmags2) + n1c1 = np.einsum("ik,ijk->ij", cosmags1, rr1) + n1c2 = np.einsum("ik,ijk->ij", cosmags1, rr2) + n2c1 = np.einsum("jk,ijk->ij", cosmags2, rr1) + n2c2 = np.einsum("jk,ijk->ij", cosmags2, rr2) + n1n2 = np.einsum("ik,jk->ij", cosmags1, cosmags2) part1 = ct * n1c1 * n2c2 part2 = n1c1 * n2c1 + n1c2 * n2c2 - result = (n1c1 * n2c2 * sums[0] + - (2.0 * part1 - part2) * sums[1] + - (n1n2 + part1 - part2) * sums[2] + - (n1c2 - ct * n1c1) * (n2c1 - ct * n2c2) * sums[3]) + result = ( + n1c1 * n2c2 * sums[0] + + (2.0 * part1 - part2) * sums[1] + + (n1n2 + part1 - part2) * sums[2] + + (n1c2 - ct * n1c1) * (n2c1 - ct * n2c2) * sums[3] + ) # Give it a finishing touch! - result *= (_meg_const / lr1lr2) + result *= _meg_const / lr1lr2 if volume_integral: result *= r else: # 'eeg' @@ -281,7 +309,7 @@ def _fast_sphere_dot_r0(r, rr1_orig, rr2s, lr1, lr2s, cosmags1, cosmags2s, if w1 is not None: result *= w1[:, np.newaxis] for ii, w2 in enumerate(w2s): - out[ii] = np.sum(result[:, offset:offset + len(w2)], axis=sum_axis) + out[ii] = np.sum(result[:, offset : offset + len(w2)], axis=sum_axis) offset += len(w2) return out @@ -314,40 +342,52 @@ def _do_self_dots(intrad, volume, coils, r0, ch_type, lut, n_fact, n_jobs): products : array, shape (n_coils, n_coils) The integration products. """ - if ch_type == 'eeg': + if ch_type == "eeg": intrad = intrad * 0.7 # convert to normalized distances from expansion center - rmags = [coil['rmag'] - r0[np.newaxis, :] for coil in coils] + rmags = [coil["rmag"] - r0[np.newaxis, :] for coil in coils] rlens = [np.sqrt(np.sum(r * r, axis=1)) for r in rmags] rmags = [r / rl[:, np.newaxis] for r, rl in zip(rmags, rlens)] - cosmags = [coil['cosmag'] for coil in coils] - ws = [coil['w'] for coil in coils] + cosmags = [coil["cosmag"] for coil in coils] + ws = [coil["w"] for coil in coils] parallel, p_fun, n_jobs = parallel_func(_do_self_dots_subset, n_jobs) - prods = parallel(p_fun(intrad, rmags, rlens, cosmags, - ws, volume, lut, n_fact, ch_type, idx) - for idx in np.array_split(np.arange(len(rmags)), n_jobs)) + prods = parallel( + p_fun(intrad, rmags, rlens, cosmags, ws, volume, lut, n_fact, ch_type, idx) + for idx in np.array_split(np.arange(len(rmags)), n_jobs) + ) products = np.sum(prods, axis=0) return products -def _do_self_dots_subset(intrad, rmags, rlens, cosmags, ws, volume, lut, - n_fact, ch_type, idx): +def _do_self_dots_subset( + intrad, rmags, rlens, cosmags, ws, volume, lut, n_fact, ch_type, idx +): """Parallelize.""" # all possible combinations of two magnetometers products = np.zeros((len(rmags), len(rmags))) for ci1 in idx: ci2 = ci1 + 1 res = _fast_sphere_dot_r0( - intrad, rmags[ci1], rmags[:ci2], rlens[ci1], rlens[:ci2], - cosmags[ci1], cosmags[:ci2], ws[ci1], ws[:ci2], volume, lut, - n_fact, ch_type) + intrad, + rmags[ci1], + rmags[:ci2], + rlens[ci1], + rlens[:ci2], + cosmags[ci1], + cosmags[:ci2], + ws[ci1], + ws[:ci2], + volume, + lut, + n_fact, + ch_type, + ) products[ci1, :ci2] = res products[:ci2, ci1] = res return products -def _do_cross_dots(intrad, volume, coils1, coils2, r0, ch_type, - lut, n_fact): +def _do_cross_dots(intrad, volume, coils1, coils2, r0, ch_type, lut, n_fact): """Compute lead field dot product integrations between two coil sets. The code is a direct translation of MNE-C code found in @@ -378,10 +418,10 @@ def _do_cross_dots(intrad, volume, coils1, coils2, r0, ch_type, products : array, shape (n_coils, n_coils) The integration products. """ - if ch_type == 'eeg': + if ch_type == "eeg": intrad = intrad * 0.7 - rmags1 = [coil['rmag'] - r0[np.newaxis, :] for coil in coils1] - rmags2 = [coil['rmag'] - r0[np.newaxis, :] for coil in coils2] + rmags1 = [coil["rmag"] - r0[np.newaxis, :] for coil in coils1] + rmags2 = [coil["rmag"] - r0[np.newaxis, :] for coil in coils2] rlens1 = [np.sqrt(np.sum(r * r, axis=1)) for r in rmags1] rlens2 = [np.sqrt(np.sum(r * r, axis=1)) for r in rmags2] @@ -389,24 +429,37 @@ def _do_cross_dots(intrad, volume, coils1, coils2, r0, ch_type, rmags1 = [r / rl[:, np.newaxis] for r, rl in zip(rmags1, rlens1)] rmags2 = [r / rl[:, np.newaxis] for r, rl in zip(rmags2, rlens2)] - ws1 = [coil['w'] for coil in coils1] - ws2 = [coil['w'] for coil in coils2] + ws1 = [coil["w"] for coil in coils1] + ws2 = [coil["w"] for coil in coils2] - cosmags1 = [coil['cosmag'] for coil in coils1] - cosmags2 = [coil['cosmag'] for coil in coils2] + cosmags1 = [coil["cosmag"] for coil in coils1] + cosmags2 = [coil["cosmag"] for coil in coils2] products = np.zeros((len(rmags1), len(rmags2))) for ci1 in range(len(coils1)): res = _fast_sphere_dot_r0( - intrad, rmags1[ci1], rmags2, rlens1[ci1], rlens2, cosmags1[ci1], - cosmags2, ws1[ci1], ws2, volume, lut, n_fact, ch_type) + intrad, + rmags1[ci1], + rmags2, + rlens1[ci1], + rlens2, + cosmags1[ci1], + cosmags2, + ws1[ci1], + ws2, + volume, + lut, + n_fact, + ch_type, + ) products[ci1, :] = res return products @fill_doc -def _do_surface_dots(intrad, volume, coils, surf, sel, r0, ch_type, - lut, n_fact, n_jobs): +def _do_surface_dots( + intrad, volume, coils, surf, sel, r0, ch_type, lut, n_fact, n_jobs +): """Compute the map construction products. Parameters @@ -438,15 +491,15 @@ def _do_surface_dots(intrad, volume, coils, surf, sel, r0, ch_type, The integration products. """ # convert to normalized distances from expansion center - rmags = [coil['rmag'] - r0[np.newaxis, :] for coil in coils] + rmags = [coil["rmag"] - r0[np.newaxis, :] for coil in coils] rlens = [np.sqrt(np.sum(r * r, axis=1)) for r in rmags] rmags = [r / rl[:, np.newaxis] for r, rl in zip(rmags, rlens)] - cosmags = [coil['cosmag'] for coil in coils] - ws = [coil['w'] for coil in coils] + cosmags = [coil["cosmag"] for coil in coils] + ws = [coil["w"] for coil in coils] rref = None refl = None # virt_ref = False - if ch_type == 'eeg': + if ch_type == "eeg": intrad = intrad * 0.7 # The virtual ref code is untested and unused, so it is # commented out for now @@ -455,24 +508,54 @@ def _do_surface_dots(intrad, volume, coils, surf, sel, r0, ch_type, # refl = np.sqrt(np.sum(rref * rref, axis=1)) # rref /= refl[:, np.newaxis] - rsurf = surf['rr'][sel] - r0[np.newaxis, :] + rsurf = surf["rr"][sel] - r0[np.newaxis, :] lsurf = np.sqrt(np.sum(rsurf * rsurf, axis=1)) rsurf /= lsurf[:, np.newaxis] - this_nn = surf['nn'][sel] + this_nn = surf["nn"][sel] # loop over the coils parallel, p_fun, n_jobs = parallel_func(_do_surface_dots_subset, n_jobs) - prods = parallel(p_fun(intrad, rsurf, rmags, rref, refl, lsurf, rlens, - this_nn, cosmags, ws, volume, lut, n_fact, ch_type, - idx) - for idx in np.array_split(np.arange(len(rmags)), n_jobs)) + prods = parallel( + p_fun( + intrad, + rsurf, + rmags, + rref, + refl, + lsurf, + rlens, + this_nn, + cosmags, + ws, + volume, + lut, + n_fact, + ch_type, + idx, + ) + for idx in np.array_split(np.arange(len(rmags)), n_jobs) + ) products = np.sum(prods, axis=0) return products -def _do_surface_dots_subset(intrad, rsurf, rmags, rref, refl, lsurf, rlens, - this_nn, cosmags, ws, volume, lut, n_fact, ch_type, - idx): +def _do_surface_dots_subset( + intrad, + rsurf, + rmags, + rref, + refl, + lsurf, + rlens, + this_nn, + cosmags, + ws, + volume, + lut, + n_fact, + ch_type, + idx, +): """Parallelize. Parameters @@ -507,8 +590,20 @@ def _do_surface_dots_subset(intrad, rsurf, rmags, rref, refl, lsurf, rlens, The integration products. """ products = _fast_sphere_dot_r0( - intrad, rsurf, rmags, lsurf, rlens, this_nn, cosmags, None, ws, - volume, lut, n_fact, ch_type).T + intrad, + rsurf, + rmags, + lsurf, + rlens, + this_nn, + cosmags, + None, + ws, + volume, + lut, + n_fact, + ch_type, + ).T if rref is not None: raise NotImplementedError # we don't ever use this, isn't tested # vres = _fast_sphere_dot_r0( diff --git a/mne/forward/_make_forward.py b/mne/forward/_make_forward.py index 34be8d023cd..44783393ab8 100644 --- a/mne/forward/_make_forward.py +++ b/mne/forward/_make_forward.py @@ -21,23 +21,35 @@ from ..io.compensator import get_current_comp, make_compensator from ..io.pick import _has_kit_refs, pick_types, pick_info from ..io.constants import FIFF, FWD -from ..transforms import (_ensure_trans, transform_surface_to, apply_trans, - _get_trans, _print_coord_trans, _coord_frame_name, - Transform, invert_transform) +from ..transforms import ( + _ensure_trans, + transform_surface_to, + apply_trans, + _get_trans, + _print_coord_trans, + _coord_frame_name, + Transform, + invert_transform, +) from ..utils import logger, verbose, warn, _pl, _validate_type, _check_fname -from ..source_space import (_ensure_src, _filter_source_spaces, - _make_discrete_source_space, _complete_vol_src) +from ..source_space import ( + _ensure_src, + _filter_source_spaces, + _make_discrete_source_space, + _complete_vol_src, +) from ..source_estimate import VolSourceEstimate from ..surface import _normalize_vectors, _CheckInside from ..bem import read_bem_solution, _bem_find_surface, ConductorModel -from .forward import (Forward, _merge_fwds, convert_forward_solution, - _FWD_ORDER) +from .forward import Forward, _merge_fwds, convert_forward_solution, _FWD_ORDER -_accuracy_dict = dict(point=FWD.COIL_ACCURACY_POINT, - normal=FWD.COIL_ACCURACY_NORMAL, - accurate=FWD.COIL_ACCURACY_ACCURATE) +_accuracy_dict = dict( + point=FWD.COIL_ACCURACY_POINT, + normal=FWD.COIL_ACCURACY_NORMAL, + accurate=FWD.COIL_ACCURACY_ACCURATE, +) _extra_coil_def_fname = None @@ -63,11 +75,11 @@ def _read_coil_defs(verbose=None): The global variable "_extra_coil_def_fname" can be used to prepend additional definitions. These are never added to the registry. """ - coil_dir = op.join(op.split(__file__)[0], '..', 'data') + coil_dir = op.join(op.split(__file__)[0], "..", "data") coils = list() if _extra_coil_def_fname is not None: coils += _read_coil_def_file(_extra_coil_def_fname, use_registry=False) - coils += _read_coil_def_file(op.join(coil_dir, 'coil_def.dat')) + coils += _read_coil_def_file(op.join(coil_dir, "coil_def.dat")) return coils @@ -81,23 +93,28 @@ def _read_coil_def_file(fname, use_registry=True): if not use_registry or fname not in _coil_registry: big_val = 0.5 coils = list() - with open(fname, 'r') as fid: + with open(fname, "r") as fid: lines = fid.readlines() lines = lines[::-1] while len(lines) > 0: line = lines.pop().strip() - if line[0] == '#' and len(line) > 0: + if line[0] == "#" and len(line) > 0: continue desc_start = line.find('"') desc_end = len(line) - 1 assert line.strip()[desc_end] == '"' desc = line[desc_start:desc_end] - vals = np.fromstring(line[:desc_start].strip(), - dtype=float, sep=' ') + vals = np.fromstring(line[:desc_start].strip(), dtype=float, sep=" ") assert len(vals) == 6 npts = int(vals[3]) - coil = dict(coil_type=vals[1], coil_class=vals[0], desc=desc, - accuracy=vals[2], size=vals[4], base=vals[5]) + coil = dict( + coil_type=vals[1], + coil_class=vals[0], + desc=desc, + accuracy=vals[2], + size=vals[4], + base=vals[5], + ) # get parameters of each component rmag = list() cosmag = list() @@ -105,13 +122,13 @@ def _read_coil_def_file(fname, use_registry=True): for p in range(npts): # get next non-comment line line = lines.pop() - while line[0] == '#': + while line[0] == "#": line = lines.pop() - vals = np.fromstring(line, sep=' ') + vals = np.fromstring(line, sep=" ") if len(vals) != 7: raise RuntimeError( - f'Could not interpret line {p + 1} as 7 points:\n' - f'{line}') + f"Could not interpret line {p + 1} as 7 points:\n" f"{line}" + ) # Read and verify data for each integration point w.append(vals[0]) rmag.append(vals[[1, 2, 3]]) @@ -119,11 +136,11 @@ def _read_coil_def_file(fname, use_registry=True): w = np.array(w) rmag = np.array(rmag) cosmag = np.array(cosmag) - size = np.sqrt(np.sum(cosmag ** 2, axis=1)) - if np.any(np.sqrt(np.sum(rmag ** 2, axis=1)) > big_val): - raise RuntimeError('Unreasonable integration point') + size = np.sqrt(np.sum(cosmag**2, axis=1)) + if np.any(np.sqrt(np.sum(rmag**2, axis=1)) > big_val): + raise RuntimeError("Unreasonable integration point") if np.any(size <= 0): - raise RuntimeError('Unreasonable normal') + raise RuntimeError("Unreasonable normal") cosmag /= size[:, np.newaxis] coil.update(dict(w=w, cosmag=cosmag, rmag=rmag)) coils.append(coil) @@ -131,70 +148,92 @@ def _read_coil_def_file(fname, use_registry=True): _coil_registry[fname] = coils if use_registry: coils = deepcopy(_coil_registry[fname]) - logger.info('%d coil definition%s read', len(coils), _pl(coils)) + logger.info("%d coil definition%s read", len(coils), _pl(coils)) return coils def _create_meg_coil(coilset, ch, acc, do_es): """Create a coil definition using templates, transform if necessary.""" # Also change the coordinate frame if so desired - if ch['kind'] not in [FIFF.FIFFV_MEG_CH, FIFF.FIFFV_REF_MEG_CH]: - raise RuntimeError('%s is not a MEG channel' % ch['ch_name']) + if ch["kind"] not in [FIFF.FIFFV_MEG_CH, FIFF.FIFFV_REF_MEG_CH]: + raise RuntimeError("%s is not a MEG channel" % ch["ch_name"]) # Simple linear search from the coil definitions for coil in coilset: - if coil['coil_type'] == (ch['coil_type'] & 0xFFFF) and \ - coil['accuracy'] == acc: + if coil["coil_type"] == (ch["coil_type"] & 0xFFFF) and coil["accuracy"] == acc: break else: - raise RuntimeError('Desired coil definition not found ' - '(type = %d acc = %d)' % (ch['coil_type'], acc)) + raise RuntimeError( + "Desired coil definition not found " + "(type = %d acc = %d)" % (ch["coil_type"], acc) + ) # Apply a coordinate transformation if so desired - coil_trans = _loc_to_coil_trans(ch['loc']) + coil_trans = _loc_to_coil_trans(ch["loc"]) # Create the result - res = dict(chname=ch['ch_name'], coil_class=coil['coil_class'], - accuracy=coil['accuracy'], base=coil['base'], size=coil['size'], - type=ch['coil_type'], w=coil['w'], desc=coil['desc'], - coord_frame=FIFF.FIFFV_COORD_DEVICE, rmag_orig=coil['rmag'], - cosmag_orig=coil['cosmag'], coil_trans_orig=coil_trans, - r0=coil_trans[:3, 3], - rmag=apply_trans(coil_trans, coil['rmag']), - cosmag=apply_trans(coil_trans, coil['cosmag'], False)) + res = dict( + chname=ch["ch_name"], + coil_class=coil["coil_class"], + accuracy=coil["accuracy"], + base=coil["base"], + size=coil["size"], + type=ch["coil_type"], + w=coil["w"], + desc=coil["desc"], + coord_frame=FIFF.FIFFV_COORD_DEVICE, + rmag_orig=coil["rmag"], + cosmag_orig=coil["cosmag"], + coil_trans_orig=coil_trans, + r0=coil_trans[:3, 3], + rmag=apply_trans(coil_trans, coil["rmag"]), + cosmag=apply_trans(coil_trans, coil["cosmag"], False), + ) if do_es: - r0_exey = (np.dot(coil['rmag'][:, :2], coil_trans[:3, :2].T) + - coil_trans[:3, 3]) - res.update(ex=coil_trans[:3, 0], ey=coil_trans[:3, 1], - ez=coil_trans[:3, 2], r0_exey=r0_exey) + r0_exey = np.dot(coil["rmag"][:, :2], coil_trans[:3, :2].T) + coil_trans[:3, 3] + res.update( + ex=coil_trans[:3, 0], + ey=coil_trans[:3, 1], + ez=coil_trans[:3, 2], + r0_exey=r0_exey, + ) return res def _create_eeg_el(ch, t=None): """Create an electrode definition, transform coords if necessary.""" - if ch['kind'] != FIFF.FIFFV_EEG_CH: - raise RuntimeError('%s is not an EEG channel. Cannot create an ' - 'electrode definition.' % ch['ch_name']) + if ch["kind"] != FIFF.FIFFV_EEG_CH: + raise RuntimeError( + "%s is not an EEG channel. Cannot create an " + "electrode definition." % ch["ch_name"] + ) if t is None: - t = Transform('head', 'head') # identity, no change - if t.from_str != 'head': - raise RuntimeError('Inappropriate coordinate transformation') + t = Transform("head", "head") # identity, no change + if t.from_str != "head": + raise RuntimeError("Inappropriate coordinate transformation") - r0ex = _loc_to_eeg_loc(ch['loc']) + r0ex = _loc_to_eeg_loc(ch["loc"]) if r0ex.shape[1] == 1: # no reference - w = np.array([1.]) + w = np.array([1.0]) else: # has reference - w = np.array([1., -1.]) + w = np.array([1.0, -1.0]) # Optional coordinate transformation - r0ex = apply_trans(t['trans'], r0ex.T) + r0ex = apply_trans(t["trans"], r0ex.T) # The electrode location cosmag = r0ex.copy() _normalize_vectors(cosmag) - res = dict(chname=ch['ch_name'], coil_class=FWD.COILC_EEG, w=w, - accuracy=_accuracy_dict['normal'], type=ch['coil_type'], - coord_frame=t['to'], rmag=r0ex, cosmag=cosmag) + res = dict( + chname=ch["ch_name"], + coil_class=FWD.COILC_EEG, + w=w, + accuracy=_accuracy_dict["normal"], + type=ch["coil_type"], + coord_frame=t["to"], + rmag=r0ex, + cosmag=cosmag, + ) return res @@ -212,16 +251,24 @@ def _transform_orig_meg_coils(coils, t, do_es=True): if t is None: return for coil in coils: - coil_trans = np.dot(t['trans'], coil['coil_trans_orig']) + coil_trans = np.dot(t["trans"], coil["coil_trans_orig"]) coil.update( - coord_frame=t['to'], r0=coil_trans[:3, 3], - rmag=apply_trans(coil_trans, coil['rmag_orig']), - cosmag=apply_trans(coil_trans, coil['cosmag_orig'], False)) + coord_frame=t["to"], + r0=coil_trans[:3, 3], + rmag=apply_trans(coil_trans, coil["rmag_orig"]), + cosmag=apply_trans(coil_trans, coil["cosmag_orig"], False), + ) if do_es: - r0_exey = (np.dot(coil['rmag_orig'][:, :2], - coil_trans[:3, :2].T) + coil_trans[:3, 3]) - coil.update(ex=coil_trans[:3, 0], ey=coil_trans[:3, 1], - ez=coil_trans[:3, 2], r0_exey=r0_exey) + r0_exey = ( + np.dot(coil["rmag_orig"][:, :2], coil_trans[:3, :2].T) + + coil_trans[:3, 3] + ) + coil.update( + ex=coil_trans[:3, 0], + ey=coil_trans[:3, 1], + ez=coil_trans[:3, 2], + r0_exey=r0_exey, + ) def _create_eeg_els(chs): @@ -230,47 +277,58 @@ def _create_eeg_els(chs): @verbose -def _setup_bem(bem, bem_extra, neeg, mri_head_t, allow_none=False, - verbose=None): +def _setup_bem(bem, bem_extra, neeg, mri_head_t, allow_none=False, verbose=None): """Set up a BEM for forward computation, making a copy and modifying.""" if allow_none and bem is None: return None - logger.info('') - _validate_type(bem, ('path-like', ConductorModel), bem) + logger.info("") + _validate_type(bem, ("path-like", ConductorModel), bem) if not isinstance(bem, ConductorModel): - logger.info('Setting up the BEM model using %s...\n' % bem_extra) + logger.info("Setting up the BEM model using %s...\n" % bem_extra) bem = read_bem_solution(bem) else: bem = bem.copy() - if bem['is_sphere']: - logger.info('Using the sphere model.\n') - if len(bem['layers']) == 0 and neeg > 0: - raise RuntimeError('Spherical model has zero shells, cannot use ' - 'with EEG data') - if bem['coord_frame'] != FIFF.FIFFV_COORD_HEAD: - raise RuntimeError('Spherical model is not in head coordinates') + if bem["is_sphere"]: + logger.info("Using the sphere model.\n") + if len(bem["layers"]) == 0 and neeg > 0: + raise RuntimeError( + "Spherical model has zero shells, cannot use " "with EEG data" + ) + if bem["coord_frame"] != FIFF.FIFFV_COORD_HEAD: + raise RuntimeError("Spherical model is not in head coordinates") else: - if bem['surfs'][0]['coord_frame'] != FIFF.FIFFV_COORD_MRI: + if bem["surfs"][0]["coord_frame"] != FIFF.FIFFV_COORD_MRI: + raise RuntimeError( + "BEM is in %s coordinates, should be in MRI" + % (_coord_frame_name(bem["surfs"][0]["coord_frame"]),) + ) + if neeg > 0 and len(bem["surfs"]) == 1: raise RuntimeError( - 'BEM is in %s coordinates, should be in MRI' - % (_coord_frame_name(bem['surfs'][0]['coord_frame']),)) - if neeg > 0 and len(bem['surfs']) == 1: - raise RuntimeError('Cannot use a homogeneous (1-layer BEM) model ' - 'for EEG forward calculations, consider ' - 'using a 3-layer BEM instead') - logger.info('Employing the head->MRI coordinate transform with the ' - 'BEM model.') + "Cannot use a homogeneous (1-layer BEM) model " + "for EEG forward calculations, consider " + "using a 3-layer BEM instead" + ) + logger.info( + "Employing the head->MRI coordinate transform with the " "BEM model." + ) # fwd_bem_set_head_mri_t: Set the coordinate transformation - bem['head_mri_t'] = _ensure_trans(mri_head_t, 'head', 'mri') - logger.info('BEM model %s is now set up' % op.split(bem_extra)[1]) - logger.info('') + bem["head_mri_t"] = _ensure_trans(mri_head_t, "head", "mri") + logger.info("BEM model %s is now set up" % op.split(bem_extra)[1]) + logger.info("") return bem @verbose -def _prep_meg_channels(info, accuracy='accurate', exclude=(), *, - ignore_ref=False, head_frame=True, do_es=False, - verbose=None): +def _prep_meg_channels( + info, + accuracy="accurate", + exclude=(), + *, + ignore_ref=False, + head_frame=True, + do_es=False, + verbose=None, +): """Prepare MEG coil definitions for forward calculation.""" # Find MEG channels ref_meg = True if not ignore_ref else False @@ -278,7 +336,7 @@ def _prep_meg_channels(info, accuracy='accurate', exclude=(), *, # Make sure MEG coils exist if len(picks) <= 0: - raise RuntimeError('Could not find any MEG channels') + raise RuntimeError("Could not find any MEG channels") info_meg = pick_info(info, picks) del picks @@ -287,95 +345,110 @@ def _prep_meg_channels(info, accuracy='accurate', exclude=(), *, # Get MEG compensation channels compensator = post_picks = None - ch_names = info_meg['ch_names'] + ch_names = info_meg["ch_names"] if not ignore_ref: ref_picks = pick_types(info, meg=False, ref_meg=True, exclude=exclude) ncomp = len(ref_picks) - if (ncomp > 0): - logger.info(f'Read {ncomp} MEG compensation channels from info') + if ncomp > 0: + logger.info(f"Read {ncomp} MEG compensation channels from info") # We need to check to make sure these are NOT KIT refs if _has_kit_refs(info, ref_picks): raise NotImplementedError( - 'Cannot create forward solution with KIT reference ' + "Cannot create forward solution with KIT reference " 'channels. Consider using "ignore_ref=True" in ' - 'calculation') - logger.info( - f'{len(info["comps"])} compensation data sets in info') + "calculation" + ) + logger.info(f'{len(info["comps"])} compensation data sets in info') # Compose a compensation data set if necessary # adapted from mne_make_ctf_comp() from mne_ctf_comp.c - logger.info('Setting up compensation data...') + logger.info("Setting up compensation data...") comp_num = get_current_comp(info) if comp_num is None or comp_num == 0: - logger.info(' No compensation set. Nothing more to do.') + logger.info(" No compensation set. Nothing more to do.") else: compensator = make_compensator( - info_meg, 0, comp_num, exclude_comp_chs=False) - logger.info( - f' Desired compensation data ({comp_num}) found.') - logger.info(' All compensation channels found.') - logger.info(' Preselector created.') - logger.info(' Compensation data matrix created.') - logger.info(' Postselector created.') - post_picks = pick_types( - info_meg, meg=True, ref_meg=False, exclude=exclude) + info_meg, 0, comp_num, exclude_comp_chs=False + ) + logger.info(f" Desired compensation data ({comp_num}) found.") + logger.info(" All compensation channels found.") + logger.info(" Preselector created.") + logger.info(" Compensation data matrix created.") + logger.info(" Postselector created.") + post_picks = pick_types(info_meg, meg=True, ref_meg=False, exclude=exclude) ch_names = [ch_names[pick] for pick in post_picks] # Create coil descriptions with transformation to head or device frame templates = _read_coil_defs() if head_frame: - _print_coord_trans(info['dev_head_t']) - transform = info['dev_head_t'] + _print_coord_trans(info["dev_head_t"]) + transform = info["dev_head_t"] else: transform = None megcoils = _create_meg_coils( - info_meg['chs'], accuracy, transform, templates, do_es=do_es) + info_meg["chs"], accuracy, transform, templates, do_es=do_es + ) # Check that coordinate frame is correct and log it if head_frame: - assert megcoils[0]['coord_frame'] == FIFF.FIFFV_COORD_HEAD - logger.info('MEG coil definitions created in head coordinates.') + assert megcoils[0]["coord_frame"] == FIFF.FIFFV_COORD_HEAD + logger.info("MEG coil definitions created in head coordinates.") else: - assert megcoils[0]['coord_frame'] == FIFF.FIFFV_COORD_DEVICE - logger.info('MEG coil definitions created in device coordinate.') + assert megcoils[0]["coord_frame"] == FIFF.FIFFV_COORD_DEVICE + logger.info("MEG coil definitions created in device coordinate.") return dict( - defs=megcoils, ch_names=ch_names, compensator=compensator, - info=info_meg, post_picks=post_picks) + defs=megcoils, + ch_names=ch_names, + compensator=compensator, + info=info_meg, + post_picks=post_picks, + ) @verbose def _prep_eeg_channels(info, exclude=(), verbose=None): """Prepare EEG electrode definitions for forward calculation.""" - info_extra = 'info' + info_extra = "info" # Find EEG electrodes - picks = pick_types(info, meg=False, eeg=True, ref_meg=False, - exclude=exclude) + picks = pick_types(info, meg=False, eeg=True, ref_meg=False, exclude=exclude) # Make sure EEG electrodes exist neeg = len(picks) if neeg <= 0: - raise RuntimeError('Could not find any EEG channels') + raise RuntimeError("Could not find any EEG channels") # Get channel info and names for EEG channels - eegchs = pick_info(info, picks)['chs'] - eegnames = [info['ch_names'][p] for p in picks] - logger.info('Read %3d EEG channels from %s' % (len(picks), info_extra)) + eegchs = pick_info(info, picks)["chs"] + eegnames = [info["ch_names"][p] for p in picks] + logger.info("Read %3d EEG channels from %s" % (len(picks), info_extra)) # Create EEG electrode descriptions eegels = _create_eeg_els(eegchs) - logger.info('Head coordinate coil definitions created.') + logger.info("Head coordinate coil definitions created.") return dict(defs=eegels, ch_names=eegnames) @verbose -def _prepare_for_forward(src, mri_head_t, info, bem, mindist, n_jobs, - bem_extra='', trans='', info_extra='', - meg=True, eeg=True, ignore_ref=False, - allow_bem_none=False, verbose=None): +def _prepare_for_forward( + src, + mri_head_t, + info, + bem, + mindist, + n_jobs, + bem_extra="", + trans="", + info_extra="", + meg=True, + eeg=True, + ignore_ref=False, + allow_bem_none=False, + verbose=None, +): """Prepare for forward computation. The sensors dict contains keys for each sensor type, e.g. 'meg', 'eeg'. @@ -389,116 +462,157 @@ def _prepare_for_forward(src, mri_head_t, info, bem, mindist, n_jobs, compensator """ # Read the source locations - logger.info('') + logger.info("") # let's make a copy in case we modify something src = _ensure_src(src).copy() - nsource = sum(s['nuse'] for s in src) + nsource = sum(s["nuse"] for s in src) if nsource == 0: - raise RuntimeError('No sources are active in these source spaces. ' - '"do_all" option should be used.') - logger.info('Read %d source spaces a total of %d active source locations' - % (len(src), nsource)) + raise RuntimeError( + "No sources are active in these source spaces. " + '"do_all" option should be used.' + ) + logger.info( + "Read %d source spaces a total of %d active source locations" + % (len(src), nsource) + ) # Delete some keys to clean up the source space: - for key in ['working_dir', 'command_line']: + for key in ["working_dir", "command_line"]: if key in src.info: del src.info[key] # Read the MRI -> head coordinate transformation - logger.info('') + logger.info("") _print_coord_trans(mri_head_t) # make a new dict with the relevant information - arg_list = [info_extra, trans, src, bem_extra, meg, eeg, mindist, - n_jobs, verbose] - cmd = 'make_forward_solution(%s)' % (', '.join([str(a) for a in arg_list])) + arg_list = [info_extra, trans, src, bem_extra, meg, eeg, mindist, n_jobs, verbose] + cmd = "make_forward_solution(%s)" % (", ".join([str(a) for a in arg_list])) mri_id = dict(machid=np.zeros(2, np.int32), version=0, secs=0, usecs=0) info_trans = str(trans) if isinstance(trans, Path) else trans - info = Info(chs=info['chs'], comps=info['comps'], - dev_head_t=info['dev_head_t'], mri_file=info_trans, - mri_id=mri_id, - meas_file=info_extra, meas_id=None, working_dir=os.getcwd(), - command_line=cmd, bads=info['bads'], mri_head_t=mri_head_t) + info = Info( + chs=info["chs"], + comps=info["comps"], + dev_head_t=info["dev_head_t"], + mri_file=info_trans, + mri_id=mri_id, + meas_file=info_extra, + meas_id=None, + working_dir=os.getcwd(), + command_line=cmd, + bads=info["bads"], + mri_head_t=mri_head_t, + ) info._update_redundant() info._check_consistency() - logger.info('') + logger.info("") sensors = dict() if meg and len(pick_types(info, meg=True, ref_meg=False, exclude=[])) > 0: - sensors['meg'] = _prep_meg_channels(info, ignore_ref=ignore_ref) + sensors["meg"] = _prep_meg_channels(info, ignore_ref=ignore_ref) if eeg and len(pick_types(info, eeg=True, exclude=[])) > 0: - sensors['eeg'] = _prep_eeg_channels(info) + sensors["eeg"] = _prep_eeg_channels(info) # Check that some channels were found if len(sensors) == 0: - raise RuntimeError('No MEG or EEG channels found.') + raise RuntimeError("No MEG or EEG channels found.") # pick out final info - info = pick_info(info, pick_types(info, meg=meg, eeg=eeg, ref_meg=False, - exclude=[])) + info = pick_info( + info, pick_types(info, meg=meg, eeg=eeg, ref_meg=False, exclude=[]) + ) # Transform the source spaces into the appropriate coordinates # (will either be HEAD or MRI) for s in src: - transform_surface_to(s, 'head', mri_head_t) - logger.info('Source spaces are now in %s coordinates.' - % _coord_frame_name(s['coord_frame'])) + transform_surface_to(s, "head", mri_head_t) + logger.info( + "Source spaces are now in %s coordinates." % _coord_frame_name(s["coord_frame"]) + ) # Prepare the BEM model - eegnames = sensors.get('eeg', dict()).get('ch_names', []) - bem = _setup_bem(bem, bem_extra, len(eegnames), mri_head_t, - allow_none=allow_bem_none) + eegnames = sensors.get("eeg", dict()).get("ch_names", []) + bem = _setup_bem( + bem, bem_extra, len(eegnames), mri_head_t, allow_none=allow_bem_none + ) del eegnames # Circumvent numerical problems by excluding points too close to the skull, # and check that sensors are not inside any BEM surface if bem is not None: - if not bem['is_sphere']: - check_surface = 'inner skull surface' - inner_skull = _bem_find_surface(bem, 'inner_skull') + if not bem["is_sphere"]: + check_surface = "inner skull surface" + inner_skull = _bem_find_surface(bem, "inner_skull") check_inside = _filter_source_spaces( - inner_skull, mindist, mri_head_t, src, n_jobs) - logger.info('') - if len(bem['surfs']) == 3: - check_surface = 'scalp surface' - check_inside = _CheckInside(_bem_find_surface(bem, 'head')) + inner_skull, mindist, mri_head_t, src, n_jobs + ) + logger.info("") + if len(bem["surfs"]) == 3: + check_surface = "scalp surface" + check_inside = _CheckInside(_bem_find_surface(bem, "head")) else: - check_surface = 'outermost sphere shell' - if len(bem['layers']) == 0: + check_surface = "outermost sphere shell" + if len(bem["layers"]) == 0: + def check_inside(x): return np.zeros(len(x), bool) + else: + def check_inside(x): - return (np.linalg.norm(x - bem['r0'], axis=1) < - bem['layers'][-1]['rad']) - if 'meg' in sensors: + return ( + np.linalg.norm(x - bem["r0"], axis=1) < bem["layers"][-1]["rad"] + ) + + if "meg" in sensors: meg_loc = apply_trans( invert_transform(mri_head_t), - np.array([coil['r0'] for coil in sensors['meg']['defs']])) + np.array([coil["r0"] for coil in sensors["meg"]["defs"]]), + ) n_inside = check_inside(meg_loc).sum() if n_inside: raise RuntimeError( - f'Found {n_inside} MEG sensor{_pl(n_inside)} inside the ' - f'{check_surface}, perhaps coordinate frames and/or ' - 'coregistration must be incorrect') + f"Found {n_inside} MEG sensor{_pl(n_inside)} inside the " + f"{check_surface}, perhaps coordinate frames and/or " + "coregistration must be incorrect" + ) - rr = np.concatenate([s['rr'][s['vertno']] for s in src]) + rr = np.concatenate([s["rr"][s["vertno"]] for s in src]) if len(rr) < 1: - raise RuntimeError('No points left in source space after excluding ' - 'points close to inner skull.') + raise RuntimeError( + "No points left in source space after excluding " + "points close to inner skull." + ) # deal with free orientations: source_nn = np.tile(np.eye(3), (len(rr), 1)) - update_kwargs = dict(nchan=len(info['ch_names']), nsource=len(rr), - info=info, src=src, source_nn=source_nn, - source_rr=rr, surf_ori=False, mri_head_t=mri_head_t) + update_kwargs = dict( + nchan=len(info["ch_names"]), + nsource=len(rr), + info=info, + src=src, + source_nn=source_nn, + source_rr=rr, + surf_ori=False, + mri_head_t=mri_head_t, + ) return sensors, rr, info, update_kwargs, bem @verbose -def make_forward_solution(info, trans, src, bem, meg=True, eeg=True, *, - mindist=0.0, ignore_ref=False, n_jobs=None, - verbose=None): +def make_forward_solution( + info, + trans, + src, + bem, + meg=True, + eeg=True, + *, + mindist=0.0, + ignore_ref=False, + n_jobs=None, + verbose=None, +): """Calculate a forward solution for a subject. Parameters @@ -561,61 +675,72 @@ def make_forward_solution(info, trans, src, bem, meg=True, eeg=True, *, # (could also be HEAD to MRI) mri_head_t, trans = _get_trans(trans) if isinstance(bem, ConductorModel): - bem_extra = 'instance of ConductorModel' + bem_extra = "instance of ConductorModel" else: bem_extra = bem - _validate_type(info, ('path-like', Info), 'info') + _validate_type(info, ("path-like", Info), "info") if not isinstance(info, Info): info_extra = op.split(info)[1] - info = _check_fname(info, must_exist=True, overwrite='read', - name='info') + info = _check_fname(info, must_exist=True, overwrite="read", name="info") info = read_info(info, verbose=False) else: - info_extra = 'instance of Info' + info_extra = "instance of Info" # Report the setup - logger.info('Source space : %s' % src) - logger.info('MRI -> head transform : %s' % trans) - logger.info('Measurement data : %s' % info_extra) - if isinstance(bem, ConductorModel) and bem['is_sphere']: - logger.info('Sphere model : origin at %s mm' - % (bem['r0'],)) - logger.info('Standard field computations') + logger.info("Source space : %s" % src) + logger.info("MRI -> head transform : %s" % trans) + logger.info("Measurement data : %s" % info_extra) + if isinstance(bem, ConductorModel) and bem["is_sphere"]: + logger.info("Sphere model : origin at %s mm" % (bem["r0"],)) + logger.info("Standard field computations") else: - logger.info('Conductor model : %s' % bem_extra) - logger.info('Accurate field computations') - logger.info('Do computations in %s coordinates', - _coord_frame_name(FIFF.FIFFV_COORD_HEAD)) - logger.info('Free source orientations') + logger.info("Conductor model : %s" % bem_extra) + logger.info("Accurate field computations") + logger.info( + "Do computations in %s coordinates", _coord_frame_name(FIFF.FIFFV_COORD_HEAD) + ) + logger.info("Free source orientations") # Create MEG coils and EEG electrodes in the head coordinate frame sensors, rr, info, update_kwargs, bem = _prepare_for_forward( - src, mri_head_t, info, bem, mindist, n_jobs, bem_extra, trans, - info_extra, meg, eeg, ignore_ref) - del (src, mri_head_t, trans, info_extra, bem_extra, mindist, - meg, eeg, ignore_ref) + src, + mri_head_t, + info, + bem, + mindist, + n_jobs, + bem_extra, + trans, + info_extra, + meg, + eeg, + ignore_ref, + ) + del (src, mri_head_t, trans, info_extra, bem_extra, mindist, meg, eeg, ignore_ref) # Time to do the heavy lifting: MEG first, then EEG fwds = _compute_forwards(rr, bem=bem, sensors=sensors, n_jobs=n_jobs) # merge forwards - fwds = {key: _to_forward_dict(fwds[key], sensors[key]['ch_names']) - for key in _FWD_ORDER if key in fwds} + fwds = { + key: _to_forward_dict(fwds[key], sensors[key]["ch_names"]) + for key in _FWD_ORDER + if key in fwds + } fwd = _merge_fwds(fwds, verbose=False) del fwds - logger.info('') + logger.info("") # Don't transform the source spaces back into MRI coordinates (which is # done in the C code) because mne-python assumes forward solution source # spaces are in head coords. fwd.update(**update_kwargs) - logger.info('Finished.') + logger.info("Finished.") return fwd @verbose -def make_forward_dipole(dipole, bem, info, trans=None, n_jobs=None, *, - verbose=None): +def make_forward_dipole(dipole, bem, info, trans=None, n_jobs=None, *, verbose=None): """Convert dipole object to source estimate and calculate forward operator. The instance of Dipole is converted to a discrete source space, @@ -662,6 +787,7 @@ def make_forward_dipole(dipole, bem, info, trans=None, n_jobs=None, *, """ if isinstance(dipole, list): from ..dipole import _concatenate_dipoles # To avoid circular import + dipole = _concatenate_dipoles(dipole) # Make copies to avoid mangling original dipole @@ -674,31 +800,29 @@ def make_forward_dipole(dipole, bem, info, trans=None, n_jobs=None, *, # NB information about dipole orientation enters here, then no more sources = dict(rr=pos, nn=ori) # Dipole objects must be in the head frame - src = _complete_vol_src( - [_make_discrete_source_space(sources, coord_frame='head')]) + src = _complete_vol_src([_make_discrete_source_space(sources, coord_frame="head")]) # Forward operator created for channels in info (use pick_info to restrict) # Use defaults for most params, including min_dist - fwd = make_forward_solution(info, trans, src, bem, n_jobs=n_jobs, - verbose=verbose) + fwd = make_forward_solution(info, trans, src, bem, n_jobs=n_jobs, verbose=verbose) # Convert from free orientations to fixed (in-place) - convert_forward_solution(fwd, surf_ori=False, force_fixed=True, - copy=False, use_cps=False, verbose=None) + convert_forward_solution( + fwd, surf_ori=False, force_fixed=True, copy=False, use_cps=False, verbose=None + ) # Check for omissions due to proximity to inner skull in # make_forward_solution, which will result in an exception - if fwd['src'][0]['nuse'] != len(pos): - inuse = fwd['src'][0]['inuse'].astype(bool) - head = ('The following dipoles are outside the inner skull boundary') - msg = len(head) * '#' + '\n' + head + '\n' - for (t, pos) in zip(times[np.logical_not(inuse)], - pos[np.logical_not(inuse)]): - msg += ' t={:.0f} ms, pos=({:.0f}, {:.0f}, {:.0f}) mm\n'.\ - format(t * 1000., pos[0] * 1000., - pos[1] * 1000., pos[2] * 1000.) - msg += len(head) * '#' + if fwd["src"][0]["nuse"] != len(pos): + inuse = fwd["src"][0]["inuse"].astype(bool) + head = "The following dipoles are outside the inner skull boundary" + msg = len(head) * "#" + "\n" + head + "\n" + for t, pos in zip(times[np.logical_not(inuse)], pos[np.logical_not(inuse)]): + msg += " t={:.0f} ms, pos=({:.0f}, {:.0f}, {:.0f}) mm\n".format( + t * 1000.0, pos[0] * 1000.0, pos[1] * 1000.0, pos[2] * 1000.0 + ) + msg += len(head) * "#" logger.error(msg) - raise ValueError('One or more dipoles outside the inner skull.') + raise ValueError("One or more dipoles outside the inner skull.") # multiple dipoles (rr and nn) per time instant allowed # uneven sampling in time returns list @@ -706,8 +830,10 @@ def make_forward_dipole(dipole, bem, info, trans=None, n_jobs=None, *, if len(timepoints) > 1: tdiff = np.diff(timepoints) if not np.allclose(tdiff, tdiff[0]): - warn('Unique time points of dipoles unevenly spaced: returned ' - 'stc will be a list, one for each time point.') + warn( + "Unique time points of dipoles unevenly spaced: returned " + "stc will be a list, one for each time point." + ) tstep = -1.0 else: tstep = tdiff[0] @@ -722,39 +848,64 @@ def make_forward_dipole(dipole, bem, info, trans=None, n_jobs=None, *, row = 0 for tpind, tp in enumerate(timepoints): amp = amplitude[np.in1d(times, tp)] - data[row:row + len(amp), tpind] = amp + data[row : row + len(amp), tpind] = amp row += len(amp) if tstep > 0: - stc = VolSourceEstimate(data, vertices=[fwd['src'][0]['vertno']], - tmin=timepoints[0], - tstep=tstep, subject=None) + stc = VolSourceEstimate( + data, + vertices=[fwd["src"][0]["vertno"]], + tmin=timepoints[0], + tstep=tstep, + subject=None, + ) else: # Must return a list of stc, one for each time point stc = [] for col, tp in enumerate(timepoints): - stc += [VolSourceEstimate(data[:, col][:, np.newaxis], - vertices=[fwd['src'][0]['vertno']], - tmin=tp, tstep=0.001, subject=None)] + stc += [ + VolSourceEstimate( + data[:, col][:, np.newaxis], + vertices=[fwd["src"][0]["vertno"]], + tmin=tp, + tstep=0.001, + subject=None, + ) + ] return fwd, stc -def _to_forward_dict(fwd, names, fwd_grad=None, - coord_frame=FIFF.FIFFV_COORD_HEAD, - source_ori=FIFF.FIFFV_MNE_FREE_ORI): +def _to_forward_dict( + fwd, + names, + fwd_grad=None, + coord_frame=FIFF.FIFFV_COORD_HEAD, + source_ori=FIFF.FIFFV_MNE_FREE_ORI, +): """Convert forward solution matrices to dicts.""" assert names is not None - sol = dict(data=fwd.T, nrow=fwd.shape[1], ncol=fwd.shape[0], - row_names=names, col_names=[]) - fwd = Forward(sol=sol, source_ori=source_ori, nsource=sol['ncol'], - coord_frame=coord_frame, sol_grad=None, - nchan=sol['nrow'], _orig_source_ori=source_ori, - _orig_sol=sol['data'].copy(), _orig_sol_grad=None) + sol = dict( + data=fwd.T, nrow=fwd.shape[1], ncol=fwd.shape[0], row_names=names, col_names=[] + ) + fwd = Forward( + sol=sol, + source_ori=source_ori, + nsource=sol["ncol"], + coord_frame=coord_frame, + sol_grad=None, + nchan=sol["nrow"], + _orig_source_ori=source_ori, + _orig_sol=sol["data"].copy(), + _orig_sol_grad=None, + ) if fwd_grad is not None: - sol_grad = dict(data=fwd_grad.T, nrow=fwd_grad.shape[1], - ncol=fwd_grad.shape[0], row_names=names, - col_names=[]) - fwd.update(dict(sol_grad=sol_grad), - _orig_sol_grad=sol_grad['data'].copy()) + sol_grad = dict( + data=fwd_grad.T, + nrow=fwd_grad.shape[1], + ncol=fwd_grad.shape[0], + row_names=names, + col_names=[], + ) + fwd.update(dict(sol_grad=sol_grad), _orig_sol_grad=sol_grad["data"].copy()) return fwd diff --git a/mne/forward/forward.py b/mne/forward/forward.py index 17ed07f8ac4..2f1e1c0b89d 100644 --- a/mne/forward/forward.py +++ b/mne/forward/forward.py @@ -25,30 +25,57 @@ from ..io.open import fiff_open from ..io.tree import dir_tree_find from ..io.tag import find_tag, read_tag -from ..io.matrix import (_read_named_matrix, _transpose_named_matrix, - write_named_matrix) -from ..io.meas_info import (_read_bad_channels, write_info, _write_ch_infos, - _read_extended_ch_info, _make_ch_names_mapping, - _write_bad_channels) -from ..io.pick import (pick_channels_forward, pick_info, pick_channels, - pick_types) -from ..io.write import (write_int, start_block, end_block, write_coord_trans, - write_string, start_and_end_file, write_id) +from ..io.matrix import _read_named_matrix, _transpose_named_matrix, write_named_matrix +from ..io.meas_info import ( + _read_bad_channels, + write_info, + _write_ch_infos, + _read_extended_ch_info, + _make_ch_names_mapping, + _write_bad_channels, +) +from ..io.pick import pick_channels_forward, pick_info, pick_channels, pick_types +from ..io.write import ( + write_int, + start_block, + end_block, + write_coord_trans, + write_string, + start_and_end_file, + write_id, +) from ..io.base import BaseRaw from ..evoked import Evoked, EvokedArray from ..epochs import BaseEpochs -from ..source_space import (_read_source_spaces_from_tree, - find_source_space_hemi, _set_source_space_vertices, - _write_source_spaces_to_fid, _get_src_nn, - _src_kind_dict) +from ..source_space import ( + _read_source_spaces_from_tree, + find_source_space_hemi, + _set_source_space_vertices, + _write_source_spaces_to_fid, + _get_src_nn, + _src_kind_dict, +) from ..source_estimate import _BaseVectorSourceEstimate, _BaseSourceEstimate from ..surface import _normal_orth -from ..transforms import (transform_surface_to, invert_transform, - write_trans) -from ..utils import (_check_fname, get_subjects_dir, has_mne_c, warn, - run_subprocess, check_fname, logger, verbose, fill_doc, - _validate_type, _check_compensation_grade, _check_option, - _check_stc_units, _stamp_to_dt, _on_missing, repr_html) +from ..transforms import transform_surface_to, invert_transform, write_trans +from ..utils import ( + _check_fname, + get_subjects_dir, + has_mne_c, + warn, + run_subprocess, + check_fname, + logger, + verbose, + fill_doc, + _validate_type, + _check_compensation_grade, + _check_option, + _check_stc_units, + _stamp_to_dt, + _on_missing, + repr_html, +) from ..label import Label @@ -133,69 +160,73 @@ def copy(self): return Forward(deepcopy(self)) def _get_src_type_and_ori_for_repr(self): - src_types = np.array([src['type'] for src in self['src']]) - - if (src_types == 'surf').all(): - src_type = 'Surface with %d vertices' % self['nsource'] - elif (src_types == 'vol').all(): - src_type = 'Volume with %d grid points' % self['nsource'] - elif (src_types == 'discrete').all(): - src_type = 'Discrete with %d dipoles' % self['nsource'] + src_types = np.array([src["type"] for src in self["src"]]) + + if (src_types == "surf").all(): + src_type = "Surface with %d vertices" % self["nsource"] + elif (src_types == "vol").all(): + src_type = "Volume with %d grid points" % self["nsource"] + elif (src_types == "discrete").all(): + src_type = "Discrete with %d dipoles" % self["nsource"] else: - count_string = '' - if (src_types == 'surf').any(): - count_string += '%d surface, ' % (src_types == 'surf').sum() - if (src_types == 'vol').any(): - count_string += '%d volume, ' % (src_types == 'vol').sum() - if (src_types == 'discrete').any(): - count_string += '%d discrete, ' \ - % (src_types == 'discrete').sum() - count_string = count_string.rstrip(', ') - src_type = ('Mixed (%s) with %d vertices' - % (count_string, self['nsource'])) - - if self['source_ori'] == FIFF.FIFFV_MNE_UNKNOWN_ORI: - src_ori = 'Unknown' - elif self['source_ori'] == FIFF.FIFFV_MNE_FIXED_ORI: - src_ori = 'Fixed' - elif self['source_ori'] == FIFF.FIFFV_MNE_FREE_ORI: - src_ori = 'Free' + count_string = "" + if (src_types == "surf").any(): + count_string += "%d surface, " % (src_types == "surf").sum() + if (src_types == "vol").any(): + count_string += "%d volume, " % (src_types == "vol").sum() + if (src_types == "discrete").any(): + count_string += "%d discrete, " % (src_types == "discrete").sum() + count_string = count_string.rstrip(", ") + src_type = "Mixed (%s) with %d vertices" % (count_string, self["nsource"]) + + if self["source_ori"] == FIFF.FIFFV_MNE_UNKNOWN_ORI: + src_ori = "Unknown" + elif self["source_ori"] == FIFF.FIFFV_MNE_FIXED_ORI: + src_ori = "Fixed" + elif self["source_ori"] == FIFF.FIFFV_MNE_FREE_ORI: + src_ori = "Free" return src_type, src_ori def __repr__(self): """Summarize forward info instead of printing all.""" - entr = ' 0: - raise ValueError('Width of matrix must be a multiple of n') + raise ValueError("Width of matrix must be a multiple of n") tmp = np.arange(ma * bdn, dtype=np.int64).reshape(bdn, ma) tmp = np.tile(tmp, (1, n)) @@ -279,7 +311,7 @@ def _get_tag_int(fid, node, name, id_): tag = find_tag(fid, node, id_) if tag is None: fid.close() - raise ValueError(name + ' tag not found') + raise ValueError(name + " tag not found") return int(tag.data.item()) @@ -290,42 +322,44 @@ def _read_one(fid, node): return None one = Forward() - one['source_ori'] = _get_tag_int(fid, node, 'Source orientation', - FIFF.FIFF_MNE_SOURCE_ORIENTATION) - one['coord_frame'] = _get_tag_int(fid, node, 'Coordinate frame', - FIFF.FIFF_MNE_COORD_FRAME) - one['nsource'] = _get_tag_int(fid, node, 'Number of sources', - FIFF.FIFF_MNE_SOURCE_SPACE_NPOINTS) - one['nchan'] = _get_tag_int(fid, node, 'Number of channels', - FIFF.FIFF_NCHAN) + one["source_ori"] = _get_tag_int( + fid, node, "Source orientation", FIFF.FIFF_MNE_SOURCE_ORIENTATION + ) + one["coord_frame"] = _get_tag_int( + fid, node, "Coordinate frame", FIFF.FIFF_MNE_COORD_FRAME + ) + one["nsource"] = _get_tag_int( + fid, node, "Number of sources", FIFF.FIFF_MNE_SOURCE_SPACE_NPOINTS + ) + one["nchan"] = _get_tag_int(fid, node, "Number of channels", FIFF.FIFF_NCHAN) try: - one['sol'] = _read_named_matrix(fid, node, - FIFF.FIFF_MNE_FORWARD_SOLUTION, - transpose=True) - one['_orig_sol'] = one['sol']['data'].copy() + one["sol"] = _read_named_matrix( + fid, node, FIFF.FIFF_MNE_FORWARD_SOLUTION, transpose=True + ) + one["_orig_sol"] = one["sol"]["data"].copy() except Exception: - logger.error('Forward solution data not found') + logger.error("Forward solution data not found") raise try: fwd_type = FIFF.FIFF_MNE_FORWARD_SOLUTION_GRAD - one['sol_grad'] = _read_named_matrix(fid, node, fwd_type, - transpose=True) - one['_orig_sol_grad'] = one['sol_grad']['data'].copy() + one["sol_grad"] = _read_named_matrix(fid, node, fwd_type, transpose=True) + one["_orig_sol_grad"] = one["sol_grad"]["data"].copy() except Exception: - one['sol_grad'] = None + one["sol_grad"] = None - if one['sol']['data'].shape[0] != one['nchan'] or \ - (one['sol']['data'].shape[1] != one['nsource'] and - one['sol']['data'].shape[1] != 3 * one['nsource']): - raise ValueError('Forward solution matrix has wrong dimensions') + if one["sol"]["data"].shape[0] != one["nchan"] or ( + one["sol"]["data"].shape[1] != one["nsource"] + and one["sol"]["data"].shape[1] != 3 * one["nsource"] + ): + raise ValueError("Forward solution matrix has wrong dimensions") - if one['sol_grad'] is not None: - if one['sol_grad']['data'].shape[0] != one['nchan'] or \ - (one['sol_grad']['data'].shape[1] != 3 * one['nsource'] and - one['sol_grad']['data'].shape[1] != 3 * 3 * one['nsource']): - raise ValueError('Forward solution gradient matrix has ' - 'wrong dimensions') + if one["sol_grad"] is not None: + if one["sol_grad"]["data"].shape[0] != one["nchan"] or ( + one["sol_grad"]["data"].shape[1] != 3 * one["nsource"] + and one["sol_grad"]["data"].shape[1] != 3 * 3 * one["nsource"] + ): + raise ValueError("Forward solution gradient matrix has " "wrong dimensions") return one @@ -352,30 +386,30 @@ def _read_forward_meas_info(tree, fid): # Information from the MRI file parent_mri = dir_tree_find(tree, FIFF.FIFFB_MNE_PARENT_MRI_FILE) if len(parent_mri) == 0: - raise ValueError('No parent MEG information found in operator') + raise ValueError("No parent MEG information found in operator") parent_mri = parent_mri[0] tag = find_tag(fid, parent_mri, FIFF.FIFF_MNE_FILE_NAME) - info['mri_file'] = tag.data if tag is not None else None + info["mri_file"] = tag.data if tag is not None else None tag = find_tag(fid, parent_mri, FIFF.FIFF_PARENT_FILE_ID) - info['mri_id'] = tag.data if tag is not None else None + info["mri_id"] = tag.data if tag is not None else None # Information from the MEG file parent_meg = dir_tree_find(tree, FIFF.FIFFB_MNE_PARENT_MEAS_FILE) if len(parent_meg) == 0: - raise ValueError('No parent MEG information found in operator') + raise ValueError("No parent MEG information found in operator") parent_meg = parent_meg[0] tag = find_tag(fid, parent_meg, FIFF.FIFF_MNE_FILE_NAME) - info['meas_file'] = tag.data if tag is not None else None + info["meas_file"] = tag.data if tag is not None else None tag = find_tag(fid, parent_meg, FIFF.FIFF_PARENT_FILE_ID) - info['meas_id'] = tag.data if tag is not None else None + info["meas_id"] = tag.data if tag is not None else None # Add channel information - info['chs'] = chs = list() - for k in range(parent_meg['nent']): - kind = parent_meg['directory'][k].kind - pos = parent_meg['directory'][k].pos + info["chs"] = chs = list() + for k in range(parent_meg["nent"]): + kind = parent_meg["directory"][k].kind + pos = parent_meg["directory"][k].pos if kind == FIFF.FIFF_CH_INFO: tag = read_tag(fid, pos) chs.append(tag.data) @@ -389,51 +423,50 @@ def _read_forward_meas_info(tree, fid): coord_device = FIFF.FIFFV_COORD_DEVICE coord_ctf_head = FIFF.FIFFV_MNE_COORD_CTF_HEAD if tag is None: - raise ValueError('MRI/head coordinate transformation not found') + raise ValueError("MRI/head coordinate transformation not found") cand = tag.data - if cand['from'] == coord_mri and cand['to'] == coord_head: - info['mri_head_t'] = cand + if cand["from"] == coord_mri and cand["to"] == coord_head: + info["mri_head_t"] = cand else: - raise ValueError('MRI/head coordinate transformation not found') + raise ValueError("MRI/head coordinate transformation not found") # Get the MEG device <-> head coordinate transformation tag = find_tag(fid, parent_meg, FIFF.FIFF_COORD_TRANS) if tag is None: - raise ValueError('MEG/head coordinate transformation not found') + raise ValueError("MEG/head coordinate transformation not found") cand = tag.data - if cand['from'] == coord_device and cand['to'] == coord_head: - info['dev_head_t'] = cand - elif cand['from'] == coord_ctf_head and cand['to'] == coord_head: - info['ctf_head_t'] = cand + if cand["from"] == coord_device and cand["to"] == coord_head: + info["dev_head_t"] = cand + elif cand["from"] == coord_ctf_head and cand["to"] == coord_head: + info["ctf_head_t"] = cand else: - raise ValueError('MEG/head coordinate transformation not found') + raise ValueError("MEG/head coordinate transformation not found") - info['bads'] = _read_bad_channels( - fid, parent_meg, ch_names_mapping=ch_names_mapping) + info["bads"] = _read_bad_channels( + fid, parent_meg, ch_names_mapping=ch_names_mapping + ) # clean up our bad list, old versions could have non-existent bads - info['bads'] = [bad for bad in info['bads'] if bad in info['ch_names']] + info["bads"] = [bad for bad in info["bads"] if bad in info["ch_names"]] # Check if a custom reference has been applied tag = find_tag(fid, parent_mri, FIFF.FIFF_MNE_CUSTOM_REF) if tag is None: tag = find_tag(fid, parent_mri, 236) # Constant 236 used before v0.11 - info['custom_ref_applied'] = ( - int(tag.data.item()) if tag is not None else False - ) + info["custom_ref_applied"] = int(tag.data.item()) if tag is not None else False info._unlocked = False return info def _subject_from_forward(forward): """Get subject id from inverse operator.""" - return forward['src']._subject + return forward["src"]._subject # This sets the forward solution order (and gives human-readable names) _FWD_ORDER = dict( - meg='MEG', - eeg='EEG', + meg="MEG", + eeg="EEG", ) @@ -455,28 +488,30 @@ def _merge_fwds(fwds, *, verbose=None): b = fwds[key] a_kind, b_kind = _FWD_ORDER[first_key], _FWD_ORDER[key] combined.append(b_kind) - if (a['sol']['data'].shape[1] != b['sol']['data'].shape[1] or - a['source_ori'] != b['source_ori'] or - a['nsource'] != b['nsource'] or - a['coord_frame'] != b['coord_frame']): + if ( + a["sol"]["data"].shape[1] != b["sol"]["data"].shape[1] + or a["source_ori"] != b["source_ori"] + or a["nsource"] != b["nsource"] + or a["coord_frame"] != b["coord_frame"] + ): raise ValueError( - f'The {a_kind} and {b_kind} forward solutions do not match') - for k in ('sol', 'sol_grad'): + f"The {a_kind} and {b_kind} forward solutions do not match" + ) + for k in ("sol", "sol_grad"): if a[k] is None: continue - a[k]['data'] = np.r_[a[k]['data'], b[k]['data']] - a[f'_orig_{k}'] = np.r_[a[f'_orig_{k}'], b[f'_orig_{k}']] - a[k]['nrow'] = a[k]['nrow'] + b[k]['nrow'] - a[k]['row_names'] = a[k]['row_names'] + b[k]['row_names'] - a['nchan'] = a['nchan'] + b['nchan'] + a[k]["data"] = np.r_[a[k]["data"], b[k]["data"]] + a[f"_orig_{k}"] = np.r_[a[f"_orig_{k}"], b[f"_orig_{k}"]] + a[k]["nrow"] = a[k]["nrow"] + b[k]["nrow"] + a[k]["row_names"] = a[k]["row_names"] + b[k]["row_names"] + a["nchan"] = a["nchan"] + b["nchan"] if len(fwds) > 1: logger.info(f' Forward solutions combined: {", ".join(combined)}') return fwd @verbose -def read_forward_solution(fname, include=(), exclude=(), *, ordered=None, - verbose=None): +def read_forward_solution(fname, include=(), exclude=(), *, ordered=None, verbose=None): """Read a forward solution a.k.a. lead field. Parameters @@ -515,27 +550,28 @@ def read_forward_solution(fname, include=(), exclude=(), *, ordered=None, surface-based, fixed orientation cannot be reverted after loading the forward solution with :func:`read_forward_solution`. """ - check_fname(fname, 'forward', ('-fwd.fif', '-fwd.fif.gz', - '_fwd.fif', '_fwd.fif.gz')) - fname = _check_fname(fname=fname, must_exist=True, overwrite='read') + check_fname( + fname, "forward", ("-fwd.fif", "-fwd.fif.gz", "_fwd.fif", "_fwd.fif.gz") + ) + fname = _check_fname(fname=fname, must_exist=True, overwrite="read") # Open the file, create directory - logger.info('Reading forward solution from %s...' % fname) + logger.info("Reading forward solution from %s..." % fname) f, tree, _ = fiff_open(fname) with f as fid: # Find all forward solutions fwds = dir_tree_find(tree, FIFF.FIFFB_MNE_FORWARD_SOLUTION) if len(fwds) == 0: - raise ValueError('No forward solutions in %s' % fname) + raise ValueError("No forward solutions in %s" % fname) # Parent MRI data parent_mri = dir_tree_find(tree, FIFF.FIFFB_MNE_PARENT_MRI_FILE) if len(parent_mri) == 0: - raise ValueError('No parent MRI information in %s' % fname) + raise ValueError("No parent MRI information in %s" % fname) parent_mri = parent_mri[0] src = _read_source_spaces_from_tree(fid, tree, patch_stats=False) for s in src: - s['id'] = find_source_space_hemi(s) + s["id"] = find_source_space_hemi(s) fwd = None @@ -545,8 +581,9 @@ def read_forward_solution(fname, include=(), exclude=(), *, ordered=None, for k in range(len(fwds)): tag = find_tag(fid, fwds[k], FIFF.FIFF_MNE_INCLUDED_METHODS) if tag is None: - raise ValueError('Methods not listed for one of the forward ' - 'solutions') + raise ValueError( + "Methods not listed for one of the forward " "solutions" + ) if tag.data == FIFF.FIFFV_MNE_MEG: megnode = fwds[k] @@ -556,26 +593,30 @@ def read_forward_solution(fname, include=(), exclude=(), *, ordered=None, fwds = dict() megfwd = _read_one(fid, megnode) if megfwd is not None: - fwds['meg'] = megfwd + fwds["meg"] = megfwd if is_fixed_orient(megfwd): - ori = 'fixed' + ori = "fixed" else: - ori = 'free' - logger.info(' Read MEG forward solution (%d sources, ' - '%d channels, %s orientations)' - % (megfwd['nsource'], megfwd['nchan'], ori)) + ori = "free" + logger.info( + " Read MEG forward solution (%d sources, " + "%d channels, %s orientations)" + % (megfwd["nsource"], megfwd["nchan"], ori) + ) del megfwd eegfwd = _read_one(fid, eegnode) if eegfwd is not None: - fwds['eeg'] = eegfwd + fwds["eeg"] = eegfwd if is_fixed_orient(eegfwd): - ori = 'fixed' + ori = "fixed" else: - ori = 'free' - logger.info(' Read EEG forward solution (%d sources, ' - '%d channels, %s orientations)' - % (eegfwd['nsource'], eegfwd['nchan'], ori)) + ori = "free" + logger.info( + " Read EEG forward solution (%d sources, " + "%d channels, %s orientations)" + % (eegfwd["nsource"], eegfwd["nchan"], ori) + ) del eegfwd fwd = _merge_fwds(fwds) @@ -584,22 +625,25 @@ def read_forward_solution(fname, include=(), exclude=(), *, ordered=None, # Get the MRI <-> head coordinate transformation tag = find_tag(fid, parent_mri, FIFF.FIFF_COORD_TRANS) if tag is None: - raise ValueError('MRI/head coordinate transformation not found') + raise ValueError("MRI/head coordinate transformation not found") mri_head_t = tag.data - if (mri_head_t['from'] != FIFF.FIFFV_COORD_MRI or - mri_head_t['to'] != FIFF.FIFFV_COORD_HEAD): + if ( + mri_head_t["from"] != FIFF.FIFFV_COORD_MRI + or mri_head_t["to"] != FIFF.FIFFV_COORD_HEAD + ): mri_head_t = invert_transform(mri_head_t) - if (mri_head_t['from'] != FIFF.FIFFV_COORD_MRI or - mri_head_t['to'] != FIFF.FIFFV_COORD_HEAD): + if ( + mri_head_t["from"] != FIFF.FIFFV_COORD_MRI + or mri_head_t["to"] != FIFF.FIFFV_COORD_HEAD + ): fid.close() - raise ValueError('MRI/head coordinate transformation not ' - 'found') - fwd['mri_head_t'] = mri_head_t + raise ValueError("MRI/head coordinate transformation not " "found") + fwd["mri_head_t"] = mri_head_t # # get parent MEG info # - fwd['info'] = _read_forward_meas_info(tree, fid) + fwd["info"] = _read_forward_meas_info(tree, fid) # MNE environment parent_env = dir_tree_find(tree, FIFF.FIFFB_MNE_ENV) @@ -607,20 +651,22 @@ def read_forward_solution(fname, include=(), exclude=(), *, ordered=None, parent_env = parent_env[0] tag = find_tag(fid, parent_env, FIFF.FIFF_MNE_ENV_WORKING_DIR) if tag is not None: - with fwd['info']._unlock(): - fwd['info']['working_dir'] = tag.data + with fwd["info"]._unlock(): + fwd["info"]["working_dir"] = tag.data tag = find_tag(fid, parent_env, FIFF.FIFF_MNE_ENV_COMMAND_LINE) if tag is not None: - with fwd['info']._unlock(): - fwd['info']['command_line'] = tag.data + with fwd["info"]._unlock(): + fwd["info"]["command_line"] = tag.data # Transform the source spaces to the correct coordinate frame # if necessary # Make sure forward solution is in either the MRI or HEAD coordinate frame - if fwd['coord_frame'] not in (FIFF.FIFFV_COORD_MRI, FIFF.FIFFV_COORD_HEAD): - raise ValueError('Only forward solutions computed in MRI or head ' - 'coordinates are acceptable') + if fwd["coord_frame"] not in (FIFF.FIFFV_COORD_MRI, FIFF.FIFFV_COORD_HEAD): + raise ValueError( + "Only forward solutions computed in MRI or head " + "coordinates are acceptable" + ) # Transform each source space to the HEAD or MRI coordinate frame, # depending on the coordinate frame of the forward solution @@ -629,45 +675,47 @@ def read_forward_solution(fname, include=(), exclude=(), *, ordered=None, nuse = 0 for s in src: try: - s = transform_surface_to(s, fwd['coord_frame'], mri_head_t) + s = transform_surface_to(s, fwd["coord_frame"], mri_head_t) except Exception as inst: - raise ValueError('Could not transform source space (%s)' % inst) + raise ValueError("Could not transform source space (%s)" % inst) - nuse += s['nuse'] + nuse += s["nuse"] # Make sure the number of sources match after transformation - if nuse != fwd['nsource']: - raise ValueError('Source spaces do not match the forward solution.') + if nuse != fwd["nsource"]: + raise ValueError("Source spaces do not match the forward solution.") - logger.info(' Source spaces transformed to the forward solution ' - 'coordinate frame') - fwd['src'] = src + logger.info( + " Source spaces transformed to the forward solution " "coordinate frame" + ) + fwd["src"] = src # Handle the source locations and orientations - fwd['source_rr'] = np.concatenate([ss['rr'][ss['vertno'], :] - for ss in src], axis=0) + fwd["source_rr"] = np.concatenate([ss["rr"][ss["vertno"], :] for ss in src], axis=0) # Store original source orientations - fwd['_orig_source_ori'] = fwd['source_ori'] + fwd["_orig_source_ori"] = fwd["source_ori"] # Deal with include and exclude pick_channels_forward(fwd, include=include, exclude=exclude, copy=False) if is_fixed_orient(fwd, orig=True): - fwd['source_nn'] = np.concatenate([_src['nn'][_src['vertno'], :] - for _src in fwd['src']], axis=0) - fwd['source_ori'] = FIFF.FIFFV_MNE_FIXED_ORI - fwd['surf_ori'] = True + fwd["source_nn"] = np.concatenate( + [_src["nn"][_src["vertno"], :] for _src in fwd["src"]], axis=0 + ) + fwd["source_ori"] = FIFF.FIFFV_MNE_FIXED_ORI + fwd["surf_ori"] = True else: - fwd['source_nn'] = np.kron(np.ones((fwd['nsource'], 1)), np.eye(3)) - fwd['source_ori'] = FIFF.FIFFV_MNE_FREE_ORI - fwd['surf_ori'] = False + fwd["source_nn"] = np.kron(np.ones((fwd["nsource"], 1)), np.eye(3)) + fwd["source_ori"] = FIFF.FIFFV_MNE_FREE_ORI + fwd["surf_ori"] = False return Forward(fwd) @verbose -def convert_forward_solution(fwd, surf_ori=False, force_fixed=False, - copy=True, use_cps=True, *, verbose=None): +def convert_forward_solution( + fwd, surf_ori=False, force_fixed=False, copy=True, use_cps=True, *, verbose=None +): """Convert forward solution between different source orientations. Parameters @@ -690,28 +738,34 @@ def convert_forward_solution(fwd, surf_ori=False, force_fixed=False, The modified forward solution. """ from scipy import sparse + fwd = fwd.copy() if copy else fwd if force_fixed is True: surf_ori = True - if any([src['type'] == 'vol' for src in fwd['src']]) and force_fixed: + if any([src["type"] == "vol" for src in fwd["src"]]) and force_fixed: raise ValueError( - 'Forward operator was generated with sources from a ' - 'volume source space. Conversion to fixed orientation is not ' - 'possible. Consider using a discrete source space if you have ' - 'meaningful normal orientations.') + "Forward operator was generated with sources from a " + "volume source space. Conversion to fixed orientation is not " + "possible. Consider using a discrete source space if you have " + "meaningful normal orientations." + ) if surf_ori and use_cps: - if any(s.get('patch_inds') is not None for s in fwd['src']): - logger.info(' Average patch normals will be employed in ' - 'the rotation to the local surface coordinates..' - '..') + if any(s.get("patch_inds") is not None for s in fwd["src"]): + logger.info( + " Average patch normals will be employed in " + "the rotation to the local surface coordinates.." + ".." + ) else: use_cps = False - logger.info(' No patch info available. The standard source ' - 'space normals will be employed in the rotation ' - 'to the local surface coordinates....') + logger.info( + " No patch info available. The standard source " + "space normals will be employed in the rotation " + "to the local surface coordinates...." + ) # We need to change these entries (only): # 1. source_nn @@ -723,78 +777,79 @@ def convert_forward_solution(fwd, surf_ori=False, force_fixed=False, if is_fixed_orient(fwd, orig=True) or (force_fixed and not use_cps): # Fixed - fwd['source_nn'] = np.concatenate([_get_src_nn(s, use_cps) - for s in fwd['src']], axis=0) + fwd["source_nn"] = np.concatenate( + [_get_src_nn(s, use_cps) for s in fwd["src"]], axis=0 + ) if not is_fixed_orient(fwd, orig=True): - logger.info(' Changing to fixed-orientation forward ' - 'solution with surface-based source orientations...') - fix_rot = _block_diag(fwd['source_nn'].T, 1) + logger.info( + " Changing to fixed-orientation forward " + "solution with surface-based source orientations..." + ) + fix_rot = _block_diag(fwd["source_nn"].T, 1) # newer versions of numpy require explicit casting here, so *= no # longer works - fwd['sol']['data'] = (fwd['_orig_sol'] * - fix_rot).astype('float32') - fwd['sol']['ncol'] = fwd['nsource'] - if fwd['sol_grad'] is not None: + fwd["sol"]["data"] = (fwd["_orig_sol"] * fix_rot).astype("float32") + fwd["sol"]["ncol"] = fwd["nsource"] + if fwd["sol_grad"] is not None: x = sparse.block_diag([fix_rot] * 3) - fwd['sol_grad']['data'] = fwd['_orig_sol_grad'] * x # dot prod - fwd['sol_grad']['ncol'] = 3 * fwd['nsource'] - fwd['source_ori'] = FIFF.FIFFV_MNE_FIXED_ORI - fwd['surf_ori'] = True + fwd["sol_grad"]["data"] = fwd["_orig_sol_grad"] * x # dot prod + fwd["sol_grad"]["ncol"] = 3 * fwd["nsource"] + fwd["source_ori"] = FIFF.FIFFV_MNE_FIXED_ORI + fwd["surf_ori"] = True elif surf_ori: # Free, surf-oriented # Rotate the local source coordinate systems - fwd['source_nn'] = np.kron(np.ones((fwd['nsource'], 1)), np.eye(3)) - logger.info(' Converting to surface-based source orientations...') + fwd["source_nn"] = np.kron(np.ones((fwd["nsource"], 1)), np.eye(3)) + logger.info(" Converting to surface-based source orientations...") # Actually determine the source orientations pp = 0 - for s in fwd['src']: - if s['type'] in ['surf', 'discrete']: + for s in fwd["src"]: + if s["type"] in ["surf", "discrete"]: nn = _get_src_nn(s, use_cps) - stop = pp + 3 * s['nuse'] - fwd['source_nn'][pp:stop] = _normal_orth(nn).reshape(-1, 3) + stop = pp + 3 * s["nuse"] + fwd["source_nn"][pp:stop] = _normal_orth(nn).reshape(-1, 3) pp = stop del nn else: - pp += 3 * s['nuse'] + pp += 3 * s["nuse"] # Rotate the solution components as well if force_fixed: - fwd['source_nn'] = fwd['source_nn'][2::3, :] - fix_rot = _block_diag(fwd['source_nn'].T, 1) + fwd["source_nn"] = fwd["source_nn"][2::3, :] + fix_rot = _block_diag(fwd["source_nn"].T, 1) # newer versions of numpy require explicit casting here, so *= no # longer works - fwd['sol']['data'] = (fwd['_orig_sol'] * - fix_rot).astype('float32') - fwd['sol']['ncol'] = fwd['nsource'] - if fwd['sol_grad'] is not None: + fwd["sol"]["data"] = (fwd["_orig_sol"] * fix_rot).astype("float32") + fwd["sol"]["ncol"] = fwd["nsource"] + if fwd["sol_grad"] is not None: x = sparse.block_diag([fix_rot] * 3) - fwd['sol_grad']['data'] = fwd['_orig_sol_grad'] * x # dot prod - fwd['sol_grad']['ncol'] = 3 * fwd['nsource'] - fwd['source_ori'] = FIFF.FIFFV_MNE_FIXED_ORI - fwd['surf_ori'] = True + fwd["sol_grad"]["data"] = fwd["_orig_sol_grad"] * x # dot prod + fwd["sol_grad"]["ncol"] = 3 * fwd["nsource"] + fwd["source_ori"] = FIFF.FIFFV_MNE_FIXED_ORI + fwd["surf_ori"] = True else: - surf_rot = _block_diag(fwd['source_nn'].T, 3) - fwd['sol']['data'] = fwd['_orig_sol'] * surf_rot - fwd['sol']['ncol'] = 3 * fwd['nsource'] - if fwd['sol_grad'] is not None: + surf_rot = _block_diag(fwd["source_nn"].T, 3) + fwd["sol"]["data"] = fwd["_orig_sol"] * surf_rot + fwd["sol"]["ncol"] = 3 * fwd["nsource"] + if fwd["sol_grad"] is not None: x = sparse.block_diag([surf_rot] * 3) - fwd['sol_grad']['data'] = fwd['_orig_sol_grad'] * x # dot prod - fwd['sol_grad']['ncol'] = 9 * fwd['nsource'] - fwd['source_ori'] = FIFF.FIFFV_MNE_FREE_ORI - fwd['surf_ori'] = True + fwd["sol_grad"]["data"] = fwd["_orig_sol_grad"] * x # dot prod + fwd["sol_grad"]["ncol"] = 9 * fwd["nsource"] + fwd["source_ori"] = FIFF.FIFFV_MNE_FREE_ORI + fwd["surf_ori"] = True else: # Free, cartesian - logger.info(' Cartesian source orientations...') - fwd['source_nn'] = np.tile(np.eye(3), (fwd['nsource'], 1)) - fwd['sol']['data'] = fwd['_orig_sol'].copy() - fwd['sol']['ncol'] = 3 * fwd['nsource'] - if fwd['sol_grad'] is not None: - fwd['sol_grad']['data'] = fwd['_orig_sol_grad'].copy() - fwd['sol_grad']['ncol'] = 9 * fwd['nsource'] - fwd['source_ori'] = FIFF.FIFFV_MNE_FREE_ORI - fwd['surf_ori'] = False - - logger.info(' [done]') + logger.info(" Cartesian source orientations...") + fwd["source_nn"] = np.tile(np.eye(3), (fwd["nsource"], 1)) + fwd["sol"]["data"] = fwd["_orig_sol"].copy() + fwd["sol"]["ncol"] = 3 * fwd["nsource"] + if fwd["sol_grad"] is not None: + fwd["sol_grad"]["data"] = fwd["_orig_sol_grad"].copy() + fwd["sol_grad"]["ncol"] = 9 * fwd["nsource"] + fwd["source_ori"] = FIFF.FIFFV_MNE_FREE_ORI + fwd["surf_ori"] = False + + logger.info(" [done]") return fwd @@ -832,8 +887,9 @@ def write_forward_solution(fname, fwd, overwrite=False, verbose=None): surface-based, fixed orientation cannot be reverted after loading the forward solution with :func:`read_forward_solution`. """ - check_fname(fname, 'forward', ('-fwd.fif', '-fwd.fif.gz', - '_fwd.fif', '_fwd.fif.gz')) + check_fname( + fname, "forward", ("-fwd.fif", "-fwd.fif.gz", "_fwd.fif", "_fwd.fif.gz") + ) # check for file existence and expand `~` if present fname = _check_fname(fname, overwrite) @@ -849,10 +905,10 @@ def _write_forward_solution(fid, fwd): # start_block(fid, FIFF.FIFFB_MNE_ENV) write_id(fid, FIFF.FIFF_BLOCK_ID) - data = fwd['info'].get('working_dir', None) + data = fwd["info"].get("working_dir", None) if data is not None: write_string(fid, FIFF.FIFF_MNE_ENV_WORKING_DIR, data) - data = fwd['info'].get('command_line', None) + data = fwd["info"].get("command_line", None) if data is not None: write_string(fid, FIFF.FIFF_MNE_ENV_COMMAND_LINE, data) end_block(fid, FIFF.FIFFB_MNE_ENV) @@ -861,118 +917,138 @@ def _write_forward_solution(fid, fwd): # Information from the MRI file # start_block(fid, FIFF.FIFFB_MNE_PARENT_MRI_FILE) - write_string(fid, FIFF.FIFF_MNE_FILE_NAME, fwd['info']['mri_file']) - if fwd['info']['mri_id'] is not None: - write_id(fid, FIFF.FIFF_PARENT_FILE_ID, fwd['info']['mri_id']) + write_string(fid, FIFF.FIFF_MNE_FILE_NAME, fwd["info"]["mri_file"]) + if fwd["info"]["mri_id"] is not None: + write_id(fid, FIFF.FIFF_PARENT_FILE_ID, fwd["info"]["mri_id"]) # store the MRI to HEAD transform in MRI file - write_coord_trans(fid, fwd['info']['mri_head_t']) + write_coord_trans(fid, fwd["info"]["mri_head_t"]) end_block(fid, FIFF.FIFFB_MNE_PARENT_MRI_FILE) # write measurement info - write_forward_meas_info(fid, fwd['info']) + write_forward_meas_info(fid, fwd["info"]) # invert our original source space transform src = list() - for s in fwd['src']: + for s in fwd["src"]: s = deepcopy(s) try: # returns source space to original coordinate frame # usually MRI - s = transform_surface_to(s, fwd['mri_head_t']['from'], - fwd['mri_head_t']) + s = transform_surface_to(s, fwd["mri_head_t"]["from"], fwd["mri_head_t"]) except Exception as inst: - raise ValueError('Could not transform source space (%s)' % inst) + raise ValueError("Could not transform source space (%s)" % inst) src.append(s) # # Write the source spaces (again) # _write_source_spaces_to_fid(fid, src) - n_vert = sum([ss['nuse'] for ss in src]) - if fwd['_orig_source_ori'] == FIFF.FIFFV_MNE_FIXED_ORI: + n_vert = sum([ss["nuse"] for ss in src]) + if fwd["_orig_source_ori"] == FIFF.FIFFV_MNE_FIXED_ORI: n_col = n_vert else: n_col = 3 * n_vert # Undo transformations - sol = fwd['_orig_sol'].copy() - if fwd['sol_grad'] is not None: - sol_grad = fwd['_orig_sol_grad'].copy() + sol = fwd["_orig_sol"].copy() + if fwd["sol_grad"] is not None: + sol_grad = fwd["_orig_sol_grad"].copy() else: sol_grad = None - if fwd['surf_ori'] is True: - if fwd['_orig_source_ori'] == FIFF.FIFFV_MNE_FIXED_ORI: - warn('The forward solution, which is stored on disk now, is based ' - 'on a forward solution with fixed orientation. Please note ' - 'that the transformation to surface-based, fixed orientation ' - 'cannot be reverted after loading the forward solution with ' - 'read_forward_solution.', RuntimeWarning) + if fwd["surf_ori"] is True: + if fwd["_orig_source_ori"] == FIFF.FIFFV_MNE_FIXED_ORI: + warn( + "The forward solution, which is stored on disk now, is based " + "on a forward solution with fixed orientation. Please note " + "that the transformation to surface-based, fixed orientation " + "cannot be reverted after loading the forward solution with " + "read_forward_solution.", + RuntimeWarning, + ) else: - warn('This forward solution is based on a forward solution with ' - 'free orientation. The original forward solution is stored ' - 'on disk in X/Y/Z RAS coordinates. Any transformation ' - '(surface orientation or fixed orientation) will be ' - 'reverted. To reapply any transformation to the forward ' - 'operator please apply convert_forward_solution after ' - 'reading the forward solution with read_forward_solution.', - RuntimeWarning) + warn( + "This forward solution is based on a forward solution with " + "free orientation. The original forward solution is stored " + "on disk in X/Y/Z RAS coordinates. Any transformation " + "(surface orientation or fixed orientation) will be " + "reverted. To reapply any transformation to the forward " + "operator please apply convert_forward_solution after " + "reading the forward solution with read_forward_solution.", + RuntimeWarning, + ) # # MEG forward solution # - picks_meg = pick_types(fwd['info'], meg=True, eeg=False, ref_meg=False, - exclude=[]) - picks_eeg = pick_types(fwd['info'], meg=False, eeg=True, ref_meg=False, - exclude=[]) + picks_meg = pick_types(fwd["info"], meg=True, eeg=False, ref_meg=False, exclude=[]) + picks_eeg = pick_types(fwd["info"], meg=False, eeg=True, ref_meg=False, exclude=[]) n_meg = len(picks_meg) n_eeg = len(picks_eeg) - row_names_meg = [fwd['sol']['row_names'][p] for p in picks_meg] - row_names_eeg = [fwd['sol']['row_names'][p] for p in picks_eeg] + row_names_meg = [fwd["sol"]["row_names"][p] for p in picks_meg] + row_names_eeg = [fwd["sol"]["row_names"][p] for p in picks_eeg] if n_meg > 0: - meg_solution = dict(data=sol[picks_meg], nrow=n_meg, ncol=n_col, - row_names=row_names_meg, col_names=[]) + meg_solution = dict( + data=sol[picks_meg], + nrow=n_meg, + ncol=n_col, + row_names=row_names_meg, + col_names=[], + ) _transpose_named_matrix(meg_solution) start_block(fid, FIFF.FIFFB_MNE_FORWARD_SOLUTION) write_int(fid, FIFF.FIFF_MNE_INCLUDED_METHODS, FIFF.FIFFV_MNE_MEG) - write_int(fid, FIFF.FIFF_MNE_COORD_FRAME, fwd['coord_frame']) - write_int(fid, FIFF.FIFF_MNE_SOURCE_ORIENTATION, - fwd['_orig_source_ori']) + write_int(fid, FIFF.FIFF_MNE_COORD_FRAME, fwd["coord_frame"]) + write_int(fid, FIFF.FIFF_MNE_SOURCE_ORIENTATION, fwd["_orig_source_ori"]) write_int(fid, FIFF.FIFF_MNE_SOURCE_SPACE_NPOINTS, n_vert) write_int(fid, FIFF.FIFF_NCHAN, n_meg) write_named_matrix(fid, FIFF.FIFF_MNE_FORWARD_SOLUTION, meg_solution) if sol_grad is not None: - meg_solution_grad = dict(data=sol_grad[picks_meg], - nrow=n_meg, ncol=n_col * 3, - row_names=row_names_meg, col_names=[]) + meg_solution_grad = dict( + data=sol_grad[picks_meg], + nrow=n_meg, + ncol=n_col * 3, + row_names=row_names_meg, + col_names=[], + ) _transpose_named_matrix(meg_solution_grad) - write_named_matrix(fid, FIFF.FIFF_MNE_FORWARD_SOLUTION_GRAD, - meg_solution_grad) + write_named_matrix( + fid, FIFF.FIFF_MNE_FORWARD_SOLUTION_GRAD, meg_solution_grad + ) end_block(fid, FIFF.FIFFB_MNE_FORWARD_SOLUTION) # # EEG forward solution # if n_eeg > 0: - eeg_solution = dict(data=sol[picks_eeg], nrow=n_eeg, ncol=n_col, - row_names=row_names_eeg, col_names=[]) + eeg_solution = dict( + data=sol[picks_eeg], + nrow=n_eeg, + ncol=n_col, + row_names=row_names_eeg, + col_names=[], + ) _transpose_named_matrix(eeg_solution) start_block(fid, FIFF.FIFFB_MNE_FORWARD_SOLUTION) write_int(fid, FIFF.FIFF_MNE_INCLUDED_METHODS, FIFF.FIFFV_MNE_EEG) - write_int(fid, FIFF.FIFF_MNE_COORD_FRAME, fwd['coord_frame']) - write_int(fid, FIFF.FIFF_MNE_SOURCE_ORIENTATION, - fwd['_orig_source_ori']) + write_int(fid, FIFF.FIFF_MNE_COORD_FRAME, fwd["coord_frame"]) + write_int(fid, FIFF.FIFF_MNE_SOURCE_ORIENTATION, fwd["_orig_source_ori"]) write_int(fid, FIFF.FIFF_NCHAN, n_eeg) write_int(fid, FIFF.FIFF_MNE_SOURCE_SPACE_NPOINTS, n_vert) write_named_matrix(fid, FIFF.FIFF_MNE_FORWARD_SOLUTION, eeg_solution) if sol_grad is not None: - eeg_solution_grad = dict(data=sol_grad[picks_eeg], - nrow=n_eeg, ncol=n_col * 3, - row_names=row_names_eeg, col_names=[]) + eeg_solution_grad = dict( + data=sol_grad[picks_eeg], + nrow=n_eeg, + ncol=n_col * 3, + row_names=row_names_eeg, + col_names=[], + ) _transpose_named_matrix(eeg_solution_grad) - write_named_matrix(fid, FIFF.FIFF_MNE_FORWARD_SOLUTION_GRAD, - eeg_solution_grad) + write_named_matrix( + fid, FIFF.FIFF_MNE_FORWARD_SOLUTION_GRAD, eeg_solution_grad + ) end_block(fid, FIFF.FIFFB_MNE_FORWARD_SOLUTION) end_block(fid, FIFF.FIFFB_MNE) @@ -995,9 +1071,9 @@ def is_fixed_orient(forward, orig=False): Whether or not it is fixed orientation. """ if orig: # if we want to know about the original version - fixed_ori = (forward['_orig_source_ori'] == FIFF.FIFFV_MNE_FIXED_ORI) + fixed_ori = forward["_orig_source_ori"] == FIFF.FIFFV_MNE_FIXED_ORI else: # most of the time we want to know about the current version - fixed_ori = (forward['source_ori'] == FIFF.FIFFV_MNE_FIXED_ORI) + fixed_ori = forward["source_ori"] == FIFF.FIFFV_MNE_FIXED_ORI return fixed_ori @@ -1016,25 +1092,25 @@ def write_forward_meas_info(fid, info): # Information from the MEG file # start_block(fid, FIFF.FIFFB_MNE_PARENT_MEAS_FILE) - write_string(fid, FIFF.FIFF_MNE_FILE_NAME, info['meas_file']) - if info['meas_id'] is not None: - write_id(fid, FIFF.FIFF_PARENT_BLOCK_ID, info['meas_id']) + write_string(fid, FIFF.FIFF_MNE_FILE_NAME, info["meas_file"]) + if info["meas_id"] is not None: + write_id(fid, FIFF.FIFF_PARENT_BLOCK_ID, info["meas_id"]) # get transformation from CTF and DEVICE to HEAD coordinate frame - meg_head_t = info.get('dev_head_t', info.get('ctf_head_t')) + meg_head_t = info.get("dev_head_t", info.get("ctf_head_t")) if meg_head_t is None: fid.close() - raise ValueError('Head<-->sensor transform not found') + raise ValueError("Head<-->sensor transform not found") write_coord_trans(fid, meg_head_t) ch_names_mapping = dict() - if 'chs' in info: + if "chs" in info: # Channel information - ch_names_mapping = _make_ch_names_mapping(info['chs']) - write_int(fid, FIFF.FIFF_NCHAN, len(info['chs'])) - _write_ch_infos(fid, info['chs'], False, ch_names_mapping) - if 'bads' in info and len(info['bads']) > 0: + ch_names_mapping = _make_ch_names_mapping(info["chs"]) + write_int(fid, FIFF.FIFF_NCHAN, len(info["chs"])) + _write_ch_infos(fid, info["chs"], False, ch_names_mapping) + if "bads" in info and len(info["bads"]) > 0: # Bad channels - _write_bad_channels(fid, info['bads'], ch_names_mapping) + _write_bad_channels(fid, info["bads"], ch_names_mapping) end_block(fid, FIFF.FIFFB_MNE_PARENT_MEAS_FILE) @@ -1042,82 +1118,89 @@ def write_forward_meas_info(fid, info): def _select_orient_forward(forward, info, noise_cov=None, copy=True): """Prepare forward solution for inverse solvers.""" # fwd['sol']['row_names'] may be different order from fwd['info']['chs'] - fwd_sol_ch_names = forward['sol']['row_names'] + fwd_sol_ch_names = forward["sol"]["row_names"] all_ch_names = set(fwd_sol_ch_names) - all_bads = set(info['bads']) + all_bads = set(info["bads"]) if noise_cov is not None: - all_ch_names &= set(noise_cov['names']) - all_bads |= set(noise_cov['bads']) + all_ch_names &= set(noise_cov["names"]) + all_bads |= set(noise_cov["bads"]) else: - noise_cov = dict(bads=info['bads']) - ch_names = [c['ch_name'] for c in info['chs'] - if c['ch_name'] not in all_bads and - c['ch_name'] in all_ch_names] - - if not len(info['bads']) == len(noise_cov['bads']) or \ - not all(b in noise_cov['bads'] for b in info['bads']): - logger.info('info["bads"] and noise_cov["bads"] do not match, ' - 'excluding bad channels from both') + noise_cov = dict(bads=info["bads"]) + ch_names = [ + c["ch_name"] + for c in info["chs"] + if c["ch_name"] not in all_bads and c["ch_name"] in all_ch_names + ] + + if not len(info["bads"]) == len(noise_cov["bads"]) or not all( + b in noise_cov["bads"] for b in info["bads"] + ): + logger.info( + 'info["bads"] and noise_cov["bads"] do not match, ' + "excluding bad channels from both" + ) # check the compensation grade - _check_compensation_grade(forward['info'], info, 'forward') + _check_compensation_grade(forward["info"], info, "forward") n_chan = len(ch_names) logger.info("Computing inverse operator with %d channels." % n_chan) - forward = pick_channels_forward(forward, ch_names, ordered=True, - copy=copy) - info_idx = [info['ch_names'].index(name) for name in ch_names] + forward = pick_channels_forward(forward, ch_names, ordered=True, copy=copy) + info_idx = [info["ch_names"].index(name) for name in ch_names] info_picked = pick_info(info, info_idx) - forward['info']._check_consistency() + forward["info"]._check_consistency() info_picked._check_consistency() return forward, info_picked -def _triage_loose(src, loose, fixed='auto'): - _validate_type(loose, (str, dict, 'numeric'), 'loose') - _validate_type(fixed, (str, bool), 'fixed') +def _triage_loose(src, loose, fixed="auto"): + _validate_type(loose, (str, dict, "numeric"), "loose") + _validate_type(fixed, (str, bool), "fixed") orig_loose = loose if isinstance(loose, str): - _check_option('loose', loose, ('auto',)) + _check_option("loose", loose, ("auto",)) if fixed is True: - loose = 0. + loose = 0.0 else: # False or auto - loose = 0.2 if src.kind == 'surface' else 1. - src_types = set(_src_kind_dict[s['type']] for s in src) + loose = 0.2 if src.kind == "surface" else 1.0 + src_types = set(_src_kind_dict[s["type"]] for s in src) if not isinstance(loose, dict): loose = float(loose) loose = {key: loose for key in src_types} loose_keys = set(loose.keys()) if loose_keys != src_types: raise ValueError( - f'loose, if dict, must have keys {sorted(src_types)} to match the ' - f'source space, got {sorted(loose_keys)}') + f"loose, if dict, must have keys {sorted(src_types)} to match the " + f"source space, got {sorted(loose_keys)}" + ) # if fixed is auto it can be ignored, if it's False it can be ignored, # only really need to care about fixed=True if fixed is True: - if not all(v == 0. for v in loose.values()): + if not all(v == 0.0 for v in loose.values()): raise ValueError( 'When using fixed=True, loose must be 0. or "auto", ' - f'got {orig_loose}') + f"got {orig_loose}" + ) elif fixed is False: - if any(v == 0. for v in loose.values()): + if any(v == 0.0 for v in loose.values()): raise ValueError( - 'If loose==0., then fixed must be True or "auto", got False') + 'If loose==0., then fixed must be True or "auto", got False' + ) del fixed for key, this_loose in loose.items(): - if key not in ('surface', 'discrete') and this_loose != 1: + if key not in ("surface", "discrete") and this_loose != 1: raise ValueError( 'loose parameter has to be 1 or "auto" for non-surface/' - f'discrete source spaces, got loose["{key}"] = {this_loose}') + f'discrete source spaces, got loose["{key}"] = {this_loose}' + ) if not 0 <= this_loose <= 1: - raise ValueError( - f'loose ({key}) must be between 0 and 1, got {this_loose}') + raise ValueError(f"loose ({key}) must be between 0 and 1, got {this_loose}") return loose @verbose -def compute_orient_prior(forward, loose='auto', verbose=None): +def compute_orient_prior(forward, loose="auto", verbose=None): """Compute orientation prior. Parameters @@ -1136,40 +1219,46 @@ def compute_orient_prior(forward, loose='auto', verbose=None): -------- compute_depth_prior """ - _validate_type(forward, Forward, 'forward') - n_sources = forward['sol']['data'].shape[1] + _validate_type(forward, Forward, "forward") + n_sources = forward["sol"]["data"].shape[1] - loose = _triage_loose(forward['src'], loose) + loose = _triage_loose(forward["src"], loose) orient_prior = np.ones(n_sources, dtype=np.float64) if is_fixed_orient(forward): - if any(v > 0. for v in loose.values()): - raise ValueError('loose must be 0. with forward operator ' - 'with fixed orientation, got %s' % (loose,)) + if any(v > 0.0 for v in loose.values()): + raise ValueError( + "loose must be 0. with forward operator " + "with fixed orientation, got %s" % (loose,) + ) return orient_prior - if all(v == 1. for v in loose.values()): + if all(v == 1.0 for v in loose.values()): return orient_prior # We actually need non-unity prior, compute it for each source space # separately - if not forward['surf_ori']: - raise ValueError('Forward operator is not oriented in surface ' - 'coordinates. loose parameter should be 1. ' - 'not %s.' % (loose,)) + if not forward["surf_ori"]: + raise ValueError( + "Forward operator is not oriented in surface " + "coordinates. loose parameter should be 1. " + "not %s." % (loose,) + ) start = 0 logged = dict() - for s in forward['src']: - this_type = _src_kind_dict[s['type']] + for s in forward["src"]: + this_type = _src_kind_dict[s["type"]] use_loose = loose[this_type] if not logged.get(this_type): - if use_loose == 1.: - name = 'free' + if use_loose == 1.0: + name = "free" else: - name = 'fixed' if use_loose == 0. else 'loose' - logger.info(f'Applying {name.ljust(5)} dipole orientations to ' - f'{this_type.ljust(7)} source spaces: {use_loose}') + name = "fixed" if use_loose == 0.0 else "loose" + logger.info( + f"Applying {name.ljust(5)} dipole orientations to " + f"{this_type.ljust(7)} source spaces: {use_loose}" + ) logged[this_type] = True - stop = start + 3 * s['nuse'] + stop = start + 3 * s["nuse"] orient_prior[start:stop:3] *= use_loose - orient_prior[start + 1:stop:3] *= use_loose + orient_prior[start + 1 : stop : 3] *= use_loose start = stop return orient_prior @@ -1177,27 +1266,38 @@ def compute_orient_prior(forward, loose='auto', verbose=None): def _restrict_gain_matrix(G, info): """Restrict gain matrix entries for optimal depth weighting.""" # Figure out which ones have been used - if len(info['chs']) != G.shape[0]: - raise ValueError('G.shape[0] (%d) and length of info["chs"] (%d) ' - 'do not match' % (G.shape[0], len(info['chs']))) + if len(info["chs"]) != G.shape[0]: + raise ValueError( + 'G.shape[0] (%d) and length of info["chs"] (%d) ' + "do not match" % (G.shape[0], len(info["chs"])) + ) for meg, eeg, kind in ( - ('grad', False, 'planar'), - ('mag', False, 'magnetometer or axial gradiometer'), - (False, True, 'EEG')): + ("grad", False, "planar"), + ("mag", False, "magnetometer or axial gradiometer"), + (False, True, "EEG"), + ): sel = pick_types(info, meg=meg, eeg=eeg, ref_meg=False, exclude=[]) if len(sel) > 0: - logger.info(' %d %s channels' % (len(sel), kind)) + logger.info(" %d %s channels" % (len(sel), kind)) break else: - warn('Could not find MEG or EEG channels to limit depth channels') + warn("Could not find MEG or EEG channels to limit depth channels") sel = slice(None) return G[sel] @verbose -def compute_depth_prior(forward, info, exp=0.8, limit=10.0, - limit_depth_chs=False, combine_xyz='spectral', - noise_cov=None, rank=None, verbose=None): +def compute_depth_prior( + forward, + info, + exp=0.8, + limit=10.0, + limit_depth_chs=False, + combine_xyz="spectral", + noise_cov=None, + rank=None, + verbose=None, +): """Compute depth prior for depth weighting. Parameters @@ -1278,39 +1378,44 @@ def compute_depth_prior(forward, info, exp=0.8, limit=10.0, SI units (such as EEG being orders of magnitude larger than MEG). """ from ..cov import Covariance, compute_whitener - _validate_type(forward, Forward, 'forward') - patch_areas = forward.get('patch_areas', None) + + _validate_type(forward, Forward, "forward") + patch_areas = forward.get("patch_areas", None) is_fixed_ori = is_fixed_orient(forward) - G = forward['sol']['data'] - logger.info('Creating the depth weighting matrix...') - _validate_type(noise_cov, (Covariance, None), 'noise_cov', - 'Covariance or None') - _validate_type(limit_depth_chs, (str, bool), 'limit_depth_chs') + G = forward["sol"]["data"] + logger.info("Creating the depth weighting matrix...") + _validate_type(noise_cov, (Covariance, None), "noise_cov", "Covariance or None") + _validate_type(limit_depth_chs, (str, bool), "limit_depth_chs") if isinstance(limit_depth_chs, str): - if limit_depth_chs != 'whiten': - raise ValueError('limit_depth_chs, if str, must be "whiten", got ' - '%s' % (limit_depth_chs,)) + if limit_depth_chs != "whiten": + raise ValueError( + 'limit_depth_chs, if str, must be "whiten", got ' + "%s" % (limit_depth_chs,) + ) if not isinstance(noise_cov, Covariance): - raise ValueError('With limit_depth_chs="whiten", noise_cov must be' - ' a Covariance, got %s' % (type(noise_cov),)) + raise ValueError( + 'With limit_depth_chs="whiten", noise_cov must be' + " a Covariance, got %s" % (type(noise_cov),) + ) if combine_xyz is not False: # private / expert option - _check_option('combine_xyz', combine_xyz, ('fro', 'spectral')) + _check_option("combine_xyz", combine_xyz, ("fro", "spectral")) # If possible, pick best depth-weighting channels if limit_depth_chs is True: G = _restrict_gain_matrix(G, info) - elif limit_depth_chs == 'whiten': - whitener, _ = compute_whitener(noise_cov, info, pca=True, rank=rank, - verbose=False) + elif limit_depth_chs == "whiten": + whitener, _ = compute_whitener( + noise_cov, info, pca=True, rank=rank, verbose=False + ) G = np.dot(whitener, G) # Compute the gain matrix - if is_fixed_ori or combine_xyz in ('fro', False): - d = np.sum(G ** 2, axis=0) + if is_fixed_ori or combine_xyz in ("fro", False): + d = np.sum(G**2, axis=0) if not (is_fixed_ori or combine_xyz is False): d = d.reshape(-1, 3).sum(axis=1) # Spherical leadfield can be zero at the center - d[d == 0.] = np.min(d[d != 0.]) + d[d == 0.0] = np.min(d[d != 0.0]) else: # 'spectral' # n_pos = G.shape[1] // 3 # The following is equivalent to this, but 4-10x faster @@ -1320,22 +1425,22 @@ def compute_depth_prior(forward, info, exp=0.8, limit=10.0, # x = np.dot(Gk.T, Gk) # d[k] = linalg.svdvals(x)[0] G.shape = (G.shape[0], -1, 3) - d = np.linalg.norm(np.einsum('svj,svk->vjk', G, G), # vector dot prods - ord=2, axis=(1, 2)) # ord=2 spectral (largest s.v.) + d = np.linalg.norm( + np.einsum("svj,svk->vjk", G, G), ord=2, axis=(1, 2) # vector dot prods + ) # ord=2 spectral (largest s.v.) G.shape = (G.shape[0], -1) # XXX Currently the fwd solns never have "patch_areas" defined if patch_areas is not None: if not is_fixed_ori and combine_xyz is False: patch_areas = np.repeat(patch_areas, 3) - d /= patch_areas ** 2 - logger.info(' Patch areas taken into account in the depth ' - 'weighting') + d /= patch_areas**2 + logger.info(" Patch areas taken into account in the depth " "weighting") w = 1.0 / d if limit is not None: ws = np.sort(w) - weight_limit = limit ** 2 + weight_limit = limit**2 if limit_depth_chs is False: # match old mne-python behavior # we used to do ind = np.argmin(ws), but this is 0 by sort above @@ -1350,13 +1455,13 @@ def compute_depth_prior(forward, info, exp=0.8, limit=10.0, limit = ws[ind] n_limit = ind - logger.info(' limit = %d/%d = %f' - % (n_limit + 1, len(d), - np.sqrt(limit / ws[0]))) + logger.info( + " limit = %d/%d = %f" % (n_limit + 1, len(d), np.sqrt(limit / ws[0])) + ) scale = 1.0 / limit - logger.info(' scale = %g exp = %g' % (scale, exp)) + logger.info(" scale = %g exp = %g" % (scale, exp)) w = np.minimum(w / limit, 1) - depth_prior = w ** exp + depth_prior = w**exp if not (is_fixed_ori or combine_xyz is False): depth_prior = np.repeat(depth_prior, 3) @@ -1364,8 +1469,9 @@ def compute_depth_prior(forward, info, exp=0.8, limit=10.0, return depth_prior -def _stc_src_sel(src, stc, on_missing='raise', - extra=', likely due to forward calculations'): +def _stc_src_sel( + src, stc, on_missing="raise", extra=", likely due to forward calculations" +): """Select the vertex indices of a source space using a source estimate.""" if isinstance(stc, list): vertices = stc @@ -1374,14 +1480,16 @@ def _stc_src_sel(src, stc, on_missing='raise', vertices = stc.vertices del stc if not len(src) == len(vertices): - raise RuntimeError('Mismatch between number of source spaces (%s) and ' - 'STC vertices (%s)' % (len(src), len(vertices))) + raise RuntimeError( + "Mismatch between number of source spaces (%s) and " + "STC vertices (%s)" % (len(src), len(vertices)) + ) src_sels, stc_sels, out_vertices = [], [], [] src_offset = stc_offset = 0 for s, v in zip(src, vertices): - joint_sel = np.intersect1d(s['vertno'], v) - src_sels.append(np.searchsorted(s['vertno'], joint_sel) + src_offset) - src_offset += len(s['vertno']) + joint_sel = np.intersect1d(s["vertno"], v) + src_sels.append(np.searchsorted(s["vertno"], joint_sel) + src_offset) + src_offset += len(s["vertno"]) idx = np.searchsorted(v, joint_sel) stc_sels.append(idx + stc_offset) stc_offset += len(v) @@ -1393,20 +1501,21 @@ def _stc_src_sel(src, stc, on_missing='raise', n_stc = sum(len(v) for v in vertices) n_joint = len(src_sel) if n_joint != n_stc: - msg = ('Only %i of %i SourceEstimate %s found in ' - 'source space%s' - % (n_joint, n_stc, 'vertex' if n_stc == 1 else 'vertices', - extra)) + msg = "Only %i of %i SourceEstimate %s found in " "source space%s" % ( + n_joint, + n_stc, + "vertex" if n_stc == 1 else "vertices", + extra, + ) _on_missing(on_missing, msg) return src_sel, stc_sel, out_vertices def _fill_measurement_info(info, fwd, sfreq, data): """Fill the measurement info of a Raw or Evoked object.""" - sel = pick_channels( - info['ch_names'], fwd['sol']['row_names'], ordered=False) + sel = pick_channels(info["ch_names"], fwd["sol"]["row_names"], ordered=False) info = pick_info(info, sel) - info['bads'] = [] + info["bads"] = [] now = time() sec = np.floor(now) @@ -1414,41 +1523,49 @@ def _fill_measurement_info(info, fwd, sfreq, data): # this is probably correct based on what's done in meas_info.py... with info._unlock(check_after=True): - info.update(meas_id=fwd['info']['meas_id'], file_id=info['meas_id'], - meas_date=_stamp_to_dt((int(sec), int(usec))), - highpass=0., lowpass=sfreq / 2., sfreq=sfreq, projs=[]) + info.update( + meas_id=fwd["info"]["meas_id"], + file_id=info["meas_id"], + meas_date=_stamp_to_dt((int(sec), int(usec))), + highpass=0.0, + lowpass=sfreq / 2.0, + sfreq=sfreq, + projs=[], + ) # reorder data (which is in fwd order) to match that of info - order = [fwd['sol']['row_names'].index(name) for name in info['ch_names']] + order = [fwd["sol"]["row_names"].index(name) for name in info["ch_names"]] data = data[order] return info, data @verbose -def _apply_forward(fwd, stc, start=None, stop=None, on_missing='raise', - use_cps=True, verbose=None): +def _apply_forward( + fwd, stc, start=None, stop=None, on_missing="raise", use_cps=True, verbose=None +): """Apply forward model and return data, times, ch_names.""" - _validate_type(stc, _BaseSourceEstimate, 'stc', 'SourceEstimate') - _validate_type(fwd, Forward, 'fwd') + _validate_type(stc, _BaseSourceEstimate, "stc", "SourceEstimate") + _validate_type(fwd, Forward, "fwd") if isinstance(stc, _BaseVectorSourceEstimate): vector = True fwd = convert_forward_solution(fwd, force_fixed=False, surf_ori=False) else: vector = False if not is_fixed_orient(fwd): - fwd = convert_forward_solution(fwd, force_fixed=True, - use_cps=use_cps) + fwd = convert_forward_solution(fwd, force_fixed=True, use_cps=use_cps) if np.all(stc.data > 0): - warn('Source estimate only contains currents with positive values. ' - 'Use pick_ori="normal" when computing the inverse to compute ' - 'currents not current magnitudes.') + warn( + "Source estimate only contains currents with positive values. " + 'Use pick_ori="normal" when computing the inverse to compute ' + "currents not current magnitudes." + ) _check_stc_units(stc) - src_sel, stc_sel, _ = _stc_src_sel(fwd['src'], stc, on_missing=on_missing) - gain = fwd['sol']['data'] + src_sel, stc_sel, _ = _stc_src_sel(fwd["src"], stc, on_missing=on_missing) + gain = fwd["sol"]["data"] stc_sel = slice(None) if len(stc_sel) == len(stc.data) else stc_sel times = stc.times[start:stop].copy() stc_data = stc.data[stc_sel, ..., start:stop].reshape(-1, len(times)) @@ -1458,15 +1575,23 @@ def _apply_forward(fwd, stc, start=None, stop=None, on_missing='raise', gain = gain[:, src_sel].reshape(len(gain), -1) # save some memory if possible - logger.info('Projecting source estimate to sensor space...') + logger.info("Projecting source estimate to sensor space...") data = np.dot(gain, stc_data) - logger.info('[done]') + logger.info("[done]") return data, times @verbose -def apply_forward(fwd, stc, info, start=None, stop=None, use_cps=True, - on_missing='raise', verbose=None): +def apply_forward( + fwd, + stc, + info, + start=None, + stop=None, + use_cps=True, + on_missing="raise", + verbose=None, +): """Project source space currents to sensor space using a forward operator. The sensor space data is computed for all channels present in fwd. Use @@ -1507,19 +1632,22 @@ def apply_forward(fwd, stc, info, start=None, stop=None, use_cps=True, -------- apply_forward_raw: Compute sensor space data and return a Raw object. """ - _validate_type(info, Info, 'info') - _validate_type(fwd, Forward, 'forward') + _validate_type(info, Info, "info") + _validate_type(fwd, Forward, "forward") info._check_consistency() # make sure evoked_template contains all channels in fwd - for ch_name in fwd['sol']['row_names']: - if ch_name not in info['ch_names']: - raise ValueError('Channel %s of forward operator not present in ' - 'evoked_template.' % ch_name) + for ch_name in fwd["sol"]["row_names"]: + if ch_name not in info["ch_names"]: + raise ValueError( + "Channel %s of forward operator not present in " + "evoked_template." % ch_name + ) # project the source estimate to the sensor space - data, times = _apply_forward(fwd, stc, start, stop, on_missing=on_missing, - use_cps=use_cps) + data, times = _apply_forward( + fwd, stc, start, stop, on_missing=on_missing, use_cps=use_cps + ) # fill the measurement info sfreq = float(1.0 / stc.tstep) @@ -1534,8 +1662,16 @@ def apply_forward(fwd, stc, info, start=None, stop=None, use_cps=True, @verbose -def apply_forward_raw(fwd, stc, info, start=None, stop=None, - on_missing='raise', use_cps=True, verbose=None): +def apply_forward_raw( + fwd, + stc, + info, + start=None, + stop=None, + on_missing="raise", + use_cps=True, + verbose=None, +): """Project source space currents to sensor space using a forward operator. The sensor space data is computed for all channels present in fwd. Use @@ -1577,19 +1713,21 @@ def apply_forward_raw(fwd, stc, info, start=None, stop=None, apply_forward: Compute sensor space data and return an Evoked object. """ # make sure info contains all channels in fwd - for ch_name in fwd['sol']['row_names']: - if ch_name not in info['ch_names']: - raise ValueError('Channel %s of forward operator not present in ' - 'info.' % ch_name) + for ch_name in fwd["sol"]["row_names"]: + if ch_name not in info["ch_names"]: + raise ValueError( + "Channel %s of forward operator not present in " "info." % ch_name + ) # project the source estimate to the sensor space - data, times = _apply_forward(fwd, stc, start, stop, on_missing=on_missing, - use_cps=use_cps) + data, times = _apply_forward( + fwd, stc, start, stop, on_missing=on_missing, use_cps=use_cps + ) sfreq = 1.0 / stc.tstep info, data = _fill_measurement_info(info, fwd, sfreq, data) with info._unlock(): - info['projs'] = [] + info["projs"] = [] # store sensor data in Raw object using the info raw = RawArray(data, info, first_samp=int(np.round(times[0] * sfreq))) raw._projector = None @@ -1597,7 +1735,7 @@ def apply_forward_raw(fwd, stc, info, start=None, stop=None, @fill_doc -def restrict_forward_to_stc(fwd, stc, on_missing='ignore'): +def restrict_forward_to_stc(fwd, stc, on_missing="ignore"): """Restrict forward operator to active sources in a source estimate. Parameters @@ -1620,9 +1758,9 @@ def restrict_forward_to_stc(fwd, stc, on_missing='ignore'): -------- restrict_forward_to_label """ - _validate_type(on_missing, str, 'on_missing') - _check_option('on_missing', on_missing, ('ignore', 'warn', 'raise')) - src_sel, _, vertices = _stc_src_sel(fwd['src'], stc, on_missing=on_missing) + _validate_type(on_missing, str, "on_missing") + _check_option("on_missing", on_missing, ("ignore", "warn", "raise")) + src_sel, _, vertices = _stc_src_sel(fwd["src"], stc, on_missing=on_missing) del stc return _restrict_forward_to_src_sel(fwd, src_sel) @@ -1630,46 +1768,47 @@ def restrict_forward_to_stc(fwd, stc, on_missing='ignore'): def _restrict_forward_to_src_sel(fwd, src_sel): fwd_out = deepcopy(fwd) # figure out the vertno we are keeping - idx_sel = np.concatenate([[[si] * len(s['vertno']), s['vertno']] - for si, s in enumerate(fwd['src'])], axis=-1) + idx_sel = np.concatenate( + [[[si] * len(s["vertno"]), s["vertno"]] for si, s in enumerate(fwd["src"])], + axis=-1, + ) assert idx_sel.ndim == 2 and idx_sel.shape[0] == 2 - assert idx_sel.shape[1] == fwd['nsource'] + assert idx_sel.shape[1] == fwd["nsource"] idx_sel = idx_sel[:, src_sel] - fwd_out['source_rr'] = fwd['source_rr'][src_sel] - fwd_out['nsource'] = len(src_sel) + fwd_out["source_rr"] = fwd["source_rr"][src_sel] + fwd_out["nsource"] = len(src_sel) if is_fixed_orient(fwd): idx = src_sel - if fwd['sol_grad'] is not None: + if fwd["sol_grad"] is not None: idx_grad = (3 * src_sel[:, None] + np.arange(3)).ravel() else: idx = (3 * src_sel[:, None] + np.arange(3)).ravel() - if fwd['sol_grad'] is not None: + if fwd["sol_grad"] is not None: idx_grad = (9 * src_sel[:, None] + np.arange(9)).ravel() - fwd_out['source_nn'] = fwd['source_nn'][idx] - fwd_out['sol']['data'] = fwd['sol']['data'][:, idx] - if fwd['sol_grad'] is not None: - fwd_out['sol_grad']['data'] = fwd['sol_grad']['data'][:, idx_grad] - fwd_out['sol']['ncol'] = len(idx) + fwd_out["source_nn"] = fwd["source_nn"][idx] + fwd_out["sol"]["data"] = fwd["sol"]["data"][:, idx] + if fwd["sol_grad"] is not None: + fwd_out["sol_grad"]["data"] = fwd["sol_grad"]["data"][:, idx_grad] + fwd_out["sol"]["ncol"] = len(idx) if is_fixed_orient(fwd, orig=True): idx = src_sel - if fwd['sol_grad'] is not None: + if fwd["sol_grad"] is not None: idx_grad = (3 * src_sel[:, None] + np.arange(3)).ravel() else: idx = (3 * src_sel[:, None] + np.arange(3)).ravel() - if fwd['sol_grad'] is not None: + if fwd["sol_grad"] is not None: idx_grad = (9 * src_sel[:, None] + np.arange(9)).ravel() - fwd_out['_orig_sol'] = fwd['_orig_sol'][:, idx] - if fwd['sol_grad'] is not None: - fwd_out['_orig_sol_grad'] = fwd['_orig_sol_grad'][:, idx_grad] + fwd_out["_orig_sol"] = fwd["_orig_sol"][:, idx] + if fwd["sol_grad"] is not None: + fwd_out["_orig_sol_grad"] = fwd["_orig_sol_grad"][:, idx_grad] - vertices = [idx_sel[1][idx_sel[0] == si] - for si in range(len(fwd_out['src']))] - _set_source_space_vertices(fwd_out['src'], vertices) + vertices = [idx_sel[1][idx_sel[0] == si] for si in range(len(fwd_out["src"]))] + _set_source_space_vertices(fwd_out["src"], vertices) return fwd_out @@ -1701,92 +1840,106 @@ def restrict_forward_to_label(fwd, labels): # Get vertices separately of each hemisphere from all label for label in labels: _validate_type(label, Label, "label", "Label or list") - i = 0 if label.hemi == 'lh' else 1 + i = 0 if label.hemi == "lh" else 1 vertices[i] = np.append(vertices[i], label.vertices) # Remove duplicates and sort vertices = [np.unique(vert_hemi) for vert_hemi in vertices] fwd_out = deepcopy(fwd) - fwd_out['source_rr'] = np.zeros((0, 3)) - fwd_out['nsource'] = 0 - fwd_out['source_nn'] = np.zeros((0, 3)) - fwd_out['sol']['data'] = np.zeros((fwd['sol']['data'].shape[0], 0)) - fwd_out['_orig_sol'] = np.zeros((fwd['_orig_sol'].shape[0], 0)) - if fwd['sol_grad'] is not None: - fwd_out['sol_grad']['data'] = np.zeros( - (fwd['sol_grad']['data'].shape[0], 0)) - fwd_out['_orig_sol_grad'] = np.zeros( - (fwd['_orig_sol_grad'].shape[0], 0)) - fwd_out['sol']['ncol'] = 0 - nuse_lh = fwd['src'][0]['nuse'] + fwd_out["source_rr"] = np.zeros((0, 3)) + fwd_out["nsource"] = 0 + fwd_out["source_nn"] = np.zeros((0, 3)) + fwd_out["sol"]["data"] = np.zeros((fwd["sol"]["data"].shape[0], 0)) + fwd_out["_orig_sol"] = np.zeros((fwd["_orig_sol"].shape[0], 0)) + if fwd["sol_grad"] is not None: + fwd_out["sol_grad"]["data"] = np.zeros((fwd["sol_grad"]["data"].shape[0], 0)) + fwd_out["_orig_sol_grad"] = np.zeros((fwd["_orig_sol_grad"].shape[0], 0)) + fwd_out["sol"]["ncol"] = 0 + nuse_lh = fwd["src"][0]["nuse"] for i in range(2): - fwd_out['src'][i]['vertno'] = np.array([], int) - fwd_out['src'][i]['nuse'] = 0 - fwd_out['src'][i]['inuse'] = fwd['src'][i]['inuse'].copy() - fwd_out['src'][i]['inuse'].fill(0) - fwd_out['src'][i]['use_tris'] = np.array([[]], int) - fwd_out['src'][i]['nuse_tri'] = np.array([0]) + fwd_out["src"][i]["vertno"] = np.array([], int) + fwd_out["src"][i]["nuse"] = 0 + fwd_out["src"][i]["inuse"] = fwd["src"][i]["inuse"].copy() + fwd_out["src"][i]["inuse"].fill(0) + fwd_out["src"][i]["use_tris"] = np.array([[]], int) + fwd_out["src"][i]["nuse_tri"] = np.array([0]) # src_sel is idx to cols in fwd that are in any label per hemi - src_sel = np.intersect1d(fwd['src'][i]['vertno'], vertices[i]) - src_sel = np.searchsorted(fwd['src'][i]['vertno'], src_sel) + src_sel = np.intersect1d(fwd["src"][i]["vertno"], vertices[i]) + src_sel = np.searchsorted(fwd["src"][i]["vertno"], src_sel) # Reconstruct each src - vertno = fwd['src'][i]['vertno'][src_sel] - fwd_out['src'][i]['inuse'][vertno] = 1 - fwd_out['src'][i]['nuse'] += len(vertno) - fwd_out['src'][i]['vertno'] = np.where(fwd_out['src'][i]['inuse'])[0] + vertno = fwd["src"][i]["vertno"][src_sel] + fwd_out["src"][i]["inuse"][vertno] = 1 + fwd_out["src"][i]["nuse"] += len(vertno) + fwd_out["src"][i]["vertno"] = np.where(fwd_out["src"][i]["inuse"])[0] # Reconstruct part of fwd that is not sol data src_sel += i * nuse_lh # Add column shift to right hemi - fwd_out['source_rr'] = np.vstack([fwd_out['source_rr'], - fwd['source_rr'][src_sel]]) - fwd_out['nsource'] += len(src_sel) + fwd_out["source_rr"] = np.vstack( + [fwd_out["source_rr"], fwd["source_rr"][src_sel]] + ) + fwd_out["nsource"] += len(src_sel) if is_fixed_orient(fwd): idx = src_sel - if fwd['sol_grad'] is not None: + if fwd["sol_grad"] is not None: idx_grad = (3 * src_sel[:, None] + np.arange(3)).ravel() else: idx = (3 * src_sel[:, None] + np.arange(3)).ravel() - if fwd['sol_grad'] is not None: + if fwd["sol_grad"] is not None: idx_grad = (9 * src_sel[:, None] + np.arange(9)).ravel() - fwd_out['source_nn'] = np.vstack( - [fwd_out['source_nn'], fwd['source_nn'][idx]]) - fwd_out['sol']['data'] = np.hstack( - [fwd_out['sol']['data'], fwd['sol']['data'][:, idx]]) - if fwd['sol_grad'] is not None: - fwd_out['sol_grad']['data'] = np.hstack( - [fwd_out['sol_grad']['data'], - fwd['sol_rad']['data'][:, idx_grad]]) - fwd_out['sol']['ncol'] += len(idx) + fwd_out["source_nn"] = np.vstack([fwd_out["source_nn"], fwd["source_nn"][idx]]) + fwd_out["sol"]["data"] = np.hstack( + [fwd_out["sol"]["data"], fwd["sol"]["data"][:, idx]] + ) + if fwd["sol_grad"] is not None: + fwd_out["sol_grad"]["data"] = np.hstack( + [fwd_out["sol_grad"]["data"], fwd["sol_rad"]["data"][:, idx_grad]] + ) + fwd_out["sol"]["ncol"] += len(idx) if is_fixed_orient(fwd, orig=True): idx = src_sel - if fwd['sol_grad'] is not None: + if fwd["sol_grad"] is not None: idx_grad = (3 * src_sel[:, None] + np.arange(3)).ravel() else: idx = (3 * src_sel[:, None] + np.arange(3)).ravel() - if fwd['sol_grad'] is not None: + if fwd["sol_grad"] is not None: idx_grad = (9 * src_sel[:, None] + np.arange(9)).ravel() - fwd_out['_orig_sol'] = np.hstack( - [fwd_out['_orig_sol'], fwd['_orig_sol'][:, idx]]) - if fwd['sol_grad'] is not None: - fwd_out['_orig_sol_grad'] = np.hstack( - [fwd_out['_orig_sol_grad'], - fwd['_orig_sol_grad'][:, idx_grad]]) + fwd_out["_orig_sol"] = np.hstack( + [fwd_out["_orig_sol"], fwd["_orig_sol"][:, idx]] + ) + if fwd["sol_grad"] is not None: + fwd_out["_orig_sol_grad"] = np.hstack( + [fwd_out["_orig_sol_grad"], fwd["_orig_sol_grad"][:, idx_grad]] + ) return fwd_out -def _do_forward_solution(subject, meas, fname=None, src=None, spacing=None, - mindist=None, bem=None, mri=None, trans=None, - eeg=True, meg=True, fixed=False, grad=False, - mricoord=False, overwrite=False, subjects_dir=None, - verbose=None): +def _do_forward_solution( + subject, + meas, + fname=None, + src=None, + spacing=None, + mindist=None, + bem=None, + mri=None, + trans=None, + eeg=True, + meg=True, + fixed=False, + grad=False, + mricoord=False, + overwrite=False, + subjects_dir=None, + verbose=None, +): """Calculate a forward solution for a subject using MNE-C routines. This is kept around for testing purposes. @@ -1852,7 +2005,7 @@ def _do_forward_solution(subject, meas, fname=None, src=None, spacing=None, The generated forward solution. """ if not has_mne_c(): - raise RuntimeError('mne command line tools could not be found') + raise RuntimeError("mne command line tools could not be found") # check for file existence temp_dir = Path(tempfile.mkdtemp()) @@ -1862,9 +2015,9 @@ def _do_forward_solution(subject, meas, fname=None, src=None, spacing=None, _validate_type(subject, "str", "subject") # check for meas to exist as string, or try to make evoked - _validate_type(meas, ('path-like', BaseRaw, BaseEpochs, Evoked), 'meas') + _validate_type(meas, ("path-like", BaseRaw, BaseEpochs, Evoked), "meas") if isinstance(meas, (BaseRaw, BaseEpochs, Evoked)): - meas_file = op.join(temp_dir, 'info.fif') + meas_file = op.join(temp_dir, "info.fif") write_info(meas_file, meas.info) meas = meas_file else: @@ -1872,11 +2025,11 @@ def _do_forward_solution(subject, meas, fname=None, src=None, spacing=None, # deal with trans/mri if mri is not None and trans is not None: - raise ValueError('trans and mri cannot both be specified') + raise ValueError("trans and mri cannot both be specified") if mri is None and trans is None: # MNE allows this to default to a trans/mri in the subject's dir, # but let's be safe here and force the user to pass us a trans/mri - raise ValueError('Either trans or mri must be specified') + raise ValueError("Either trans or mri must be specified") if trans is not None: if isinstance(trans, dict): @@ -1885,8 +2038,10 @@ def _do_forward_solution(subject, meas, fname=None, src=None, spacing=None, try: write_trans(trans, trans_data) except Exception: - raise OSError('trans was a dict, but could not be ' - 'written to disk as a transform file') + raise OSError( + "trans was a dict, but could not be " + "written to disk as a transform file" + ) elif isinstance(trans, (str, Path, PathLike)): _check_fname(trans, "read", must_exist=True, name="trans") trans = Path(trans) @@ -1899,8 +2054,10 @@ def _do_forward_solution(subject, meas, fname=None, src=None, spacing=None, try: write_trans(mri, mri_data) except Exception: - raise OSError('mri was a dict, but could not be ' - 'written to disk as a transform file') + raise OSError( + "mri was a dict, but could not be " + "written to disk as a transform file" + ) elif isinstance(mri, (str, Path, PathLike)): _check_fname(mri, "read", must_exist=True, name="mri") mri = Path(mri) @@ -1909,37 +2066,45 @@ def _do_forward_solution(subject, meas, fname=None, src=None, spacing=None, # deal with meg/eeg if not meg and not eeg: - raise ValueError('meg or eeg (or both) must be True') + raise ValueError("meg or eeg (or both) must be True") if not fname.suffix == ".fif": - raise ValueError('Forward name does not end with .fif') + raise ValueError("Forward name does not end with .fif") path = fname.parent.absolute() fname = fname.name # deal with mindist if mindist is not None: if isinstance(mindist, str): - if not mindist.lower() == 'all': + if not mindist.lower() == "all": raise ValueError('mindist, if string, must be "all"') - mindist = ['--all'] + mindist = ["--all"] else: - mindist = ['--mindist', '%g' % mindist] + mindist = ["--mindist", "%g" % mindist] # src, spacing, bem - for element, name, kind in zip((src, spacing, bem), - ("src", "spacing", "bem"), - ('path-like', 'str', 'path-like')): + for element, name, kind in zip( + (src, spacing, bem), + ("src", "spacing", "bem"), + ("path-like", "str", "path-like"), + ): if element is not None: _validate_type(element, kind, name, "%s or None" % kind) # put together the actual call - cmd = ['mne_do_forward_solution', - '--subject', subject, - '--meas', meas, - '--fwd', fname, - '--destdir', str(path)] + cmd = [ + "mne_do_forward_solution", + "--subject", + subject, + "--meas", + meas, + "--fwd", + fname, + "--destdir", + str(path), + ] if src is not None: - cmd += ['--src', src] + cmd += ["--src", src] if spacing is not None: if spacing.isdigit(): pass # spacing in mm @@ -1948,36 +2113,38 @@ def _do_forward_solution(subject, meas, fname=None, src=None, spacing=None, match = re.match(r"(oct|ico)-?(\d+)$", spacing) if match is None: raise ValueError("Invalid spacing parameter: %r" % spacing) - spacing = '-'.join(match.groups()) - cmd += ['--spacing', spacing] + spacing = "-".join(match.groups()) + cmd += ["--spacing", spacing] if mindist is not None: cmd += mindist if bem is not None: - cmd += ['--bem', bem] + cmd += ["--bem", bem] if mri is not None: - cmd += ['--mri', '%s' % str(mri.absolute())] + cmd += ["--mri", "%s" % str(mri.absolute())] if trans is not None: - cmd += ['--trans', '%s' % str(trans.absolute())] + cmd += ["--trans", "%s" % str(trans.absolute())] if not meg: - cmd.append('--eegonly') + cmd.append("--eegonly") if not eeg: - cmd.append('--megonly') + cmd.append("--megonly") if fixed: - cmd.append('--fixed') + cmd.append("--fixed") if grad: - cmd.append('--grad') + cmd.append("--grad") if mricoord: - cmd.append('--mricoord') + cmd.append("--mricoord") if overwrite: - cmd.append('--overwrite') + cmd.append("--overwrite") env = os.environ.copy() subjects_dir = str(get_subjects_dir(subjects_dir, raise_error=True)) - env['SUBJECTS_DIR'] = subjects_dir + env["SUBJECTS_DIR"] = subjects_dir try: - logger.info('Running forward solution generation command with ' - 'subjects_dir %s' % subjects_dir) + logger.info( + "Running forward solution generation command with " + "subjects_dir %s" % subjects_dir + ) run_subprocess(cmd, env=env) except Exception: raise @@ -2011,19 +2178,19 @@ def average_forward_solutions(fwds, weights=None, verbose=None): # check for fwds being a list _validate_type(fwds, list, "fwds") if not len(fwds) > 0: - raise ValueError('fwds must not be empty') + raise ValueError("fwds must not be empty") # check weights if weights is None: weights = np.ones(len(fwds)) weights = np.asanyarray(weights) # in case it's a list, convert it if not np.all(weights >= 0): - raise ValueError('weights must be non-negative') + raise ValueError("weights must be non-negative") if not len(weights) == len(fwds): - raise ValueError('weights must be None or the same length as fwds') + raise ValueError("weights must be None or the same length as fwds") w_sum = np.sum(weights) if not w_sum > 0: - raise ValueError('weights cannot all be zero') + raise ValueError("weights cannot all be zero") weights /= w_sum # check our forward solutions @@ -2031,32 +2198,49 @@ def average_forward_solutions(fwds, weights=None, verbose=None): # check to make sure it's a forward solution _validate_type(fwd, dict, "each entry in fwds", "dict") # check to make sure the dict is actually a fwd - check_keys = ['info', 'sol_grad', 'nchan', 'src', 'source_nn', 'sol', - 'source_rr', 'source_ori', 'surf_ori', 'coord_frame', - 'mri_head_t', 'nsource'] + check_keys = [ + "info", + "sol_grad", + "nchan", + "src", + "source_nn", + "sol", + "source_rr", + "source_ori", + "surf_ori", + "coord_frame", + "mri_head_t", + "nsource", + ] if not all(key in fwd for key in check_keys): - raise KeyError('forward solution dict does not have all standard ' - 'entries, cannot compute average.') + raise KeyError( + "forward solution dict does not have all standard " + "entries, cannot compute average." + ) # check forward solution compatibility - if any(fwd['sol'][k] != fwds[0]['sol'][k] - for fwd in fwds[1:] for k in ['nrow', 'ncol']): - raise ValueError('Forward solutions have incompatible dimensions') - if any(fwd[k] != fwds[0][k] for fwd in fwds[1:] - for k in ['source_ori', 'surf_ori', 'coord_frame']): - raise ValueError('Forward solutions have incompatible orientations') + if any( + fwd["sol"][k] != fwds[0]["sol"][k] for fwd in fwds[1:] for k in ["nrow", "ncol"] + ): + raise ValueError("Forward solutions have incompatible dimensions") + if any( + fwd[k] != fwds[0][k] + for fwd in fwds[1:] + for k in ["source_ori", "surf_ori", "coord_frame"] + ): + raise ValueError("Forward solutions have incompatible orientations") # actually average them (solutions and gradients) fwd_ave = deepcopy(fwds[0]) - fwd_ave['sol']['data'] *= weights[0] - fwd_ave['_orig_sol'] *= weights[0] + fwd_ave["sol"]["data"] *= weights[0] + fwd_ave["_orig_sol"] *= weights[0] for fwd, w in zip(fwds[1:], weights[1:]): - fwd_ave['sol']['data'] += w * fwd['sol']['data'] - fwd_ave['_orig_sol'] += w * fwd['_orig_sol'] - if fwd_ave['sol_grad'] is not None: - fwd_ave['sol_grad']['data'] *= weights[0] - fwd_ave['_orig_sol_grad'] *= weights[0] + fwd_ave["sol"]["data"] += w * fwd["sol"]["data"] + fwd_ave["_orig_sol"] += w * fwd["_orig_sol"] + if fwd_ave["sol_grad"] is not None: + fwd_ave["sol_grad"]["data"] *= weights[0] + fwd_ave["_orig_sol_grad"] *= weights[0] for fwd, w in zip(fwds[1:], weights[1:]): - fwd_ave['sol_grad']['data'] += w * fwd['sol_grad']['data'] - fwd_ave['_orig_sol_grad'] += w * fwd['_orig_sol_grad'] + fwd_ave["sol_grad"]["data"] += w * fwd["sol_grad"]["data"] + fwd_ave["_orig_sol_grad"] += w * fwd["_orig_sol_grad"] return fwd_ave diff --git a/mne/forward/tests/test_field_interpolation.py b/mne/forward/tests/test_field_interpolation.py index 9adf3915870..036a4a58af9 100644 --- a/mne/forward/tests/test_field_interpolation.py +++ b/mne/forward/tests/test_field_interpolation.py @@ -3,16 +3,24 @@ import numpy as np from numpy.polynomial import legendre -from numpy.testing import (assert_allclose, assert_array_equal, assert_equal, - assert_array_almost_equal) +from numpy.testing import ( + assert_allclose, + assert_array_equal, + assert_equal, + assert_array_almost_equal, +) from scipy.interpolate import interp1d import pytest import mne from mne.forward import _make_surface_mapping, make_field_map -from mne.forward._lead_dots import (_comp_sum_eeg, _comp_sums_meg, - _get_legen_table, _do_cross_dots) +from mne.forward._lead_dots import ( + _comp_sum_eeg, + _comp_sums_meg, + _get_legen_table, + _do_cross_dots, +) from mne.forward._make_forward import _create_meg_coils from mne.forward._field_interpolation import _setup_dots from mne.surface import get_meg_helmet_surf, get_head_surf @@ -21,15 +29,14 @@ from mne.io import read_raw_fif -base_dir = op.join(op.dirname(__file__), '..', '..', 'io', 'tests', 'data') -raw_fname = op.join(base_dir, 'test_raw.fif') -evoked_fname = op.join(base_dir, 'test-ave.fif') -raw_ctf_fname = op.join(base_dir, 'test_ctf_raw.fif') +base_dir = op.join(op.dirname(__file__), "..", "..", "io", "tests", "data") +raw_fname = op.join(base_dir, "test_raw.fif") +evoked_fname = op.join(base_dir, "test-ave.fif") +raw_ctf_fname = op.join(base_dir, "test_ctf_raw.fif") data_path = testing.data_path(download=False) -trans_fname = op.join(data_path, 'MEG', 'sample', - 'sample_audvis_trunc-trans.fif') -subjects_dir = op.join(data_path, 'subjects') +trans_fname = op.join(data_path, "MEG", "sample", "sample_audvis_trunc-trans.fif") +subjects_dir = op.join(data_path, "subjects") @testing.requires_testing_data @@ -41,29 +48,30 @@ def test_field_map_ctf(): evoked = Epochs(raw, events).average() evoked.pick_channels(evoked.ch_names[:50]) # crappy mapping but faster # smoke test - passing trans_fname as pathlib.Path as additional check - make_field_map(evoked, trans=Path(trans_fname), subject='sample', - subjects_dir=subjects_dir) + make_field_map( + evoked, trans=Path(trans_fname), subject="sample", subjects_dir=subjects_dir + ) def test_legendre_val(): """Test Legendre polynomial (derivative) equivalence.""" rng = np.random.RandomState(0) # check table equiv - xs = np.linspace(-1., 1., 1000) + xs = np.linspace(-1.0, 1.0, 1000) n_terms = 100 # True, numpy vals_np = legendre.legvander(xs, n_terms - 1) # Table approximation - for nc, interp in zip([100, 50], ['nearest', 'linear']): - lut, n_fact = _get_legen_table('eeg', n_coeff=nc, force_calc=True) - lut_fun = interp1d(np.linspace(-1, 1, lut.shape[0]), lut, interp, - axis=0) + for nc, interp in zip([100, 50], ["nearest", "linear"]): + lut, n_fact = _get_legen_table("eeg", n_coeff=nc, force_calc=True) + lut_fun = interp1d(np.linspace(-1, 1, lut.shape[0]), lut, interp, axis=0) vals_i = lut_fun(xs) # Need a "1:" here because we omit the first coefficient in our table! - assert_allclose(vals_np[:, 1:vals_i.shape[1] + 1], vals_i, - rtol=1e-2, atol=5e-3) + assert_allclose( + vals_np[:, 1 : vals_i.shape[1] + 1], vals_i, rtol=1e-2, atol=5e-3 + ) # Now let's look at our sums ctheta = rng.rand(20, 30) * 2.0 - 1.0 @@ -74,24 +82,27 @@ def test_legendre_val(): # compare to numpy n = np.arange(1, n_terms, dtype=float)[:, np.newaxis, np.newaxis] coeffs = np.zeros((n_terms,) + beta.shape) - coeffs[1:] = (np.cumprod([beta] * (n_terms - 1), axis=0) * - (2.0 * n + 1.0) * (2.0 * n + 1.0) / n) + coeffs[1:] = ( + np.cumprod([beta] * (n_terms - 1), axis=0) + * (2.0 * n + 1.0) + * (2.0 * n + 1.0) + / n + ) # can't use tensor=False here b/c it isn't in old numpy c2 = np.empty((20, 30)) for ci1 in range(20): for ci2 in range(30): - c2[ci1, ci2] = legendre.legval(ctheta[ci1, ci2], - coeffs[:, ci1, ci2]) + c2[ci1, ci2] = legendre.legval(ctheta[ci1, ci2], coeffs[:, ci1, ci2]) assert_allclose(c1, c2, 1e-2, 1e-3) # close enough... # compare fast and slow for MEG ctheta = rng.rand(20 * 30) * 2.0 - 1.0 beta = rng.rand(20 * 30) * 0.8 - lut, n_fact = _get_legen_table('meg', n_coeff=10, force_calc=True) - fun = interp1d(np.linspace(-1, 1, lut.shape[0]), lut, 'nearest', axis=0) + lut, n_fact = _get_legen_table("meg", n_coeff=10, force_calc=True) + fun = interp1d(np.linspace(-1, 1, lut.shape[0]), lut, "nearest", axis=0) coeffs = _comp_sums_meg(beta, ctheta, fun, n_fact, False) - lut, n_fact = _get_legen_table('meg', n_coeff=20, force_calc=True) - fun = interp1d(np.linspace(-1, 1, lut.shape[0]), lut, 'linear', axis=0) + lut, n_fact = _get_legen_table("meg", n_coeff=20, force_calc=True) + fun = interp1d(np.linspace(-1, 1, lut.shape[0]), lut, "linear", axis=0) coeffs = _comp_sums_meg(beta, ctheta, fun, n_fact, False) @@ -99,10 +110,10 @@ def test_legendre_table(): """Test Legendre table calculation.""" # double-check our table generation n = 10 - for ch_type in ['eeg', 'meg']: + for ch_type in ["eeg", "meg"]: lut1, n_fact1 = _get_legen_table(ch_type, n_coeff=25, force_calc=True) - lut1 = lut1[:, :n - 1].copy() - n_fact1 = n_fact1[:n - 1].copy() + lut1 = lut1[:, : n - 1].copy() + n_fact1 = n_fact1[: n - 1].copy() lut2, n_fact2 = _get_legen_table(ch_type, n_coeff=n, force_calc=True) assert_allclose(lut1, lut2) assert_allclose(n_fact1, n_fact2) @@ -111,77 +122,93 @@ def test_legendre_table(): @testing.requires_testing_data def test_make_field_map_eeg(): """Test interpolation of EEG field onto head.""" - evoked = read_evokeds(evoked_fname, condition='Left Auditory') - evoked.info['bads'] = ['MEG 2443', 'EEG 053'] # add some bads - surf = get_head_surf('sample', subjects_dir=subjects_dir) + evoked = read_evokeds(evoked_fname, condition="Left Auditory") + evoked.info["bads"] = ["MEG 2443", "EEG 053"] # add some bads + surf = get_head_surf("sample", subjects_dir=subjects_dir) # we must have trans if surface is in MRI coords - pytest.raises(ValueError, _make_surface_mapping, evoked.info, surf, 'eeg') + pytest.raises(ValueError, _make_surface_mapping, evoked.info, surf, "eeg") evoked.pick_types(meg=False, eeg=True) - fmd = make_field_map(evoked, trans_fname, - subject='sample', subjects_dir=subjects_dir) + fmd = make_field_map( + evoked, trans_fname, subject="sample", subjects_dir=subjects_dir + ) # trans is necessary for EEG only - pytest.raises(RuntimeError, make_field_map, evoked, None, - subject='sample', subjects_dir=subjects_dir) - - fmd = make_field_map(evoked, trans_fname, - subject='sample', subjects_dir=subjects_dir) + pytest.raises( + RuntimeError, + make_field_map, + evoked, + None, + subject="sample", + subjects_dir=subjects_dir, + ) + + fmd = make_field_map( + evoked, trans_fname, subject="sample", subjects_dir=subjects_dir + ) assert len(fmd) == 1 - assert_array_equal(fmd[0]['data'].shape, (642, 59)) # maps data onto surf - assert len(fmd[0]['ch_names']) == 59 + assert_array_equal(fmd[0]["data"].shape, (642, 59)) # maps data onto surf + assert len(fmd[0]["ch_names"]) == 59 @testing.requires_testing_data @pytest.mark.slowtest def test_make_field_map_meg(): """Test interpolation of MEG field onto helmet | head.""" - evoked = read_evokeds(evoked_fname, condition='Left Auditory') + evoked = read_evokeds(evoked_fname, condition="Left Auditory") info = evoked.info surf = get_meg_helmet_surf(info) # let's reduce the number of channels by a bunch to speed it up - info['bads'] = info['ch_names'][:200] + info["bads"] = info["ch_names"][:200] # bad ch_type - pytest.raises(ValueError, _make_surface_mapping, info, surf, 'foo') + pytest.raises(ValueError, _make_surface_mapping, info, surf, "foo") # bad mode - pytest.raises(ValueError, _make_surface_mapping, info, surf, 'meg', - mode='foo') + pytest.raises(ValueError, _make_surface_mapping, info, surf, "meg", mode="foo") # no picks evoked_eeg = evoked.copy().pick_types(meg=False, eeg=True) - pytest.raises(RuntimeError, _make_surface_mapping, evoked_eeg.info, - surf, 'meg') + pytest.raises(RuntimeError, _make_surface_mapping, evoked_eeg.info, surf, "meg") # bad surface def - nn = surf['nn'] - del surf['nn'] - pytest.raises(KeyError, _make_surface_mapping, info, surf, 'meg') - surf['nn'] = nn - cf = surf['coord_frame'] - del surf['coord_frame'] - pytest.raises(KeyError, _make_surface_mapping, info, surf, 'meg') - surf['coord_frame'] = cf + nn = surf["nn"] + del surf["nn"] + pytest.raises(KeyError, _make_surface_mapping, info, surf, "meg") + surf["nn"] = nn + cf = surf["coord_frame"] + del surf["coord_frame"] + pytest.raises(KeyError, _make_surface_mapping, info, surf, "meg") + surf["coord_frame"] = cf # now do it with make_field_map evoked.pick_types(meg=True, eeg=False) evoked.info.normalize_proj() # avoid projection warnings - fmd = make_field_map(evoked, None, - subject='sample', subjects_dir=subjects_dir) - assert (len(fmd) == 1) - assert_array_equal(fmd[0]['data'].shape, (304, 106)) # maps data onto surf - assert len(fmd[0]['ch_names']) == 106 + fmd = make_field_map(evoked, None, subject="sample", subjects_dir=subjects_dir) + assert len(fmd) == 1 + assert_array_equal(fmd[0]["data"].shape, (304, 106)) # maps data onto surf + assert len(fmd[0]["ch_names"]) == 106 - pytest.raises(ValueError, make_field_map, evoked, ch_type='foobar') + pytest.raises(ValueError, make_field_map, evoked, ch_type="foobar") # now test the make_field_map on head surf for MEG evoked.pick_types(meg=True, eeg=False) evoked.info.normalize_proj() - fmd = make_field_map(evoked, trans_fname, meg_surf='head', - subject='sample', subjects_dir=subjects_dir) + fmd = make_field_map( + evoked, + trans_fname, + meg_surf="head", + subject="sample", + subjects_dir=subjects_dir, + ) assert len(fmd) == 1 - assert_array_equal(fmd[0]['data'].shape, (642, 106)) # maps data onto surf - assert len(fmd[0]['ch_names']) == 106 + assert_array_equal(fmd[0]["data"].shape, (642, 106)) # maps data onto surf + assert len(fmd[0]["ch_names"]) == 106 - pytest.raises(ValueError, make_field_map, evoked, meg_surf='foobar', - subjects_dir=subjects_dir, trans=trans_fname) + pytest.raises( + ValueError, + make_field_map, + evoked, + meg_surf="foobar", + subjects_dir=subjects_dir, + trans=trans_fname, + ) @testing.requires_testing_data @@ -192,31 +219,45 @@ def test_make_field_map_meeg(): picks = picks[::10] evoked.pick_channels([evoked.ch_names[p] for p in picks]) evoked.info.normalize_proj() - maps = make_field_map(evoked, trans_fname, subject='sample', - subjects_dir=subjects_dir, verbose='debug') - assert_equal(maps[0]['data'].shape, (642, 6)) # EEG->Head - assert_equal(maps[1]['data'].shape, (304, 31)) # MEG->Helmet + maps = make_field_map( + evoked, + trans_fname, + subject="sample", + subjects_dir=subjects_dir, + verbose="debug", + ) + assert_equal(maps[0]["data"].shape, (642, 6)) # EEG->Head + assert_equal(maps[1]["data"].shape, (304, 31)) # MEG->Helmet # reasonable ranges maxs = (1.2, 2.0) # before #4418, was (1.1, 2.0) mins = (-0.8, -1.3) # before #4418, was (-0.6, -1.2) assert_equal(len(maxs), len(maps)) for map_, max_, min_ in zip(maps, maxs, mins): - assert_allclose(map_['data'].max(), max_, rtol=5e-2) - assert_allclose(map_['data'].min(), min_, rtol=5e-2) + assert_allclose(map_["data"].max(), max_, rtol=5e-2) + assert_allclose(map_["data"].min(), min_, rtol=5e-2) # calculated from correct looking mapping on 2015/12/26 - assert_allclose(np.sqrt(np.sum(maps[0]['data'] ** 2)), 19.0903, # 16.6088, - atol=1e-3, rtol=1e-3) - assert_allclose(np.sqrt(np.sum(maps[1]['data'] ** 2)), 19.4748, # 20.1245, - atol=1e-3, rtol=1e-3) + assert_allclose( + np.sqrt(np.sum(maps[0]["data"] ** 2)), 19.0903, atol=1e-3, rtol=1e-3 # 16.6088, + ) + assert_allclose( + np.sqrt(np.sum(maps[1]["data"] ** 2)), 19.4748, atol=1e-3, rtol=1e-3 # 20.1245, + ) def _setup_args(info): """Configure args for test_as_meg_type_evoked.""" - coils = _create_meg_coils(info['chs'], 'normal', info['dev_head_t']) - int_rad, _, lut_fun, n_fact = _setup_dots('fast', info, coils, 'meg') - my_origin = np.array([0., 0., 0.04]) - args_dict = dict(intrad=int_rad, volume=False, coils1=coils, r0=my_origin, - ch_type='meg', lut=lut_fun, n_fact=n_fact) + coils = _create_meg_coils(info["chs"], "normal", info["dev_head_t"]) + int_rad, _, lut_fun, n_fact = _setup_dots("fast", info, coils, "meg") + my_origin = np.array([0.0, 0.0, 0.04]) + args_dict = dict( + intrad=int_rad, + volume=False, + coils1=coils, + r0=my_origin, + ch_type="meg", + lut=lut_fun, + n_fact=n_fact, + ) return args_dict @@ -226,23 +267,30 @@ def test_as_meg_type_evoked(): # validation tests raw = read_raw_fif(raw_fname) events = mne.find_events(raw) - picks = pick_types(raw.info, meg=True, eeg=True, stim=True, - ecg=True, eog=True, include=['STI 014'], - exclude='bads') + picks = pick_types( + raw.info, + meg=True, + eeg=True, + stim=True, + ecg=True, + eog=True, + include=["STI 014"], + exclude="bads", + ) epochs = mne.Epochs(raw, events, picks=picks) evoked = epochs.average() with pytest.raises(ValueError, match="Invalid value for the 'ch_type'"): - evoked.as_type('meg') + evoked.as_type("meg") with pytest.raises(ValueError, match="Invalid value for the 'ch_type'"): - evoked.copy().pick_types(meg='grad').as_type('meg') + evoked.copy().pick_types(meg="grad").as_type("meg") # channel names - ch_names = evoked.info['ch_names'] + ch_names = evoked.info["ch_names"] virt_evoked = evoked.copy().pick_channels(ch_names=ch_names[:10:1]) virt_evoked.info.normalize_proj() - virt_evoked = virt_evoked.as_type('mag') - assert (all(ch.endswith('_v') for ch in virt_evoked.info['ch_names'])) + virt_evoked = virt_evoked.as_type("mag") + assert all(ch.endswith("_v") for ch in virt_evoked.info["ch_names"]) # pick from and to channels evoked_from = evoked.copy().pick_channels(ch_names=ch_names[2:10:3]) @@ -252,8 +300,8 @@ def test_as_meg_type_evoked(): # set up things args1, args2 = _setup_args(info_from), _setup_args(info_to) - args1.update(coils2=args2['coils1']) - args2.update(coils2=args1['coils1']) + args1.update(coils2=args2["coils1"]) + args2.update(coils2=args1["coils1"]) # test cross dots cross_dots1 = _do_cross_dots(**args1) @@ -263,14 +311,13 @@ def test_as_meg_type_evoked(): # correlation test evoked = evoked.pick_channels(ch_names=ch_names[:10:]).copy() - data1 = evoked.pick_types(meg='grad').data.ravel() - data2 = evoked.as_type('grad').data.ravel() - assert (np.corrcoef(data1, data2)[0, 1] > 0.95) + data1 = evoked.pick_types(meg="grad").data.ravel() + data2 = evoked.as_type("grad").data.ravel() + assert np.corrcoef(data1, data2)[0, 1] > 0.95 # Do it with epochs - virt_epochs = \ - epochs.copy().load_data().pick_channels(ch_names=ch_names[:10:1]) + virt_epochs = epochs.copy().load_data().pick_channels(ch_names=ch_names[:10:1]) virt_epochs.info.normalize_proj() - virt_epochs = virt_epochs.as_type('mag') - assert (all(ch.endswith('_v') for ch in virt_epochs.info['ch_names'])) + virt_epochs = virt_epochs.as_type("mag") + assert all(ch.endswith("_v") for ch in virt_epochs.info["ch_names"]) assert_allclose(virt_epochs.get_data().mean(0), virt_evoked.data) diff --git a/mne/forward/tests/test_forward.py b/mne/forward/tests/test_forward.py index ff244d9e0bf..59f53b349c0 100644 --- a/mne/forward/tests/test_forward.py +++ b/mne/forward/tests/test_forward.py @@ -3,57 +3,65 @@ import pytest import numpy as np -from numpy.testing import (assert_array_almost_equal, assert_equal, - assert_array_equal, assert_allclose) +from numpy.testing import ( + assert_array_almost_equal, + assert_equal, + assert_array_equal, + assert_allclose, +) from mne.datasets import testing -from mne import (read_forward_solution, apply_forward, apply_forward_raw, - average_forward_solutions, write_forward_solution, - convert_forward_solution, SourceEstimate, pick_types_forward, - read_evokeds, VectorSourceEstimate) +from mne import ( + read_forward_solution, + apply_forward, + apply_forward_raw, + average_forward_solutions, + write_forward_solution, + convert_forward_solution, + SourceEstimate, + pick_types_forward, + read_evokeds, + VectorSourceEstimate, +) from mne.io import read_info from mne.label import read_label from mne.utils import requires_mne, run_subprocess -from mne.forward import (restrict_forward_to_stc, restrict_forward_to_label, - Forward, is_fixed_orient, compute_orient_prior, - compute_depth_prior) +from mne.forward import ( + restrict_forward_to_stc, + restrict_forward_to_label, + Forward, + is_fixed_orient, + compute_orient_prior, + compute_depth_prior, +) from mne.channels import equalize_channels data_path = testing.data_path(download=False) -fname_meeg = ( - data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-4-fwd.fif" -) +fname_meeg = data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-4-fwd.fif" fname_meeg_grad = ( - data_path - / "MEG" - / "sample" - / "sample_audvis_trunc-meg-eeg-oct-2-grad-fwd.fif" + data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-2-grad-fwd.fif" ) fname_evoked = ( - Path(__file__).parent.parent.parent - / "io" - / "tests" - / "data" - / "test-ave.fif" + Path(__file__).parent.parent.parent / "io" / "tests" / "data" / "test-ave.fif" ) def assert_forward_allclose(f1, f2, rtol=1e-7): """Compare two potentially converted forward solutions.""" - assert_allclose(f1['sol']['data'], f2['sol']['data'], rtol=rtol) - assert f1['sol']['ncol'] == f2['sol']['ncol'] - assert f1['sol']['ncol'] == f1['sol']['data'].shape[1] - assert_allclose(f1['source_nn'], f2['source_nn'], rtol=rtol) - if f1['sol_grad'] is not None: - assert (f2['sol_grad'] is not None) - assert_allclose(f1['sol_grad']['data'], f2['sol_grad']['data']) - assert f1['sol_grad']['ncol'] == f2['sol_grad']['ncol'] - assert f1['sol_grad']['ncol'] == f1['sol_grad']['data'].shape[1] + assert_allclose(f1["sol"]["data"], f2["sol"]["data"], rtol=rtol) + assert f1["sol"]["ncol"] == f2["sol"]["ncol"] + assert f1["sol"]["ncol"] == f1["sol"]["data"].shape[1] + assert_allclose(f1["source_nn"], f2["source_nn"], rtol=rtol) + if f1["sol_grad"] is not None: + assert f2["sol_grad"] is not None + assert_allclose(f1["sol_grad"]["data"], f2["sol_grad"]["data"]) + assert f1["sol_grad"]["ncol"] == f2["sol_grad"]["ncol"] + assert f1["sol_grad"]["ncol"] == f1["sol_grad"]["data"].shape[1] else: - assert (f2['sol_grad'] is None) - assert f1['source_ori'] == f2['source_ori'] - assert f1['surf_ori'] == f2['surf_ori'] - assert f1['src'][0]['coord_frame'] == f1['src'][0]['coord_frame'] + assert f2["sol_grad"] is None + assert f1["source_ori"] == f2["source_ori"] + assert f1["surf_ori"] == f2["surf_ori"] + assert f1["src"][0]["coord_frame"] == f1["src"][0]["coord_frame"] @testing.requires_testing_data @@ -61,33 +69,33 @@ def test_convert_forward(): """Test converting forward solution between different representations.""" fwd = read_forward_solution(fname_meeg_grad) fwd_repr = repr(fwd) - assert ('306' in fwd_repr) - assert ('60' in fwd_repr) - assert (fwd_repr) - assert (isinstance(fwd, Forward)) + assert "306" in fwd_repr + assert "60" in fwd_repr + assert fwd_repr + assert isinstance(fwd, Forward) # look at surface orientation fwd_surf = convert_forward_solution(fwd, surf_ori=True) # go back fwd_new = convert_forward_solution(fwd_surf, surf_ori=False) - assert (repr(fwd_new)) - assert (isinstance(fwd_new, Forward)) + assert repr(fwd_new) + assert isinstance(fwd_new, Forward) assert_forward_allclose(fwd, fwd_new) del fwd_new gc.collect() # now go to fixed - fwd_fixed = convert_forward_solution(fwd_surf, surf_ori=True, - force_fixed=True, use_cps=False) + fwd_fixed = convert_forward_solution( + fwd_surf, surf_ori=True, force_fixed=True, use_cps=False + ) del fwd_surf gc.collect() - assert (repr(fwd_fixed)) - assert (isinstance(fwd_fixed, Forward)) - assert (is_fixed_orient(fwd_fixed)) + assert repr(fwd_fixed) + assert isinstance(fwd_fixed, Forward) + assert is_fixed_orient(fwd_fixed) # now go back to cartesian (original condition) - fwd_new = convert_forward_solution(fwd_fixed, surf_ori=False, - force_fixed=False) - assert (repr(fwd_new)) - assert (isinstance(fwd_new, Forward)) + fwd_new = convert_forward_solution(fwd_fixed, surf_ori=False, force_fixed=False) + assert repr(fwd_new) + assert isinstance(fwd_new, Forward) assert_forward_allclose(fwd, fwd_new) del fwd, fwd_new, fwd_fixed gc.collect() @@ -100,86 +108,86 @@ def test_io_forward(tmp_path): # do extensive tests with MEEG + grad n_channels, n_src = 366, 108 fwd = read_forward_solution(fname_meeg_grad) - assert (isinstance(fwd, Forward)) + assert isinstance(fwd, Forward) fwd = read_forward_solution(fname_meeg_grad) fwd = convert_forward_solution(fwd, surf_ori=True) - leadfield = fwd['sol']['data'] + leadfield = fwd["sol"]["data"] assert_equal(leadfield.shape, (n_channels, n_src)) - assert_equal(len(fwd['sol']['row_names']), n_channels) - fname_temp = tmp_path / 'test-fwd.fif' - with pytest.warns(RuntimeWarning, match='stored on disk'): + assert_equal(len(fwd["sol"]["row_names"]), n_channels) + fname_temp = tmp_path / "test-fwd.fif" + with pytest.warns(RuntimeWarning, match="stored on disk"): write_forward_solution(fname_temp, fwd, overwrite=True) fwd = read_forward_solution(fname_meeg_grad) fwd = convert_forward_solution(fwd, surf_ori=True) fwd_read = read_forward_solution(fname_temp) fwd_read = convert_forward_solution(fwd_read, surf_ori=True) - leadfield = fwd_read['sol']['data'] + leadfield = fwd_read["sol"]["data"] assert_equal(leadfield.shape, (n_channels, n_src)) - assert_equal(len(fwd_read['sol']['row_names']), n_channels) - assert_equal(len(fwd_read['info']['chs']), n_channels) - assert ('dev_head_t' in fwd_read['info']) - assert ('mri_head_t' in fwd_read) - assert_array_almost_equal(fwd['sol']['data'], fwd_read['sol']['data']) + assert_equal(len(fwd_read["sol"]["row_names"]), n_channels) + assert_equal(len(fwd_read["info"]["chs"]), n_channels) + assert "dev_head_t" in fwd_read["info"] + assert "mri_head_t" in fwd_read + assert_array_almost_equal(fwd["sol"]["data"], fwd_read["sol"]["data"]) fwd = read_forward_solution(fname_meeg) - fwd = convert_forward_solution(fwd, surf_ori=True, force_fixed=True, - use_cps=False) - with pytest.warns(RuntimeWarning, match='stored on disk'): + fwd = convert_forward_solution(fwd, surf_ori=True, force_fixed=True, use_cps=False) + with pytest.warns(RuntimeWarning, match="stored on disk"): write_forward_solution(fname_temp, fwd, overwrite=True) fwd_read = read_forward_solution(fname_temp) - fwd_read = convert_forward_solution(fwd_read, surf_ori=True, - force_fixed=True, use_cps=False) - assert (repr(fwd_read)) - assert (isinstance(fwd_read, Forward)) - assert (is_fixed_orient(fwd_read)) + fwd_read = convert_forward_solution( + fwd_read, surf_ori=True, force_fixed=True, use_cps=False + ) + assert repr(fwd_read) + assert isinstance(fwd_read, Forward) + assert is_fixed_orient(fwd_read) assert_forward_allclose(fwd, fwd_read) - fwd = convert_forward_solution(fwd, surf_ori=True, force_fixed=True, - use_cps=True) - leadfield = fwd['sol']['data'] + fwd = convert_forward_solution(fwd, surf_ori=True, force_fixed=True, use_cps=True) + leadfield = fwd["sol"]["data"] assert_equal(leadfield.shape, (n_channels, 1494 / 3)) - assert_equal(len(fwd['sol']['row_names']), n_channels) - assert_equal(len(fwd['info']['chs']), n_channels) - assert ('dev_head_t' in fwd['info']) - assert ('mri_head_t' in fwd) - assert (fwd['surf_ori']) - with pytest.warns(RuntimeWarning, match='stored on disk'): + assert_equal(len(fwd["sol"]["row_names"]), n_channels) + assert_equal(len(fwd["info"]["chs"]), n_channels) + assert "dev_head_t" in fwd["info"] + assert "mri_head_t" in fwd + assert fwd["surf_ori"] + with pytest.warns(RuntimeWarning, match="stored on disk"): write_forward_solution(fname_temp, fwd, overwrite=True) fwd_read = read_forward_solution(fname_temp) - fwd_read = convert_forward_solution(fwd_read, surf_ori=True, - force_fixed=True, use_cps=True) - assert (repr(fwd_read)) - assert (isinstance(fwd_read, Forward)) - assert (is_fixed_orient(fwd_read)) + fwd_read = convert_forward_solution( + fwd_read, surf_ori=True, force_fixed=True, use_cps=True + ) + assert repr(fwd_read) + assert isinstance(fwd_read, Forward) + assert is_fixed_orient(fwd_read) assert_forward_allclose(fwd, fwd_read) fwd = read_forward_solution(fname_meeg_grad) - fwd = convert_forward_solution(fwd, surf_ori=True, force_fixed=True, - use_cps=True) - leadfield = fwd['sol']['data'] + fwd = convert_forward_solution(fwd, surf_ori=True, force_fixed=True, use_cps=True) + leadfield = fwd["sol"]["data"] assert_equal(leadfield.shape, (n_channels, n_src / 3)) - assert_equal(len(fwd['sol']['row_names']), n_channels) - assert_equal(len(fwd['info']['chs']), n_channels) - assert ('dev_head_t' in fwd['info']) - assert ('mri_head_t' in fwd) - assert (fwd['surf_ori']) - with pytest.warns(RuntimeWarning, match='stored on disk'): + assert_equal(len(fwd["sol"]["row_names"]), n_channels) + assert_equal(len(fwd["info"]["chs"]), n_channels) + assert "dev_head_t" in fwd["info"] + assert "mri_head_t" in fwd + assert fwd["surf_ori"] + with pytest.warns(RuntimeWarning, match="stored on disk"): write_forward_solution(fname_temp, fwd, overwrite=True) fwd_read = read_forward_solution(fname_temp) - fwd_read = convert_forward_solution(fwd_read, surf_ori=True, - force_fixed=True, use_cps=True) - assert (repr(fwd_read)) - assert (isinstance(fwd_read, Forward)) - assert (is_fixed_orient(fwd_read)) + fwd_read = convert_forward_solution( + fwd_read, surf_ori=True, force_fixed=True, use_cps=True + ) + assert repr(fwd_read) + assert isinstance(fwd_read, Forward) + assert is_fixed_orient(fwd_read) assert_forward_allclose(fwd, fwd_read) # test warnings on bad filenames fwd = read_forward_solution(fname_meeg_grad) - fwd_badname = tmp_path / 'test-bad-name.fif.gz' - with pytest.warns(RuntimeWarning, match='end with'): + fwd_badname = tmp_path / "test-bad-name.fif.gz" + with pytest.warns(RuntimeWarning, match="end with"): write_forward_solution(fwd_badname, fwd) - with pytest.warns(RuntimeWarning, match='end with'): + with pytest.warns(RuntimeWarning, match="end with"): read_forward_solution(fwd_badname) fwd = read_forward_solution(fname_meeg) @@ -198,53 +206,55 @@ def test_apply_forward(): t_start = 0.123 fwd = read_forward_solution(fname_meeg) - fwd = convert_forward_solution(fwd, surf_ori=True, force_fixed=True, - use_cps=True) + fwd = convert_forward_solution(fwd, surf_ori=True, force_fixed=True, use_cps=True) fwd = pick_types_forward(fwd, meg=True) assert isinstance(fwd, Forward) - vertno = [fwd['src'][0]['vertno'], fwd['src'][1]['vertno']] + vertno = [fwd["src"][0]["vertno"], fwd["src"][1]["vertno"]] stc_data = np.ones((len(vertno[0]) + len(vertno[1]), n_times)) stc = SourceEstimate(stc_data, vertno, tmin=t_start, tstep=1.0 / sfreq) - gain_sum = np.sum(fwd['sol']['data'], axis=1) + gain_sum = np.sum(fwd["sol"]["data"], axis=1) # Evoked evoked = read_evokeds(fname_evoked, condition=0) evoked.pick_types(meg=True) - with pytest.warns(RuntimeWarning, match='only .* positive values'): + with pytest.warns(RuntimeWarning, match="only .* positive values"): evoked = apply_forward(fwd, stc, evoked.info, start=start, stop=stop) data = evoked.data times = evoked.times # do some tests - assert_array_almost_equal(evoked.info['sfreq'], sfreq) + assert_array_almost_equal(evoked.info["sfreq"], sfreq) assert_array_almost_equal(np.sum(data, axis=1), n_times * gain_sum) assert_array_almost_equal(times[0], t_start) assert_array_almost_equal(times[-1], t_start + (n_times - 1) / sfreq) # vector stc_vec = VectorSourceEstimate( - fwd['source_nn'][:, :, np.newaxis] * stc.data[:, np.newaxis], - stc.vertices, stc.tmin, stc.tstep) - with pytest.warns(RuntimeWarning, match='very large'): + fwd["source_nn"][:, :, np.newaxis] * stc.data[:, np.newaxis], + stc.vertices, + stc.tmin, + stc.tstep, + ) + with pytest.warns(RuntimeWarning, match="very large"): evoked_2 = apply_forward(fwd, stc_vec, evoked.info) assert np.abs(evoked_2.data).mean() > 1e-5 assert_allclose(evoked.data, evoked_2.data, atol=1e-10) # Raw - with pytest.warns(RuntimeWarning, match='only .* positive values'): - raw_proj = apply_forward_raw(fwd, stc, evoked.info, start=start, - stop=stop) + with pytest.warns(RuntimeWarning, match="only .* positive values"): + raw_proj = apply_forward_raw(fwd, stc, evoked.info, start=start, stop=stop) data, times = raw_proj[:, :] # do some tests - assert_array_almost_equal(raw_proj.info['sfreq'], sfreq) + assert_array_almost_equal(raw_proj.info["sfreq"], sfreq) assert_array_almost_equal(np.sum(data, axis=1), n_times * gain_sum) - atol = 1. / sfreq + atol = 1.0 / sfreq assert_allclose(raw_proj.first_samp / sfreq, t_start, atol=atol) - assert_allclose(raw_proj.last_samp / sfreq, - t_start + (n_times - 1) / sfreq, atol=atol) + assert_allclose( + raw_proj.last_samp / sfreq, t_start + (n_times - 1) / sfreq, atol=atol + ) @testing.requires_testing_data @@ -257,47 +267,47 @@ def test_restrict_forward_to_stc(tmp_path): t_start = 0.123 fwd = read_forward_solution(fname_meeg) - fwd = convert_forward_solution(fwd, surf_ori=True, force_fixed=True, - use_cps=True) + fwd = convert_forward_solution(fwd, surf_ori=True, force_fixed=True, use_cps=True) fwd = pick_types_forward(fwd, meg=True) - vertno = [fwd['src'][0]['vertno'][0:15], fwd['src'][1]['vertno'][0:5]] + vertno = [fwd["src"][0]["vertno"][0:15], fwd["src"][1]["vertno"][0:5]] stc_data = np.ones((len(vertno[0]) + len(vertno[1]), n_times)) stc = SourceEstimate(stc_data, vertno, tmin=t_start, tstep=1.0 / sfreq) fwd_out = restrict_forward_to_stc(fwd, stc) - assert (isinstance(fwd_out, Forward)) + assert isinstance(fwd_out, Forward) - assert_equal(fwd_out['sol']['ncol'], 20) - assert_equal(fwd_out['src'][0]['nuse'], 15) - assert_equal(fwd_out['src'][1]['nuse'], 5) - assert_equal(fwd_out['src'][0]['vertno'], fwd['src'][0]['vertno'][0:15]) - assert_equal(fwd_out['src'][1]['vertno'], fwd['src'][1]['vertno'][0:5]) + assert_equal(fwd_out["sol"]["ncol"], 20) + assert_equal(fwd_out["src"][0]["nuse"], 15) + assert_equal(fwd_out["src"][1]["nuse"], 5) + assert_equal(fwd_out["src"][0]["vertno"], fwd["src"][0]["vertno"][0:15]) + assert_equal(fwd_out["src"][1]["vertno"], fwd["src"][1]["vertno"][0:5]) fwd = read_forward_solution(fname_meeg) fwd = convert_forward_solution(fwd, surf_ori=True, force_fixed=False) fwd = pick_types_forward(fwd, meg=True) - vertno = [fwd['src'][0]['vertno'][0:15], fwd['src'][1]['vertno'][0:5]] + vertno = [fwd["src"][0]["vertno"][0:15], fwd["src"][1]["vertno"][0:5]] stc_data = np.ones((len(vertno[0]) + len(vertno[1]), n_times)) stc = SourceEstimate(stc_data, vertno, tmin=t_start, tstep=1.0 / sfreq) fwd_out = restrict_forward_to_stc(fwd, stc) - assert_equal(fwd_out['sol']['ncol'], 60) - assert_equal(fwd_out['src'][0]['nuse'], 15) - assert_equal(fwd_out['src'][1]['nuse'], 5) - assert_equal(fwd_out['src'][0]['vertno'], fwd['src'][0]['vertno'][0:15]) - assert_equal(fwd_out['src'][1]['vertno'], fwd['src'][1]['vertno'][0:5]) + assert_equal(fwd_out["sol"]["ncol"], 60) + assert_equal(fwd_out["src"][0]["nuse"], 15) + assert_equal(fwd_out["src"][1]["nuse"], 5) + assert_equal(fwd_out["src"][0]["vertno"], fwd["src"][0]["vertno"][0:15]) + assert_equal(fwd_out["src"][1]["vertno"], fwd["src"][1]["vertno"][0:5]) # Test saving the restricted forward object. This only works if all fields # are properly accounted for. - fname_copy = tmp_path / 'copy-fwd.fif' - with pytest.warns(RuntimeWarning, match='stored on disk'): + fname_copy = tmp_path / "copy-fwd.fif" + with pytest.warns(RuntimeWarning, match="stored on disk"): write_forward_solution(fname_copy, fwd_out, overwrite=True) fwd_out_read = read_forward_solution(fname_copy) - fwd_out_read = convert_forward_solution(fwd_out_read, surf_ori=True, - force_fixed=False) + fwd_out_read = convert_forward_solution( + fwd_out_read, surf_ori=True, force_fixed=False + ) assert_forward_allclose(fwd_out, fwd_out_read) @@ -305,63 +315,61 @@ def test_restrict_forward_to_stc(tmp_path): def test_restrict_forward_to_label(tmp_path): """Test restriction of source space to label.""" fwd = read_forward_solution(fname_meeg) - fwd = convert_forward_solution(fwd, surf_ori=True, force_fixed=True, - use_cps=True) + fwd = convert_forward_solution(fwd, surf_ori=True, force_fixed=True, use_cps=True) fwd = pick_types_forward(fwd, meg=True) label_path = data_path / "MEG" / "sample" / "labels" - labels = ['Aud-lh', 'Vis-rh'] + labels = ["Aud-lh", "Vis-rh"] label_lh = read_label(label_path / (labels[0] + ".label")) label_rh = read_label(label_path / (labels[1] + ".label")) fwd_out = restrict_forward_to_label(fwd, [label_lh, label_rh]) - src_sel_lh = np.intersect1d(fwd['src'][0]['vertno'], label_lh.vertices) - src_sel_lh = np.searchsorted(fwd['src'][0]['vertno'], src_sel_lh) - vertno_lh = fwd['src'][0]['vertno'][src_sel_lh] + src_sel_lh = np.intersect1d(fwd["src"][0]["vertno"], label_lh.vertices) + src_sel_lh = np.searchsorted(fwd["src"][0]["vertno"], src_sel_lh) + vertno_lh = fwd["src"][0]["vertno"][src_sel_lh] - nuse_lh = fwd['src'][0]['nuse'] - src_sel_rh = np.intersect1d(fwd['src'][1]['vertno'], label_rh.vertices) - src_sel_rh = np.searchsorted(fwd['src'][1]['vertno'], src_sel_rh) - vertno_rh = fwd['src'][1]['vertno'][src_sel_rh] + nuse_lh = fwd["src"][0]["nuse"] + src_sel_rh = np.intersect1d(fwd["src"][1]["vertno"], label_rh.vertices) + src_sel_rh = np.searchsorted(fwd["src"][1]["vertno"], src_sel_rh) + vertno_rh = fwd["src"][1]["vertno"][src_sel_rh] src_sel_rh += nuse_lh - assert_equal(fwd_out['sol']['ncol'], len(src_sel_lh) + len(src_sel_rh)) - assert_equal(fwd_out['src'][0]['nuse'], len(src_sel_lh)) - assert_equal(fwd_out['src'][1]['nuse'], len(src_sel_rh)) - assert_equal(fwd_out['src'][0]['vertno'], vertno_lh) - assert_equal(fwd_out['src'][1]['vertno'], vertno_rh) + assert_equal(fwd_out["sol"]["ncol"], len(src_sel_lh) + len(src_sel_rh)) + assert_equal(fwd_out["src"][0]["nuse"], len(src_sel_lh)) + assert_equal(fwd_out["src"][1]["nuse"], len(src_sel_rh)) + assert_equal(fwd_out["src"][0]["vertno"], vertno_lh) + assert_equal(fwd_out["src"][1]["vertno"], vertno_rh) fwd = read_forward_solution(fname_meeg) fwd = pick_types_forward(fwd, meg=True) label_path = data_path / "MEG" / "sample" / "labels" - labels = ['Aud-lh', 'Vis-rh'] + labels = ["Aud-lh", "Vis-rh"] label_lh = read_label(label_path / (labels[0] + ".label")) label_rh = read_label(label_path / (labels[1] + ".label")) fwd_out = restrict_forward_to_label(fwd, [label_lh, label_rh]) - src_sel_lh = np.intersect1d(fwd['src'][0]['vertno'], label_lh.vertices) - src_sel_lh = np.searchsorted(fwd['src'][0]['vertno'], src_sel_lh) - vertno_lh = fwd['src'][0]['vertno'][src_sel_lh] + src_sel_lh = np.intersect1d(fwd["src"][0]["vertno"], label_lh.vertices) + src_sel_lh = np.searchsorted(fwd["src"][0]["vertno"], src_sel_lh) + vertno_lh = fwd["src"][0]["vertno"][src_sel_lh] - nuse_lh = fwd['src'][0]['nuse'] - src_sel_rh = np.intersect1d(fwd['src'][1]['vertno'], label_rh.vertices) - src_sel_rh = np.searchsorted(fwd['src'][1]['vertno'], src_sel_rh) - vertno_rh = fwd['src'][1]['vertno'][src_sel_rh] + nuse_lh = fwd["src"][0]["nuse"] + src_sel_rh = np.intersect1d(fwd["src"][1]["vertno"], label_rh.vertices) + src_sel_rh = np.searchsorted(fwd["src"][1]["vertno"], src_sel_rh) + vertno_rh = fwd["src"][1]["vertno"][src_sel_rh] src_sel_rh += nuse_lh - assert_equal(fwd_out['sol']['ncol'], - 3 * (len(src_sel_lh) + len(src_sel_rh))) - assert_equal(fwd_out['src'][0]['nuse'], len(src_sel_lh)) - assert_equal(fwd_out['src'][1]['nuse'], len(src_sel_rh)) - assert_equal(fwd_out['src'][0]['vertno'], vertno_lh) - assert_equal(fwd_out['src'][1]['vertno'], vertno_rh) + assert_equal(fwd_out["sol"]["ncol"], 3 * (len(src_sel_lh) + len(src_sel_rh))) + assert_equal(fwd_out["src"][0]["nuse"], len(src_sel_lh)) + assert_equal(fwd_out["src"][1]["nuse"], len(src_sel_rh)) + assert_equal(fwd_out["src"][0]["vertno"], vertno_lh) + assert_equal(fwd_out["src"][1]["vertno"], vertno_rh) # Test saving the restricted forward object. This only works if all fields # are properly accounted for. - fname_copy = tmp_path / 'copy-fwd.fif' + fname_copy = tmp_path / "copy-fwd.fif" write_forward_solution(fname_copy, fwd_out, overwrite=True) fwd_out_read = read_forward_solution(fname_copy) assert_forward_allclose(fwd_out, fwd_out_read) @@ -387,20 +395,27 @@ def test_average_forward_solution(tmp_path): # try an easy case fwd_copy = average_forward_solutions([fwd]) - assert (isinstance(fwd_copy, Forward)) - assert_array_equal(fwd['sol']['data'], fwd_copy['sol']['data']) + assert isinstance(fwd_copy, Forward) + assert_array_equal(fwd["sol"]["data"], fwd_copy["sol"]["data"]) # modify a fwd solution, save it, use MNE to average with old one - fwd_copy['sol']['data'] *= 0.5 - fname_copy = str(tmp_path / 'copy-fwd.fif') + fwd_copy["sol"]["data"] *= 0.5 + fname_copy = str(tmp_path / "copy-fwd.fif") write_forward_solution(fname_copy, fwd_copy, overwrite=True) - cmd = ('mne_average_forward_solutions', '--fwd', fname_meeg, '--fwd', - fname_copy, '--out', fname_copy) + cmd = ( + "mne_average_forward_solutions", + "--fwd", + fname_meeg, + "--fwd", + fname_copy, + "--out", + fname_copy, + ) run_subprocess(cmd) # now let's actually do it, with one filename and one fwd fwd_ave = average_forward_solutions([fwd, fwd_copy]) - assert_array_equal(0.75 * fwd['sol']['data'], fwd_ave['sol']['data']) + assert_array_equal(0.75 * fwd["sol"]["data"], fwd_ave["sol"]["data"]) # fwd_ave_mne = read_forward_solution(fname_copy) # assert_array_equal(fwd_ave_mne['sol']['data'], fwd_ave['sol']['data']) @@ -416,32 +431,32 @@ def test_priors(): # Depth prior fwd = read_forward_solution(fname_meeg) assert not is_fixed_orient(fwd) - n_sources = fwd['nsource'] + n_sources = fwd["nsource"] info = read_info(fname_evoked) depth_prior = compute_depth_prior(fwd, info, exp=0.8) assert depth_prior.shape == (3 * n_sources,) - depth_prior = compute_depth_prior(fwd, info, exp=0.) - assert_array_equal(depth_prior, 1.) + depth_prior = compute_depth_prior(fwd, info, exp=0.0) + assert_array_equal(depth_prior, 1.0) with pytest.raises(ValueError, match='must be "whiten"'): - compute_depth_prior(fwd, info, limit_depth_chs='foo') - with pytest.raises(ValueError, match='noise_cov must be a Covariance'): - compute_depth_prior(fwd, info, limit_depth_chs='whiten') + compute_depth_prior(fwd, info, limit_depth_chs="foo") + with pytest.raises(ValueError, match="noise_cov must be a Covariance"): + compute_depth_prior(fwd, info, limit_depth_chs="whiten") fwd_fixed = convert_forward_solution(fwd, force_fixed=True) depth_prior = compute_depth_prior(fwd_fixed, info=info) assert depth_prior.shape == (n_sources,) # Orientation prior - orient_prior = compute_orient_prior(fwd, 1.) - assert_array_equal(orient_prior, 1.) - orient_prior = compute_orient_prior(fwd_fixed, 0.) - assert_array_equal(orient_prior, 1.) - with pytest.raises(ValueError, match='oriented in surface coordinates'): + orient_prior = compute_orient_prior(fwd, 1.0) + assert_array_equal(orient_prior, 1.0) + orient_prior = compute_orient_prior(fwd_fixed, 0.0) + assert_array_equal(orient_prior, 1.0) + with pytest.raises(ValueError, match="oriented in surface coordinates"): compute_orient_prior(fwd, 0.5) fwd_surf_ori = convert_forward_solution(fwd, surf_ori=True) orient_prior = compute_orient_prior(fwd_surf_ori, 0.5) - assert all(np.in1d(orient_prior, (0.5, 1.))) - with pytest.raises(ValueError, match='between 0 and 1'): + assert all(np.in1d(orient_prior, (0.5, 1.0))) + with pytest.raises(ValueError, match="between 0 and 1"): compute_orient_prior(fwd_surf_ori, -0.5) - with pytest.raises(ValueError, match='with fixed orientation'): + with pytest.raises(ValueError, match="with fixed orientation"): compute_orient_prior(fwd_fixed, 0.5) @@ -449,8 +464,8 @@ def test_priors(): def test_equalize_channels(): """Test equalization of channels for instances of Forward.""" fwd1 = read_forward_solution(fname_meeg) - fwd1.pick_channels(['EEG 001', 'EEG 002', 'EEG 003']) - fwd2 = fwd1.copy().pick_channels(['EEG 002', 'EEG 001'], ordered=True) + fwd1.pick_channels(["EEG 001", "EEG 002", "EEG 003"]) + fwd2 = fwd1.copy().pick_channels(["EEG 002", "EEG 001"], ordered=True) fwd1, fwd2 = equalize_channels([fwd1, fwd2]) - assert fwd1.ch_names == ['EEG 001', 'EEG 002'] - assert fwd2.ch_names == ['EEG 001', 'EEG 002'] + assert fwd1.ch_names == ["EEG 001", "EEG 002"] + assert fwd2.ch_names == ["EEG 001", "EEG 002"] diff --git a/mne/forward/tests/test_make_forward.py b/mne/forward/tests/test_make_forward.py index b23a1ec2f6e..627822aca95 100644 --- a/mne/forward/tests/test_make_forward.py +++ b/mne/forward/tests/test_make_forward.py @@ -11,37 +11,50 @@ from mne.datasets import testing from mne.io import read_raw_fif, read_raw_kit, read_raw_bti, read_info from mne.io.constants import FIFF -from mne import (read_forward_solution, write_forward_solution, - make_forward_solution, convert_forward_solution, - setup_volume_source_space, read_source_spaces, create_info, - make_sphere_model, pick_types_forward, pick_info, pick_types, - read_evokeds, read_cov, read_dipole, - get_volume_labels_from_aseg) +from mne import ( + read_forward_solution, + write_forward_solution, + make_forward_solution, + convert_forward_solution, + setup_volume_source_space, + read_source_spaces, + create_info, + make_sphere_model, + pick_types_forward, + pick_info, + pick_types, + read_evokeds, + read_cov, + read_dipole, + get_volume_labels_from_aseg, +) from mne.surface import _get_ico_surface from mne.transforms import Transform -from mne.utils import (requires_mne, run_subprocess, catch_logging, - requires_mne_mark, requires_openmeeg_mark) +from mne.utils import ( + requires_mne, + run_subprocess, + catch_logging, + requires_mne_mark, + requires_openmeeg_mark, +) from mne.forward._make_forward import _create_meg_coils, make_forward_dipole from mne.forward._compute_forward import _magnetic_dipole_field_vec from mne.forward import Forward, _do_forward_solution, use_coil_def from mne.dipole import Dipole, fit_dipole from mne.simulation import simulate_evoked from mne.source_estimate import VolSourceEstimate -from mne.source_space import (write_source_spaces, _compare_source_spaces, - setup_source_space) +from mne.source_space import ( + write_source_spaces, + _compare_source_spaces, + setup_source_space, +) from mne.forward.tests.test_forward import assert_forward_allclose data_path = testing.data_path(download=False) -fname_meeg = ( - data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-4-fwd.fif" -) +fname_meeg = data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-4-fwd.fif" fname_raw = ( - Path(__file__).parent.parent.parent - / "io" - / "tests" - / "data" - / "test_raw.fif" + Path(__file__).parent.parent.parent / "io" / "tests" / "data" / "test_raw.fif" ) fname_evo = data_path / "MEG" / "sample" / "sample_audvis_trunc-ave.fif" fname_cov = data_path / "MEG" / "sample" / "sample_audvis_trunc-cov.fif" @@ -49,9 +62,7 @@ fname_trans = data_path / "MEG" / "sample" / "sample_audvis_trunc-trans.fif" subjects_dir = data_path / "subjects" fname_src = subjects_dir / "sample" / "bem" / "sample-oct-4-src.fif" -fname_bem = ( - subjects_dir / "sample" / "bem" / "sample-1280-1280-1280-bem-sol.fif" -) +fname_bem = subjects_dir / "sample" / "bem" / "sample-1280-1280-1280-bem-sol.fif" fname_aseg = subjects_dir / "sample" / "mri" / "aseg.mgz" fname_bem_meg = subjects_dir / "sample" / "bem" / "sample-1280-bem-sol.fif" @@ -70,9 +81,9 @@ def _col_corrs(a, b): a_std = np.sqrt((a * a).mean(0)) b_std = np.sqrt((b * b).mean(0)) all_zero = (a_std == 0) & (b_std == 0) - num[all_zero] = 1. - a_std[all_zero] = 1. - b_std[all_zero] = 1. + num[all_zero] = 1.0 + a_std[all_zero] = 1.0 + b_std[all_zero] = 1.0 return num / (a_std * b_std) @@ -81,67 +92,94 @@ def _rdm(a, b): a_norm = np.linalg.norm(a, axis=0) b_norm = np.linalg.norm(b, axis=0) all_zero = (a_norm == 0) & (b_norm == 0) - a_norm[all_zero] = 1. - b_norm[all_zero] = 1. + a_norm[all_zero] = 1.0 + b_norm[all_zero] = 1.0 return a_norm / b_norm -def _compare_forwards(fwd, fwd_py, n_sensors, n_src, - meg_rtol=1e-4, meg_atol=1e-9, - meg_corr_tol=0.99, meg_rdm_tol=0.01, - eeg_rtol=1e-3, eeg_atol=1e-3, - eeg_corr_tol=0.99, eeg_rdm_tol=0.01): +def _compare_forwards( + fwd, + fwd_py, + n_sensors, + n_src, + meg_rtol=1e-4, + meg_atol=1e-9, + meg_corr_tol=0.99, + meg_rdm_tol=0.01, + eeg_rtol=1e-3, + eeg_atol=1e-3, + eeg_corr_tol=0.99, + eeg_rdm_tol=0.01, +): """Test forwards.""" # check source spaces - assert len(fwd['src']) == len(fwd_py['src']) - _compare_source_spaces(fwd['src'], fwd_py['src'], mode='approx') + assert len(fwd["src"]) == len(fwd_py["src"]) + _compare_source_spaces(fwd["src"], fwd_py["src"], mode="approx") for surf_ori, force_fixed in product([False, True], [False, True]): # use copy here to leave our originals unmodified - fwd = convert_forward_solution(fwd, surf_ori, force_fixed, copy=True, - use_cps=True) - fwd_py = convert_forward_solution(fwd_py, surf_ori, force_fixed, - copy=True, use_cps=True) + fwd = convert_forward_solution( + fwd, surf_ori, force_fixed, copy=True, use_cps=True + ) + fwd_py = convert_forward_solution( + fwd_py, surf_ori, force_fixed, copy=True, use_cps=True + ) check_src = n_src // 3 if force_fixed else n_src - for key in ('nchan', 'source_rr', 'source_ori', - 'surf_ori', 'coord_frame', 'nsource'): - assert_allclose(fwd_py[key], fwd[key], rtol=1e-4, atol=1e-7, - err_msg=key) + for key in ( + "nchan", + "source_rr", + "source_ori", + "surf_ori", + "coord_frame", + "nsource", + ): + assert_allclose(fwd_py[key], fwd[key], rtol=1e-4, atol=1e-7, err_msg=key) # In surf_ori=True only Z matters for source_nn if surf_ori and not force_fixed: ori_sl = slice(2, None, 3) else: ori_sl = slice(None) - assert_allclose(fwd_py['source_nn'][ori_sl], fwd['source_nn'][ori_sl], - rtol=1e-4, atol=1e-6) - assert_allclose(fwd_py['mri_head_t']['trans'], - fwd['mri_head_t']['trans'], rtol=1e-5, atol=1e-8) - - assert fwd_py['sol']['data'].shape == (n_sensors, check_src) - assert len(fwd['sol']['row_names']) == n_sensors - assert len(fwd_py['sol']['row_names']) == n_sensors + assert_allclose( + fwd_py["source_nn"][ori_sl], fwd["source_nn"][ori_sl], rtol=1e-4, atol=1e-6 + ) + assert_allclose( + fwd_py["mri_head_t"]["trans"], + fwd["mri_head_t"]["trans"], + rtol=1e-5, + atol=1e-8, + ) + + assert fwd_py["sol"]["data"].shape == (n_sensors, check_src) + assert len(fwd["sol"]["row_names"]) == n_sensors + assert len(fwd_py["sol"]["row_names"]) == n_sensors # check MEG - fwd_meg = fwd['sol']['data'][:306, ori_sl] - fwd_meg_py = fwd_py['sol']['data'][:306, ori_sl] - assert_allclose(fwd_meg, fwd_meg_py, rtol=meg_rtol, atol=meg_atol, - err_msg='MEG mismatch') + fwd_meg = fwd["sol"]["data"][:306, ori_sl] + fwd_meg_py = fwd_py["sol"]["data"][:306, ori_sl] + assert_allclose( + fwd_meg, fwd_meg_py, rtol=meg_rtol, atol=meg_atol, err_msg="MEG mismatch" + ) meg_corrs = _col_corrs(fwd_meg, fwd_meg_py) - assert_array_less(meg_corr_tol, meg_corrs, err_msg='MEG corr/MAG') + assert_array_less(meg_corr_tol, meg_corrs, err_msg="MEG corr/MAG") meg_rdm = _rdm(fwd_meg, fwd_meg_py) - assert_allclose(meg_rdm, 1, atol=meg_rdm_tol, err_msg='MEG RDM') + assert_allclose(meg_rdm, 1, atol=meg_rdm_tol, err_msg="MEG RDM") # check EEG - if fwd['sol']['data'].shape[0] > 306: - fwd_eeg = fwd['sol']['data'][306:, ori_sl] - fwd_eeg_py = fwd['sol']['data'][306:, ori_sl] - assert_allclose(fwd_eeg, fwd_eeg_py, rtol=eeg_rtol, atol=eeg_atol, - err_msg='EEG mismatch') + if fwd["sol"]["data"].shape[0] > 306: + fwd_eeg = fwd["sol"]["data"][306:, ori_sl] + fwd_eeg_py = fwd["sol"]["data"][306:, ori_sl] + assert_allclose( + fwd_eeg, + fwd_eeg_py, + rtol=eeg_rtol, + atol=eeg_atol, + err_msg="EEG mismatch", + ) # To test so-called MAG we use correlation (related to cosine # similarity) and also RDM to test the amplitude mismatch eeg_corrs = _col_corrs(fwd_eeg, fwd_eeg_py) - assert_array_less(eeg_corr_tol, eeg_corrs, err_msg='EEG corr/MAG') + assert_array_less(eeg_corr_tol, eeg_corrs, err_msg="EEG corr/MAG") eeg_rdm = _rdm(fwd_eeg, fwd_eeg_py) - assert_allclose(eeg_rdm, 1, atol=eeg_rdm_tol, err_msg='EEG RDM') + assert_allclose(eeg_rdm, 1, atol=eeg_rdm_tol, err_msg="EEG RDM") def test_magnetic_dipole(): @@ -149,24 +187,24 @@ def test_magnetic_dipole(): info = read_info(fname_raw) picks = pick_types(info, meg=True, eeg=False, exclude=[]) info = pick_info(info, picks[:12]) - coils = _create_meg_coils(info['chs'], 'normal', None) + coils = _create_meg_coils(info["chs"], "normal", None) # magnetic dipole far (meters!) from device origin - r0 = np.array([0., 13., -6.]) - for ch, coil in zip(info['chs'], coils): - rr = (ch['loc'][:3] + r0) / 2. # get halfway closer + r0 = np.array([0.0, 13.0, -6.0]) + for ch, coil in zip(info["chs"], coils): + rr = (ch["loc"][:3] + r0) / 2.0 # get halfway closer far_fwd = _magnetic_dipole_field_vec(r0[np.newaxis, :], [coil]) near_fwd = _magnetic_dipole_field_vec(rr[np.newaxis, :], [coil]) - ratio = 8. if ch['ch_name'][-1] == '1' else 16. # grad vs mag + ratio = 8.0 if ch["ch_name"][-1] == "1" else 16.0 # grad vs mag assert_allclose(np.median(near_fwd / far_fwd), ratio, atol=1e-1) # degenerate case - r0 = coils[0]['rmag'][[0]] - with pytest.raises(RuntimeError, match='Coil too close'): + r0 = coils[0]["rmag"][[0]] + with pytest.raises(RuntimeError, match="Coil too close"): _magnetic_dipole_field_vec(r0, coils[:1]) - with pytest.warns(RuntimeWarning, match='Coil too close'): - fwd = _magnetic_dipole_field_vec(r0, coils[:1], too_close='warning') + with pytest.warns(RuntimeWarning, match="Coil too close"): + fwd = _magnetic_dipole_field_vec(r0, coils[:1], too_close="warning") assert not np.isfinite(fwd).any() - with np.errstate(invalid='ignore'): - fwd = _magnetic_dipole_field_vec(r0, coils[:1], too_close='info') + with np.errstate(invalid="ignore"): + fwd = _magnetic_dipole_field_vec(r0, coils[:1], too_close="info") assert not np.isfinite(fwd).any() @@ -181,116 +219,183 @@ def test_make_forward_solution_kit(tmp_path, fname_src_small): fname_kit_raw = kit_dir / "test_bin_raw.fif" # first use mne-C: convert file, make forward solution - fwd = _do_forward_solution('sample', fname_kit_raw, src=fname_src_small, - bem=fname_bem_meg, mri=trans_path, - eeg=False, meg=True, subjects_dir=subjects_dir) - assert (isinstance(fwd, Forward)) + fwd = _do_forward_solution( + "sample", + fname_kit_raw, + src=fname_src_small, + bem=fname_bem_meg, + mri=trans_path, + eeg=False, + meg=True, + subjects_dir=subjects_dir, + ) + assert isinstance(fwd, Forward) # now let's use python with the same raw file src = read_source_spaces(fname_src_small) - fwd_py = make_forward_solution(fname_kit_raw, trans_path, src, - fname_bem_meg, eeg=False, meg=True) + fwd_py = make_forward_solution( + fname_kit_raw, trans_path, src, fname_bem_meg, eeg=False, meg=True + ) _compare_forwards(fwd, fwd_py, 157, n_src_small) - assert (isinstance(fwd_py, Forward)) + assert isinstance(fwd_py, Forward) # now let's use mne-python all the way raw_py = read_raw_kit(sqd_path, mrk_path, elp_path, hsp_path) # without ignore_ref=True, this should throw an error: - with pytest.raises(NotImplementedError, match='Cannot.*KIT reference'): - make_forward_solution(raw_py.info, src=src, eeg=False, meg=True, - bem=fname_bem_meg, trans=trans_path) + with pytest.raises(NotImplementedError, match="Cannot.*KIT reference"): + make_forward_solution( + raw_py.info, + src=src, + eeg=False, + meg=True, + bem=fname_bem_meg, + trans=trans_path, + ) # check that asking for eeg channels (even if they don't exist) is handled - meg_only_info = pick_info(raw_py.info, pick_types(raw_py.info, meg=True, - eeg=False)) - fwd_py = make_forward_solution(meg_only_info, src=src, meg=True, eeg=True, - bem=fname_bem_meg, trans=trans_path, - ignore_ref=True) - _compare_forwards(fwd, fwd_py, 157, n_src_small, - meg_rtol=1e-3, meg_atol=1e-7) + meg_only_info = pick_info(raw_py.info, pick_types(raw_py.info, meg=True, eeg=False)) + fwd_py = make_forward_solution( + meg_only_info, + src=src, + meg=True, + eeg=True, + bem=fname_bem_meg, + trans=trans_path, + ignore_ref=True, + ) + _compare_forwards(fwd, fwd_py, 157, n_src_small, meg_rtol=1e-3, meg_atol=1e-7) @requires_mne def test_make_forward_solution_bti(fname_src_small): """Test BTI end-to-end versus C.""" - bti_pdf = bti_dir / 'test_pdf_linux' - bti_config = bti_dir / 'test_config_linux' - bti_hs = bti_dir / 'test_hs_linux' - fname_bti_raw = bti_dir / 'exported4D_linux_raw.fif' + bti_pdf = bti_dir / "test_pdf_linux" + bti_config = bti_dir / "test_config_linux" + bti_hs = bti_dir / "test_hs_linux" + fname_bti_raw = bti_dir / "exported4D_linux_raw.fif" raw_py = read_raw_bti(bti_pdf, bti_config, bti_hs, preload=False) src = read_source_spaces(fname_src_small) - fwd_py = make_forward_solution(raw_py.info, src=src, eeg=False, meg=True, - bem=fname_bem_meg, trans=trans_path) - fwd = _do_forward_solution('sample', fname_bti_raw, src=fname_src_small, - bem=fname_bem_meg, mri=trans_path, - eeg=False, meg=True, subjects_dir=subjects_dir) + fwd_py = make_forward_solution( + raw_py.info, src=src, eeg=False, meg=True, bem=fname_bem_meg, trans=trans_path + ) + fwd = _do_forward_solution( + "sample", + fname_bti_raw, + src=fname_src_small, + bem=fname_bem_meg, + mri=trans_path, + eeg=False, + meg=True, + subjects_dir=subjects_dir, + ) _compare_forwards(fwd, fwd_py, 248, n_src_small) -@pytest.mark.parametrize('other', [ - pytest.param('MNE-C', marks=requires_mne_mark()), - pytest.param('openmeeg', marks=requires_openmeeg_mark()), -]) +@pytest.mark.parametrize( + "other", + [ + pytest.param("MNE-C", marks=requires_mne_mark()), + pytest.param("openmeeg", marks=requires_openmeeg_mark()), + ], +) def test_make_forward_solution_ctf(tmp_path, fname_src_small, other): """Test CTF w/compensation against MNE-C or OpenMEEG.""" - pytest.importorskip('nibabel') + pytest.importorskip("nibabel") src = read_source_spaces(fname_src_small) raw = read_raw_fif(fname_ctf_raw) assert raw.compensation_grade == 3 - if other == 'openmeeg': - mindist = 20. + if other == "openmeeg": + mindist = 20.0 n_src_want = 51 else: - assert other == 'MNE-C' - mindist = 0. + assert other == "MNE-C" + mindist = 0.0 n_src_want = n_src_small assert n_src_want == 108 - mindist = 20. if other == 'openmeeg' else 0. + mindist = 20.0 if other == "openmeeg" else 0.0 fwd_py = make_forward_solution( - fname_ctf_raw, fname_trans, src, fname_bem_meg, eeg=False, - mindist=mindist, verbose=True) - - if other == 'openmeeg': + fname_ctf_raw, + fname_trans, + src, + fname_bem_meg, + eeg=False, + mindist=mindist, + verbose=True, + ) + + if other == "openmeeg": # TODO: This should be a 1-layer, but it's broken # (some correlations become negative!)... bem_surfaces = read_bem_surfaces(fname_bem) # fname_bem_meg - bem = make_bem_solution(bem_surfaces, solver='openmeeg') + bem = make_bem_solution(bem_surfaces, solver="openmeeg") # TODO: These tolerances are bad tol_kwargs = dict(meg_atol=1, meg_corr_tol=0.65, meg_rdm_tol=0.6) fwd = make_forward_solution( - fname_ctf_raw, fname_trans, src, bem, eeg=False, mindist=mindist, - verbose=True) + fname_ctf_raw, + fname_trans, + src, + bem, + eeg=False, + mindist=mindist, + verbose=True, + ) else: - assert other == 'MNE-C' + assert other == "MNE-C" bem = None tol_kwargs = dict() fwd = _do_forward_solution( - 'sample', fname_ctf_raw, mri=fname_trans, src=fname_src_small, - bem=fname_bem_meg, eeg=False, meg=True, subjects_dir=subjects_dir, - mindist=mindist) + "sample", + fname_ctf_raw, + mri=fname_trans, + src=fname_src_small, + bem=fname_bem_meg, + eeg=False, + meg=True, + subjects_dir=subjects_dir, + mindist=mindist, + ) _compare_forwards(fwd, fwd_py, 274, n_src_want, **tol_kwargs) # CTF with compensation changed in python ctf_raw = read_raw_fif(fname_ctf_raw) - ctf_raw.info['bads'] = ['MRO24-2908'] # test that it works with some bads + ctf_raw.info["bads"] = ["MRO24-2908"] # test that it works with some bads ctf_raw.apply_gradient_compensation(2) fwd_py = make_forward_solution( - ctf_raw.info, fname_trans, src, fname_bem_meg, eeg=False, meg=True, - mindist=mindist) - if other == 'openmeeg': + ctf_raw.info, + fname_trans, + src, + fname_bem_meg, + eeg=False, + meg=True, + mindist=mindist, + ) + if other == "openmeeg": assert bem is not None fwd = make_forward_solution( - ctf_raw.info, fname_trans, src, bem, eeg=False, mindist=mindist, - verbose=True) + ctf_raw.info, + fname_trans, + src, + bem, + eeg=False, + mindist=mindist, + verbose=True, + ) else: fwd = _do_forward_solution( - 'sample', ctf_raw, mri=fname_trans, src=fname_src_small, - bem=fname_bem_meg, eeg=False, meg=True, subjects_dir=subjects_dir, - mindist=mindist) + "sample", + ctf_raw, + mri=fname_trans, + src=fname_src_small, + bem=fname_bem_meg, + eeg=False, + meg=True, + subjects_dir=subjects_dir, + mindist=mindist, + ) _compare_forwards(fwd, fwd_py, 274, n_src_want, **tol_kwargs) - fname_temp = tmp_path / 'test-ctf-fwd.fif' + fname_temp = tmp_path / "test-ctf-fwd.fif" write_forward_solution(fname_temp, fwd_py) fwd_py2 = read_forward_solution(fname_temp) _compare_forwards(fwd_py, fwd_py2, 274, n_src_want, **tol_kwargs) @@ -303,25 +408,32 @@ def test_make_forward_solution_basic(): with catch_logging() as log: # make sure everything can be path-like (gh #10872) fwd_py = make_forward_solution( - Path(fname_raw), Path(fname_trans), Path(fname_src), - Path(fname_bem), mindist=5., verbose=True) + Path(fname_raw), + Path(fname_trans), + Path(fname_src), + Path(fname_bem), + mindist=5.0, + verbose=True, + ) log = log.getvalue() - assert 'Total 258/258 points inside the surface' in log - assert (isinstance(fwd_py, Forward)) + assert "Total 258/258 points inside the surface" in log + assert isinstance(fwd_py, Forward) fwd = read_forward_solution(fname_meeg) - assert (isinstance(fwd, Forward)) + assert isinstance(fwd, Forward) _compare_forwards(fwd, fwd_py, 366, 1494, meg_rtol=1e-3) # Homogeneous model - with pytest.raises(RuntimeError, match='homogeneous.*1-layer.*EEG'): - make_forward_solution(fname_raw, fname_trans, fname_src, - fname_bem_meg) + with pytest.raises(RuntimeError, match="homogeneous.*1-layer.*EEG"): + make_forward_solution(fname_raw, fname_trans, fname_src, fname_bem_meg) @requires_openmeeg_mark() -@pytest.mark.parametrize("n_layers", [ - 3, - pytest.param(1, marks=pytest.mark.xfail(raises=RuntimeError)), -]) +@pytest.mark.parametrize( + "n_layers", + [ + 3, + pytest.param(1, marks=pytest.mark.xfail(raises=RuntimeError)), + ], +) @testing.requires_testing_data def test_make_forward_solution_openmeeg(n_layers): """Test making M-EEG forward solution from OpenMEEG.""" @@ -329,33 +441,45 @@ def test_make_forward_solution_openmeeg(n_layers): bem_surfaces = read_bem_surfaces(fname_bem) raw = read_raw_fif(fname_raw) n_sensors = 366 - ch_types = ['eeg', 'meg'] + ch_types = ["eeg", "meg"] if n_layers == 1: - ch_types = ['meg'] + ch_types = ["meg"] bem_surfaces = bem_surfaces[-1:] - assert bem_surfaces[0]['id'] == FIFF.FIFFV_BEM_SURF_ID_BRAIN + assert bem_surfaces[0]["id"] == FIFF.FIFFV_BEM_SURF_ID_BRAIN n_sensors = 306 raw.pick(ch_types) n_sources_kept = 501 // 3 fwds = dict() for solver in ["openmeeg", "mne"]: bem = make_bem_solution(bem_surfaces, solver=solver) - assert bem['solver'] == solver + assert bem["solver"] == solver with catch_logging() as log: # make sure everything can be path-like (gh #10872) fwd = make_forward_solution( - raw.info, Path(fname_trans), Path(fname_src), - bem, mindist=20., verbose=True) + raw.info, + Path(fname_trans), + Path(fname_src), + bem, + mindist=20.0, + verbose=True, + ) log = log.getvalue() - assert 'Total 258/258 points inside the surface' in log - assert (isinstance(fwd, Forward)) + assert "Total 258/258 points inside the surface" in log + assert isinstance(fwd, Forward) fwds[solver] = fwd del fwd - _compare_forwards(fwds["openmeeg"], - fwds["mne"], n_sensors, n_sources_kept * 3, - meg_atol=1, eeg_atol=100, - meg_corr_tol=0.98, eeg_corr_tol=0.98, - meg_rdm_tol=0.1, eeg_rdm_tol=0.2) + _compare_forwards( + fwds["openmeeg"], + fwds["mne"], + n_sensors, + n_sources_kept * 3, + meg_atol=1, + eeg_atol=100, + meg_corr_tol=0.98, + eeg_corr_tol=0.98, + meg_rdm_tol=0.1, + eeg_rdm_tol=0.2, + ) def test_make_forward_solution_discrete(tmp_path, small_surf_src): @@ -363,31 +487,36 @@ def test_make_forward_solution_discrete(tmp_path, small_surf_src): # smoke test for depth weighting and discrete source spaces src = small_surf_src src = src + setup_volume_source_space( - pos=dict(rr=src[0]['rr'][src[0]['vertno'][:3]].copy(), - nn=src[0]['nn'][src[0]['vertno'][:3]].copy())) + pos=dict( + rr=src[0]["rr"][src[0]["vertno"][:3]].copy(), + nn=src[0]["nn"][src[0]["vertno"][:3]].copy(), + ) + ) sphere = make_sphere_model() - fwd = make_forward_solution(fname_raw, fname_trans, src, sphere, - meg=True, eeg=False) + fwd = make_forward_solution( + fname_raw, fname_trans, src, sphere, meg=True, eeg=False + ) convert_forward_solution(fwd, surf_ori=True) n_src_small = 108 # this is the resulting # of verts in fwd -@pytest.fixture(scope='module', params=[testing._pytest_param()]) +@pytest.fixture(scope="module", params=[testing._pytest_param()]) def small_surf_src(): """Create a small surface source space.""" - pytest.importorskip('nibabel') - src = setup_source_space('sample', 'oct2', subjects_dir=subjects_dir, - add_dist=False) - assert sum(s['nuse'] for s in src) * 3 == n_src_small + pytest.importorskip("nibabel") + src = setup_source_space( + "sample", "oct2", subjects_dir=subjects_dir, add_dist=False + ) + assert sum(s["nuse"] for s in src) * 3 == n_src_small return src @pytest.fixture() def fname_src_small(tmp_path, small_surf_src): """Create a small source space.""" - fname_src_small = tmp_path / 'sample-oct-2-src.fif' + fname_src_small = tmp_path / "sample-oct-2-src.fif" write_source_spaces(fname_src_small, small_surf_src) return fname_src_small @@ -396,39 +525,65 @@ def fname_src_small(tmp_path, small_surf_src): @pytest.mark.timeout(90) # can take longer than 60 s on Travis def test_make_forward_solution_sphere(tmp_path, fname_src_small): """Test making a forward solution with a sphere model.""" - out_name = tmp_path / 'tmp-fwd.fif' - run_subprocess(['mne_forward_solution', '--meg', '--eeg', - '--meas', fname_raw, '--src', fname_src_small, - '--mri', fname_trans, '--fwd', out_name]) + out_name = tmp_path / "tmp-fwd.fif" + run_subprocess( + [ + "mne_forward_solution", + "--meg", + "--eeg", + "--meas", + fname_raw, + "--src", + fname_src_small, + "--mri", + fname_trans, + "--fwd", + out_name, + ] + ) fwd = read_forward_solution(out_name) sphere = make_sphere_model(verbose=True) src = read_source_spaces(fname_src_small) - fwd_py = make_forward_solution(fname_raw, fname_trans, src, sphere, - meg=True, eeg=True, verbose=True) - _compare_forwards(fwd, fwd_py, 366, 108, - meg_rtol=5e-1, meg_atol=1e-6, - eeg_rtol=5e-1, eeg_atol=5e-1) + fwd_py = make_forward_solution( + fname_raw, fname_trans, src, sphere, meg=True, eeg=True, verbose=True + ) + _compare_forwards( + fwd, + fwd_py, + 366, + 108, + meg_rtol=5e-1, + meg_atol=1e-6, + eeg_rtol=5e-1, + eeg_atol=5e-1, + ) # Since the above is pretty lax, let's check a different way for meg, eeg in zip([True, False], [False, True]): fwd_ = pick_types_forward(fwd, meg=meg, eeg=eeg) fwd_py_ = pick_types_forward(fwd, meg=meg, eeg=eeg) - assert_allclose(np.corrcoef(fwd_['sol']['data'].ravel(), - fwd_py_['sol']['data'].ravel())[0, 1], - 1.0, rtol=1e-3) + assert_allclose( + np.corrcoef(fwd_["sol"]["data"].ravel(), fwd_py_["sol"]["data"].ravel())[ + 0, 1 + ], + 1.0, + rtol=1e-3, + ) # Number of layers in the sphere model doesn't matter for MEG # (as long as no sources are omitted due to distance) - assert len(sphere['layers']) == 4 - fwd = make_forward_solution(fname_raw, fname_trans, src, sphere, - meg=True, eeg=False) + assert len(sphere["layers"]) == 4 + fwd = make_forward_solution( + fname_raw, fname_trans, src, sphere, meg=True, eeg=False + ) sphere_1 = make_sphere_model(head_radius=None) - assert len(sphere_1['layers']) == 0 - assert_array_equal(sphere['r0'], sphere_1['r0']) - fwd_1 = make_forward_solution(fname_raw, fname_trans, src, sphere, - meg=True, eeg=False) + assert len(sphere_1["layers"]) == 0 + assert_array_equal(sphere["r0"], sphere_1["r0"]) + fwd_1 = make_forward_solution( + fname_raw, fname_trans, src, sphere, meg=True, eeg=False + ) _compare_forwards(fwd, fwd_1, 306, 108, meg_rtol=1e-12, meg_atol=1e-12) # Homogeneous model sphere = make_sphere_model(head_radius=None) - with pytest.raises(RuntimeError, match='zero shells.*EEG'): + with pytest.raises(RuntimeError, match="zero shells.*EEG"): make_forward_solution(fname_raw, fname_trans, src, sphere) @@ -436,7 +591,7 @@ def test_make_forward_solution_sphere(tmp_path, fname_src_small): @testing.requires_testing_data def test_forward_mixed_source_space(tmp_path): """Test making the forward solution for a mixed source space.""" - pytest.importorskip('nibabel') + pytest.importorskip("nibabel") # get the surface source space rng = np.random.RandomState(0) surf = read_source_spaces(fname_src) @@ -444,42 +599,49 @@ def test_forward_mixed_source_space(tmp_path): # setup two volume source spaces label_names = get_volume_labels_from_aseg(fname_aseg) vol_labels = rng.choice(label_names, 2) - with pytest.warns(RuntimeWarning, match='Found no usable.*CC_Mid_Ant.*'): - vol1 = setup_volume_source_space('sample', pos=20., mri=fname_aseg, - volume_label=vol_labels[0], - add_interpolator=False) - vol2 = setup_volume_source_space('sample', pos=20., mri=fname_aseg, - volume_label=vol_labels[1], - add_interpolator=False) + with pytest.warns(RuntimeWarning, match="Found no usable.*CC_Mid_Ant.*"): + vol1 = setup_volume_source_space( + "sample", + pos=20.0, + mri=fname_aseg, + volume_label=vol_labels[0], + add_interpolator=False, + ) + vol2 = setup_volume_source_space( + "sample", + pos=20.0, + mri=fname_aseg, + volume_label=vol_labels[1], + add_interpolator=False, + ) # merge surfaces and volume src = surf + vol1 + vol2 # calculate forward solution fwd = make_forward_solution(fname_raw, fname_trans, src, fname_bem) - assert (repr(fwd)) + assert repr(fwd) # extract source spaces - src_from_fwd = fwd['src'] + src_from_fwd = fwd["src"] # get the coordinate frame of each source space - coord_frames = np.array([s['coord_frame'] for s in src_from_fwd]) + coord_frames = np.array([s["coord_frame"] for s in src_from_fwd]) # assert that all source spaces are in head coordinates - assert ((coord_frames == FIFF.FIFFV_COORD_HEAD).all()) + assert (coord_frames == FIFF.FIFFV_COORD_HEAD).all() # run tests for SourceSpaces.export_volume - fname_img = tmp_path / 'temp-image.mgz' + fname_img = tmp_path / "temp-image.mgz" # head coordinates and mri_resolution, but trans file - with pytest.raises(ValueError, match='trans containing mri to head'): + with pytest.raises(ValueError, match="trans containing mri to head"): src_from_fwd.export_volume(fname_img, mri_resolution=True, trans=None) # head coordinates and mri_resolution, but wrong trans file - vox_mri_t = vol1[0]['vox_mri_t'] - with pytest.raises(ValueError, match='head<->mri, got mri_voxel->mri'): - src_from_fwd.export_volume(fname_img, mri_resolution=True, - trans=vox_mri_t) + vox_mri_t = vol1[0]["vox_mri_t"] + with pytest.raises(ValueError, match="head<->mri, got mri_voxel->mri"): + src_from_fwd.export_volume(fname_img, mri_resolution=True, trans=vox_mri_t) @pytest.mark.slowtest @@ -490,11 +652,11 @@ def test_make_forward_dipole(tmp_path): evoked = read_evokeds(fname_evo)[0] cov = read_cov(fname_cov) - cov['projs'] = [] # avoid proj warning + cov["projs"] = [] # avoid proj warning dip_c = read_dipole(fname_dip) # Only use magnetometers for speed! - picks = pick_types(evoked.info, meg='mag', eeg=False)[::8] + picks = pick_types(evoked.info, meg="mag", eeg=False)[::8] evoked.pick_channels([evoked.ch_names[p] for p in picks]) evoked.info.normalize_proj() info = evoked.info @@ -503,18 +665,19 @@ def test_make_forward_dipole(tmp_path): # in the test dataset. n_test_dipoles = 3 # minimum 3 needed to get uneven sampling in time dipsel = np.sort(rng.permutation(np.arange(len(dip_c)))[:n_test_dipoles]) - dip_test = Dipole(times=dip_c.times[dipsel], - pos=dip_c.pos[dipsel], - amplitude=dip_c.amplitude[dipsel], - ori=dip_c.ori[dipsel], - gof=dip_c.gof[dipsel]) + dip_test = Dipole( + times=dip_c.times[dipsel], + pos=dip_c.pos[dipsel], + amplitude=dip_c.amplitude[dipsel], + ori=dip_c.ori[dipsel], + gof=dip_c.gof[dipsel], + ) sphere = make_sphere_model(head_radius=0.1) # Warning emitted due to uneven sampling in time - with pytest.warns(RuntimeWarning, match='unevenly spaced'): - fwd, stc = make_forward_dipole(dip_test, sphere, info, - trans=fname_trans) + with pytest.warns(RuntimeWarning, match="unevenly spaced"): + fwd, stc = make_forward_dipole(dip_test, sphere, info, trans=fname_trans) # stc is list of VolSourceEstimate's assert isinstance(stc, list) @@ -526,8 +689,7 @@ def test_make_forward_dipole(tmp_path): times, pos, amplitude, ori, gof = [], [], [], [], [] nave = 400 # add a tiny amount of noise to the simulated evokeds for s in stc: - evo_test = simulate_evoked(fwd, s, info, cov, - nave=nave, random_state=rng) + evo_test = simulate_evoked(fwd, s, info, cov, nave=nave, random_state=rng) # evo_test.add_proj(make_eeg_average_ref_proj(evo_test.info)) dfit, resid = fit_dipole(evo_test, cov, sphere, None) times += dfit.times.tolist() @@ -544,14 +706,16 @@ def test_make_forward_dipole(tmp_path): diff = dip_test.pos - dip_fit.pos corr = np.corrcoef(dip_test.pos.ravel(), dip_fit.pos.ravel())[0, 1] dist = np.sqrt(np.mean(np.sum(diff * diff, axis=1))) - gc_dist = 180 / np.pi * \ - np.mean(np.arccos(np.sum(dip_test.ori * dip_fit.ori, axis=1))) + gc_dist = ( + 180 / np.pi * np.mean(np.arccos(np.sum(dip_test.ori * dip_fit.ori, axis=1))) + ) amp_err = np.sqrt(np.mean((dip_test.amplitude - dip_fit.amplitude) ** 2)) # Make sure each coordinate is close to reference # NB tolerance should be set relative to snr of simulated evoked! - assert_allclose(dip_fit.pos, dip_test.pos, rtol=0, atol=1e-2, - err_msg='position mismatch') + assert_allclose( + dip_fit.pos, dip_test.pos, rtol=0, atol=1e-2, err_msg="position mismatch" + ) assert dist < 1e-2 # within 1 cm assert corr > 0.985 assert gc_dist < 20 # less than 20 degrees @@ -560,20 +724,22 @@ def test_make_forward_dipole(tmp_path): # Make sure rejection works with BEM: one dipole at z=1m # NB _make_forward.py:_prepare_for_forward will raise a RuntimeError # if no points are left after min_dist exclusions, hence 2 dips here! - dip_outside = Dipole(times=[0., 0.001], - pos=[[0., 0., 1.0], [0., 0., 0.040]], - amplitude=[100e-9, 100e-9], - ori=[[1., 0., 0.], [1., 0., 0.]], gof=1) - with pytest.raises(ValueError, match='outside the inner skull'): + dip_outside = Dipole( + times=[0.0, 0.001], + pos=[[0.0, 0.0, 1.0], [0.0, 0.0, 0.040]], + amplitude=[100e-9, 100e-9], + ori=[[1.0, 0.0, 0.0], [1.0, 0.0, 0.0]], + gof=1, + ) + with pytest.raises(ValueError, match="outside the inner skull"): make_forward_dipole(dip_outside, fname_bem, info, fname_trans) # if we get this far, can safely assume the code works with BEMs too # -> use sphere again below for speed # Now make an evenly sampled set of dipoles, some simultaneous, # should return a VolSourceEstimate regardless - times = [0., 0., 0., 0.001, 0.001, 0.002] - pos = np.random.rand(6, 3) * 0.020 + \ - np.array([0., 0., 0.040])[np.newaxis, :] + times = [0.0, 0.0, 0.0, 0.001, 0.001, 0.002] + pos = np.random.rand(6, 3) * 0.020 + np.array([0.0, 0.0, 0.040])[np.newaxis, :] amplitude = np.random.rand(6) * 100e-9 ori = np.eye(6, 3) + np.eye(6, 3, -3) gof = np.arange(len(times)) / len(times) # arbitrary @@ -581,61 +747,63 @@ def test_make_forward_dipole(tmp_path): dip_even_samp = Dipole(times, pos, amplitude, ori, gof) # I/O round-trip - fname = str(tmp_path / 'test-fwd.fif') - with pytest.warns(RuntimeWarning, match='free orientation'): + fname = str(tmp_path / "test-fwd.fif") + with pytest.warns(RuntimeWarning, match="free orientation"): write_forward_solution(fname, fwd) - fwd_read = convert_forward_solution( - read_forward_solution(fname), force_fixed=True) + fwd_read = convert_forward_solution(read_forward_solution(fname), force_fixed=True) assert_forward_allclose(fwd, fwd_read, rtol=1e-6) - fwd, stc = make_forward_dipole(dip_even_samp, sphere, info, - trans=fname_trans) + fwd, stc = make_forward_dipole(dip_even_samp, sphere, info, trans=fname_trans) assert isinstance(stc, VolSourceEstimate) - assert_allclose(stc.times, np.arange(0., 0.003, 0.001)) + assert_allclose(stc.times, np.arange(0.0, 0.003, 0.001)) # Test passing a list of Dipoles instead of a single Dipole object - fwd2, stc2 = make_forward_dipole([dip_even_samp[0], dip_even_samp[1:]], - sphere, info, trans=fname_trans) - assert_array_equal(fwd['sol']['data'], fwd2['sol']['data']) + fwd2, stc2 = make_forward_dipole( + [dip_even_samp[0], dip_even_samp[1:]], sphere, info, trans=fname_trans + ) + assert_array_equal(fwd["sol"]["data"], fwd2["sol"]["data"]) assert_array_equal(stc.data, stc2.data) @testing.requires_testing_data def test_make_forward_no_meg(tmp_path): """Test that we can make and I/O forward solution with no MEG channels.""" - pos = dict(rr=[[0.05, 0, 0]], nn=[[0, 0, 1.]]) + pos = dict(rr=[[0.05, 0, 0]], nn=[[0, 0, 1.0]]) src = setup_volume_source_space(pos=pos) bem = make_sphere_model() trans = None - montage = make_standard_montage('standard_1020') - info = create_info(['Cz'], 1000., 'eeg').set_montage(montage) + montage = make_standard_montage("standard_1020") + info = create_info(["Cz"], 1000.0, "eeg").set_montage(montage) fwd = make_forward_solution(info, trans, src, bem) - fname = tmp_path / 'test-fwd.fif' + fname = tmp_path / "test-fwd.fif" write_forward_solution(fname, fwd) fwd_read = read_forward_solution(fname) - assert_allclose(fwd['sol']['data'], fwd_read['sol']['data']) + assert_allclose(fwd["sol"]["data"], fwd_read["sol"]["data"]) def test_use_coil_def(tmp_path): """Test use_coil_def.""" - info = create_info(1, 1000., 'mag') - info['chs'][0]['coil_type'] = 9999 - info['chs'][0]['loc'][:] = [0, 0, 0.02, 1, 0, 0, 0, 1, 0, 0, 0, 1] - sphere = make_sphere_model((0., 0., 0.), 0.01) + info = create_info(1, 1000.0, "mag") + info["chs"][0]["coil_type"] = 9999 + info["chs"][0]["loc"][:] = [0, 0, 0.02, 1, 0, 0, 0, 1, 0, 0, 0, 1] + sphere = make_sphere_model((0.0, 0.0, 0.0), 0.01) src = setup_volume_source_space(pos=5, sphere=sphere) - trans = Transform('head', 'mri', None) - with pytest.raises(RuntimeError, match='coil definition not found'): + trans = Transform("head", "mri", None) + with pytest.raises(RuntimeError, match="coil definition not found"): make_forward_solution(info, trans, src, sphere) - coil_fname = tmp_path / 'coil_def.dat' - with open(coil_fname, 'w') as fid: - fid.write("""# custom cube coil def + coil_fname = tmp_path / "coil_def.dat" + with open(coil_fname, "w") as fid: + fid.write( + """# custom cube coil def 1 9999 2 8 3e-03 0.000e+00 "Test" - 0.1250 -0.750e-03 -0.750e-03 -0.750e-03 0.000 0.000""") - with pytest.raises(RuntimeError, match='Could not interpret'): + 0.1250 -0.750e-03 -0.750e-03 -0.750e-03 0.000 0.000""" + ) + with pytest.raises(RuntimeError, match="Could not interpret"): with use_coil_def(coil_fname): make_forward_solution(info, trans, src, sphere) - with open(coil_fname, 'w') as fid: - fid.write("""# custom cube coil def + with open(coil_fname, "w") as fid: + fid.write( + """# custom cube coil def 1 9999 2 8 3e-03 0.000e+00 "Test" 0.1250 -0.750e-03 -0.750e-03 -0.750e-03 0.000 0.000 1.000 0.1250 -0.750e-03 0.750e-03 -0.750e-03 0.000 0.000 1.000 @@ -644,7 +812,8 @@ def test_use_coil_def(tmp_path): 0.1250 -0.750e-03 -0.750e-03 0.750e-03 0.000 0.000 1.000 0.1250 -0.750e-03 0.750e-03 0.750e-03 0.000 0.000 1.000 0.1250 0.750e-03 -0.750e-03 0.750e-03 0.000 0.000 1.000 - 0.1250 0.750e-03 0.750e-03 0.750e-03 0.000 0.000 1.000""") + 0.1250 0.750e-03 0.750e-03 0.750e-03 0.000 0.000 1.000""" + ) with use_coil_def(coil_fname): make_forward_solution(info, trans, src, sphere) @@ -653,27 +822,27 @@ def test_use_coil_def(tmp_path): @testing.requires_testing_data def test_sensors_inside_bem(): """Test that sensors inside the BEM are problematic.""" - rr = _get_ico_surface(1)['rr'] + rr = _get_ico_surface(1)["rr"] rr /= np.linalg.norm(rr, axis=1, keepdims=True) rr *= 0.1 assert len(rr) == 42 - info = create_info(len(rr), 1000., 'mag') - info['dev_head_t'] = Transform('meg', 'head', np.eye(4)) - for ii, ch in enumerate(info['chs']): - ch['loc'][:] = np.concatenate((rr[ii], np.eye(3).ravel())) - trans = Transform('head', 'mri', np.eye(4)) - trans['trans'][2, 3] = 0.03 - sphere_noshell = make_sphere_model((0., 0., 0.), None) - sphere = make_sphere_model((0., 0., 0.), 1.01) - with pytest.raises(RuntimeError, match='.* 15 MEG.*inside the scalp.*'): + info = create_info(len(rr), 1000.0, "mag") + info["dev_head_t"] = Transform("meg", "head", np.eye(4)) + for ii, ch in enumerate(info["chs"]): + ch["loc"][:] = np.concatenate((rr[ii], np.eye(3).ravel())) + trans = Transform("head", "mri", np.eye(4)) + trans["trans"][2, 3] = 0.03 + sphere_noshell = make_sphere_model((0.0, 0.0, 0.0), None) + sphere = make_sphere_model((0.0, 0.0, 0.0), 1.01) + with pytest.raises(RuntimeError, match=".* 15 MEG.*inside the scalp.*"): make_forward_solution(info, trans, fname_src, fname_bem) make_forward_solution(info, trans, fname_src, fname_bem_meg) # okay make_forward_solution(info, trans, fname_src, sphere_noshell) # okay - with pytest.raises(RuntimeError, match='.* 42 MEG.*outermost sphere sh.*'): + with pytest.raises(RuntimeError, match=".* 42 MEG.*outermost sphere sh.*"): make_forward_solution(info, trans, fname_src, sphere) - sphere = make_sphere_model((0., 0., 2.0), 1.01) # weird, but okay + sphere = make_sphere_model((0.0, 0.0, 2.0), 1.01) # weird, but okay make_forward_solution(info, trans, fname_src, sphere) - for ch in info['chs']: - ch['loc'][:3] *= 0.1 - with pytest.raises(RuntimeError, match='.* 42 MEG.*the inner skull.*'): + for ch in info["chs"]: + ch["loc"][:3] *= 0.1 + with pytest.raises(RuntimeError, match=".* 42 MEG.*the inner skull.*"): make_forward_solution(info, trans, fname_src, fname_bem_meg) diff --git a/mne/gui/__init__.py b/mne/gui/__init__.py index 7dffe749732..0bd08f62ad5 100644 --- a/mne/gui/__init__.py +++ b/mne/gui/__init__.py @@ -8,14 +8,32 @@ @verbose -def coregistration(tabbed=False, split=True, width=None, inst=None, - subject=None, subjects_dir=None, guess_mri_subject=None, - height=None, head_opacity=None, head_high_res=None, - trans=None, scrollable=True, *, - orient_to_surface=True, scale_by_distance=True, - mark_inside=True, interaction=None, scale=None, - advanced_rendering=None, head_inside=True, - fullscreen=None, show=True, block=False, verbose=None): +def coregistration( + tabbed=False, + split=True, + width=None, + inst=None, + subject=None, + subjects_dir=None, + guess_mri_subject=None, + height=None, + head_opacity=None, + head_high_res=None, + trans=None, + scrollable=True, + *, + orient_to_surface=True, + scale_by_distance=True, + mark_inside=True, + interaction=None, + scale=None, + advanced_rendering=None, + head_inside=True, + fullscreen=None, + show=True, + block=False, + verbose=None, +): """Coregister an MRI with a subject's head shape. The GUI can be launched through the command line interface: @@ -127,13 +145,13 @@ def coregistration(tabbed=False, split=True, width=None, inst=None, .. youtube:: ALV5qqMHLlQ """ unsupported_params = { - 'tabbed': (tabbed, False), - 'split': (split, True), - 'scrollable': (scrollable, True), - 'head_inside': (head_inside, True), - 'guess_mri_subject': guess_mri_subject, - 'scale': scale, - 'advanced_rendering': advanced_rendering, + "tabbed": (tabbed, False), + "split": (split, True), + "scrollable": (scrollable, True), + "head_inside": (head_inside, True), + "guess_mri_subject": guess_mri_subject, + "scale": scale, + "advanced_rendering": advanced_rendering, } for key, val in unsupported_params.items(): if isinstance(val, tuple): @@ -141,45 +159,44 @@ def coregistration(tabbed=False, split=True, width=None, inst=None, else: to_raise = val is not None if to_raise: - warn(f"The parameter {key} is not supported with" - " the pyvistaqt 3d backend. It will be ignored.") + warn( + f"The parameter {key} is not supported with" + " the pyvistaqt 3d backend. It will be ignored." + ) config = get_config() if guess_mri_subject is None: - guess_mri_subject = config.get( - 'MNE_COREG_GUESS_MRI_SUBJECT', 'true') == 'true' + guess_mri_subject = config.get("MNE_COREG_GUESS_MRI_SUBJECT", "true") == "true" if head_high_res is None: - head_high_res = config.get('MNE_COREG_HEAD_HIGH_RES', 'true') == 'true' + head_high_res = config.get("MNE_COREG_HEAD_HIGH_RES", "true") == "true" if advanced_rendering is None: - advanced_rendering = \ - config.get('MNE_COREG_ADVANCED_RENDERING', 'true') == 'true' + advanced_rendering = ( + config.get("MNE_COREG_ADVANCED_RENDERING", "true") == "true" + ) if head_opacity is None: - head_opacity = config.get('MNE_COREG_HEAD_OPACITY', 0.8) + head_opacity = config.get("MNE_COREG_HEAD_OPACITY", 0.8) if head_inside is None: - head_inside = \ - config.get('MNE_COREG_HEAD_INSIDE', 'true').lower() == 'true' + head_inside = config.get("MNE_COREG_HEAD_INSIDE", "true").lower() == "true" if width is None: - width = config.get('MNE_COREG_WINDOW_WIDTH', 800) + width = config.get("MNE_COREG_WINDOW_WIDTH", 800) if height is None: - height = config.get('MNE_COREG_WINDOW_HEIGHT', 600) + height = config.get("MNE_COREG_WINDOW_HEIGHT", 600) if subjects_dir is None: - if 'SUBJECTS_DIR' in config: - subjects_dir = config['SUBJECTS_DIR'] - elif 'MNE_COREG_SUBJECTS_DIR' in config: - subjects_dir = config['MNE_COREG_SUBJECTS_DIR'] + if "SUBJECTS_DIR" in config: + subjects_dir = config["SUBJECTS_DIR"] + elif "MNE_COREG_SUBJECTS_DIR" in config: + subjects_dir = config["MNE_COREG_SUBJECTS_DIR"] if orient_to_surface is None: - orient_to_surface = (config.get('MNE_COREG_ORIENT_TO_SURFACE', '') == - 'true') + orient_to_surface = config.get("MNE_COREG_ORIENT_TO_SURFACE", "") == "true" if scale_by_distance is None: - scale_by_distance = (config.get('MNE_COREG_SCALE_BY_DISTANCE', '') == - 'true') + scale_by_distance = config.get("MNE_COREG_SCALE_BY_DISTANCE", "") == "true" if interaction is None: - interaction = config.get('MNE_COREG_INTERACTION', 'terrain') + interaction = config.get("MNE_COREG_INTERACTION", "terrain") if mark_inside is None: - mark_inside = config.get('MNE_COREG_MARK_INSIDE', '') == 'true' + mark_inside = config.get("MNE_COREG_MARK_INSIDE", "") == "true" if scale is None: - scale = config.get('MNE_COREG_SCENE_SCALE', 0.16) + scale = config.get("MNE_COREG_SCENE_SCALE", 0.16) if fullscreen is None: - fullscreen = config.get('MNE_COREG_FULLSCREEN', '') == 'true' + fullscreen = config.get("MNE_COREG_FULLSCREEN", "") == "true" head_opacity = float(head_opacity) head_inside = bool(head_inside) width = int(width) @@ -188,23 +205,44 @@ def coregistration(tabbed=False, split=True, width=None, inst=None, from ..viz.backends.renderer import MNE_3D_BACKEND_TESTING from ._coreg import CoregistrationUI + if MNE_3D_BACKEND_TESTING: show = block = False return CoregistrationUI( - info_file=inst, subject=subject, subjects_dir=subjects_dir, - head_resolution=head_high_res, head_opacity=head_opacity, - orient_glyphs=orient_to_surface, scale_by_distance=scale_by_distance, - mark_inside=mark_inside, trans=trans, size=(width, height), show=show, - block=block, interaction=interaction, fullscreen=fullscreen, - verbose=verbose + info_file=inst, + subject=subject, + subjects_dir=subjects_dir, + head_resolution=head_high_res, + head_opacity=head_opacity, + orient_glyphs=orient_to_surface, + scale_by_distance=scale_by_distance, + mark_inside=mark_inside, + trans=trans, + size=(width, height), + show=show, + block=block, + interaction=interaction, + fullscreen=fullscreen, + verbose=verbose, ) -@deprecated('Use the :mod:`mne-gui-addons:mne_gui_addons` package instead, ' - 'will be removed in version 1.5.0') +@deprecated( + "Use the :mod:`mne-gui-addons:mne_gui_addons` package instead, " + "will be removed in version 1.5.0" +) @verbose -def locate_ieeg(info, trans, base_image, subject=None, subjects_dir=None, - groups=None, show=True, block=False, verbose=None): +def locate_ieeg( + info, + trans, + base_image, + subject=None, + subjects_dir=None, + groups=None, + show=True, + block=False, + verbose=None, +): """Locate intracranial electrode contacts. Parameters @@ -242,55 +280,79 @@ def locate_ieeg(info, trans, base_image, subject=None, subjects_dir=None, from ..viz.backends._utils import _qt_app_exec from ._ieeg_locate import IntracranialElectrodeLocator from qtpy.QtWidgets import QApplication + mne_gui = None # get application app = QApplication.instance() if app is None: app = QApplication(["Intracranial Electrode Locator"]) gui = IntracranialElectrodeLocator( - info, trans, base_image, subject=subject, subjects_dir=subjects_dir, - groups=groups, show=show, verbose=verbose) + info, + trans, + base_image, + subject=subject, + subjects_dir=subjects_dir, + groups=groups, + show=show, + verbose=verbose, + ) if block: _qt_app_exec(app) - return mne_gui.locate_ieeg( - info=info, trans=trans, base_image=base_image, - subject=subject, subjects_dir=subjects_dir, - groups=groups, show=show, block=block) if mne_gui else gui + return ( + mne_gui.locate_ieeg( + info=info, + trans=trans, + base_image=base_image, + subject=subject, + subjects_dir=subjects_dir, + groups=groups, + show=show, + block=block, + ) + if mne_gui + else gui + ) class _GUIScraper: """Scrape GUI outputs.""" def __repr__(self): - return '' + return "" def __call__(self, block, block_vars, gallery_conf): from ._ieeg_locate import IntracranialElectrodeLocator from ._coreg import CoregistrationUI + gui_classes = ( IntracranialElectrodeLocator, CoregistrationUI, ) try: - from mne_gui_addons._ieeg_locate import IntracranialElectrodeLocator # noqa: E501 + from mne_gui_addons._ieeg_locate import ( + IntracranialElectrodeLocator, + ) # noqa: E501 except Exception: pass else: gui_classes = gui_classes + (IntracranialElectrodeLocator,) from sphinx_gallery.scrapers import figure_rst from qtpy import QtGui - for gui in block_vars['example_globals'].values(): - if (isinstance(gui, gui_classes) and - not getattr(gui, '_scraped', False) and - gallery_conf['builder_name'] == 'html'): + + for gui in block_vars["example_globals"].values(): + if ( + isinstance(gui, gui_classes) + and not getattr(gui, "_scraped", False) + and gallery_conf["builder_name"] == "html" + ): gui._scraped = True # monkey-patch but it's easy enough - img_fname = next(block_vars['image_path_iterator']) + img_fname = next(block_vars["image_path_iterator"]) # TODO fix in window refactor - window = gui if hasattr(gui, 'grab') else gui._renderer._window + window = gui if hasattr(gui, "grab") else gui._renderer._window # window is QWindow # https://doc.qt.io/qt-5/qwidget.html#grab pixmap = window.grab() - if hasattr(gui, '_renderer'): # if no renderer, no need + if hasattr(gui, "_renderer"): # if no renderer, no need # Now the tricky part: we need to get the 3D renderer, # extract the image from it, and put it in the correct # place in the pixmap. The easiest way to do this is @@ -302,8 +364,8 @@ def __call__(self, block, block_vars, gallery_conf): # https://doc.qt.io/qt-5/qwidget.html#mapTo # https://doc.qt.io/qt-5/qpainter.html#drawPixmap-1 QtGui.QPainter(pixmap).drawPixmap( - plotter.mapTo(window, plotter.rect().topLeft()), - sub_pixmap) + plotter.mapTo(window, plotter.rect().topLeft()), sub_pixmap + ) # https://doc.qt.io/qt-5/qpixmap.html#save pixmap.save(img_fname) try: # for compatibility with both GUIs, will be refactored @@ -311,6 +373,5 @@ def __call__(self, block, block_vars, gallery_conf): except Exception: pass gui.close() - return figure_rst( - [img_fname], gallery_conf['src_dir'], 'GUI') - return '' + return figure_rst([img_fname], gallery_conf["src_dir"], "GUI") + return "" diff --git a/mne/gui/_core.py b/mne/gui/_core.py index b40f16621b3..03ba89b79c4 100644 --- a/mne/gui/_core.py +++ b/mne/gui/_core.py @@ -11,9 +11,16 @@ from qtpy import QtCore from qtpy.QtCore import Slot, Qt -from qtpy.QtWidgets import (QMainWindow, QGridLayout, - QVBoxLayout, QHBoxLayout, QLabel, - QMessageBox, QWidget, QLineEdit) +from qtpy.QtWidgets import ( + QMainWindow, + QGridLayout, + QVBoxLayout, + QHBoxLayout, + QLabel, + QMessageBox, + QWidget, + QLineEdit, +) from matplotlib import patheffects from matplotlib.backends.backend_qt5agg import FigureCanvas @@ -24,28 +31,35 @@ from ..viz.utils import safe_event from ..surface import _read_mri_surface, _marching_cubes from ..transforms import apply_trans, _frame_to_str -from ..utils import (logger, _check_fname, verbose, warn, get_subjects_dir, - _import_nibabel) +from ..utils import ( + logger, + _check_fname, + verbose, + warn, + get_subjects_dir, + _import_nibabel, +) from ..viz.backends._utils import _qt_safe_window -_IMG_LABELS = [['I', 'P'], ['I', 'L'], ['P', 'L']] +_IMG_LABELS = [["I", "P"], ["I", "L"], ["P", "L"]] _ZOOM_STEP_SIZE = 5 @verbose def _load_image(img, verbose=None): """Load data from a 3D image file (e.g. CT, MR).""" - nib = _import_nibabel('use GUI') + nib = _import_nibabel("use GUI") if not isinstance(img, nib.spatialimages.SpatialImage): - logger.debug(f'Loading {img}') - _check_fname(img, overwrite='read', must_exist=True) + logger.debug(f"Loading {img}") + _check_fname(img, overwrite="read", must_exist=True) img = nib.load(img) # get data orig_data = np.array(img.dataobj).astype(np.float32) # reorient data to RAS ornt = nib.orientations.axcodes2ornt( - nib.orientations.aff2axcodes(img.affine)).astype(int) - ras_ornt = nib.orientations.axcodes2ornt('RAS') + nib.orientations.aff2axcodes(img.affine) + ).astype(int) + ras_ornt = nib.orientations.axcodes2ornt("RAS") ornt_trans = nib.orientations.ornt_transform(ornt, ras_ornt) img_data = nib.orientations.apply_orientation(orig_data, ornt_trans) orig_mgh = nib.MGHImage(orig_data, img.affine) @@ -55,14 +69,20 @@ def _load_image(img, verbose=None): return img_data, vox_ras_t, vox_scan_ras_t -def _make_mpl_plot(width=4, height=4, dpi=300, tight=True, hide_axes=True, - facecolor='black', invert=True): +def _make_mpl_plot( + width=4, + height=4, + dpi=300, + tight=True, + hide_axes=True, + facecolor="black", + invert=True, +): fig = Figure(figsize=(width, height), dpi=dpi) canvas = FigureCanvas(fig) ax = fig.subplots() if tight: - fig.subplots_adjust(bottom=0, left=0, right=1, top=1, - wspace=0, hspace=0) + fig.subplots_adjust(bottom=0, left=0, right=1, top=1, wspace=0, hspace=0) ax.set_facecolor(facecolor) # clean up excess plot text, invert if invert: @@ -82,9 +102,8 @@ class SliceBrowser(QMainWindow): (0, 1), ) - @_qt_safe_window(splash='_renderer.figure.splash', window='') - def __init__(self, base_image=None, subject=None, subjects_dir=None, - verbose=None): + @_qt_safe_window(splash="_renderer.figure.splash", window="") + def __init__(self, base_image=None, subject=None, subjects_dir=None, verbose=None): """GUI for browsing slices of anatomical images.""" # initialize QMainWindow class super(SliceBrowser, self).__init__() @@ -92,10 +111,11 @@ def __init__(self, base_image=None, subject=None, subjects_dir=None, self._verbose = verbose # if bad/None subject, will raise an informative error when loading MRI - subject = os.environ.get('SUBJECT') if subject is None else subject + subject = os.environ.get("SUBJECT") if subject is None else subject subjects_dir = str(get_subjects_dir(subjects_dir, raise_error=False)) - self._subject_dir = op.join(subjects_dir, subject) \ - if subject and subjects_dir else None + self._subject_dir = ( + op.join(subjects_dir, subject) if subject and subjects_dir else None + ) self._load_image_data(base_image=base_image) # GUI design @@ -108,10 +128,11 @@ def __init__(self, base_image=None, subject=None, subjects_dir=None, self._plt_grid.addWidget(canvas, i // 2, i % 2) self._figs.append(fig) self._renderer = _get_renderer( - name='Slice Browser', size=(400, 400), bgcolor='w') + name="Slice Browser", size=(400, 400), bgcolor="w" + ) self._plt_grid.addWidget(self._renderer.plotter, 1, 1) - self._set_ras([0., 0., 0.], update_plots=False) + self._set_ras([0.0, 0.0, 0.0], update_plots=False) self._plot_images() @@ -141,10 +162,14 @@ def _load_image_data(self, base_image=None): self._head = None self._lh = self._rh = None else: - mri_img = 'brain' if op.isfile(op.join( - self._subject_dir, 'mri', 'brain.mgz')) else 'T1' + mri_img = ( + "brain" + if op.isfile(op.join(self._subject_dir, "mri", "brain.mgz")) + else "T1" + ) self._mri_data, vox_ras_t, vox_scan_ras_t = _load_image( - op.join(self._subject_dir, 'mri', f'{mri_img}.mgz')) + op.join(self._subject_dir, "mri", f"{mri_img}.mgz") + ) # ready alternate base image if provided, otherwise use brain/T1 if base_image is None: @@ -153,19 +178,22 @@ def _load_image_data(self, base_image=None): self._vox_ras_t = vox_ras_t self._vox_scan_ras_t = vox_scan_ras_t else: - self._base_data, self._vox_ras_t, self._vox_scan_ras_t = \ - _load_image(base_image) + self._base_data, self._vox_ras_t, self._vox_scan_ras_t = _load_image( + base_image + ) if self._mri_data is not None: - if self._mri_data.shape != self._base_data.shape or \ - not np.allclose(self._vox_ras_t, vox_ras_t, rtol=1e-6): + if self._mri_data.shape != self._base_data.shape or not np.allclose( + self._vox_ras_t, vox_ras_t, rtol=1e-6 + ): raise ValueError( - 'Base image is not aligned to MRI, got ' - f'Base shape={self._base_data.shape}, ' - f'MRI shape={self._mri_data.shape}, ' - f'Base affine={vox_ras_t} and ' - f'MRI affine={self._vox_ras_t}, ' - 'please provide an aligned image or do not use the ' - '``subject`` and ``subjects_dir`` arguments') + "Base image is not aligned to MRI, got " + f"Base shape={self._base_data.shape}, " + f"MRI shape={self._mri_data.shape}, " + f"Base affine={vox_ras_t} and " + f"MRI affine={self._vox_ras_t}, " + "please provide an aligned image or do not use the " + "``subject`` and ``subjects_dir`` arguments" + ) self._ras_vox_t = np.linalg.inv(self._vox_ras_t) self._scan_ras_vox_t = np.linalg.inv(self._vox_scan_ras_t) @@ -176,113 +204,171 @@ def _load_image_data(self, base_image=None): # number. This code assumes 1mm isotropic... img_delta = 0.5 self._img_extents = list( - [-img_delta, self._voxel_sizes[idx[0]] - img_delta, - -img_delta, self._voxel_sizes[idx[1]] - img_delta] - for idx in self._xy_idx) + [ + -img_delta, + self._voxel_sizes[idx[0]] - img_delta, + -img_delta, + self._voxel_sizes[idx[1]] - img_delta, + ] + for idx in self._xy_idx + ) if self._subject_dir is not None: - if op.exists(op.join(self._subject_dir, 'surf', 'lh.seghead')): + if op.exists(op.join(self._subject_dir, "surf", "lh.seghead")): self._head = _read_mri_surface( - op.join(self._subject_dir, 'surf', 'lh.seghead')) - assert _frame_to_str[self._head['coord_frame']] == 'mri' + op.join(self._subject_dir, "surf", "lh.seghead") + ) + assert _frame_to_str[self._head["coord_frame"]] == "mri" else: - warn('`seghead` not found, using marching cubes on base image ' - 'for head plot, use :ref:`mne.bem.make_scalp_surfaces` ' - 'to add the scalp surface instead') + warn( + "`seghead` not found, using marching cubes on base image " + "for head plot, use :ref:`mne.bem.make_scalp_surfaces` " + "to add the scalp surface instead" + ) self._head = None if self._subject_dir is not None: # allow ?h.pial.T1 if ?h.pial doesn't exist # end with '' for better file not found error - for img in ('', '.T1', '.T2', ''): + for img in ("", ".T1", ".T2", ""): surf_fname = op.join( - self._subject_dir, 'surf', '{hemi}' + f'.pial{img}') - if op.isfile(surf_fname.format(hemi='lh')): + self._subject_dir, "surf", "{hemi}" + f".pial{img}" + ) + if op.isfile(surf_fname.format(hemi="lh")): break - if op.exists(surf_fname.format(hemi='lh')): - self._lh = _read_mri_surface(surf_fname.format(hemi='lh')) - assert _frame_to_str[self._lh['coord_frame']] == 'mri' - self._rh = _read_mri_surface(surf_fname.format(hemi='rh')) - assert _frame_to_str[self._rh['coord_frame']] == 'mri' + if op.exists(surf_fname.format(hemi="lh")): + self._lh = _read_mri_surface(surf_fname.format(hemi="lh")) + assert _frame_to_str[self._lh["coord_frame"]] == "mri" + self._rh = _read_mri_surface(surf_fname.format(hemi="rh")) + assert _frame_to_str[self._rh["coord_frame"]] == "mri" else: - warn('`pial` surface not found, skipping adding to 3D ' - 'plot. This indicates the Freesurfer recon-all ' - 'has not finished or has been modified and ' - 'these files have been deleted.') + warn( + "`pial` surface not found, skipping adding to 3D " + "plot. This indicates the Freesurfer recon-all " + "has not finished or has been modified and " + "these files have been deleted." + ) self._lh = self._rh = None def _plot_images(self): """Use the MRI or CT to make plots.""" # Plot sagittal (0), coronal (1) or axial (2) view - self._images = dict(base=list(), cursor_v=list(), cursor_h=list(), - bounds=list()) + self._images = dict( + base=list(), cursor_v=list(), cursor_h=list(), bounds=list() + ) img_min = np.nanmin(self._base_data) img_max = np.nanmax(self._base_data) - text_kwargs = dict(fontsize='medium', weight='bold', color='#66CCEE', - family='monospace', ha='center', va='center', - path_effects=[patheffects.withStroke( - linewidth=4, foreground="k", alpha=0.75)]) + text_kwargs = dict( + fontsize="medium", + weight="bold", + color="#66CCEE", + family="monospace", + ha="center", + va="center", + path_effects=[ + patheffects.withStroke(linewidth=4, foreground="k", alpha=0.75) + ], + ) xyz = apply_trans(self._ras_vox_t, self._ras) for axis in range(3): plot_x_idx, plot_y_idx = self._xy_idx[axis] fig = self._figs[axis] ax = fig.axes[0] - img_data = np.take(self._base_data, self._current_slice[axis], - axis=axis).T - self._images['base'].append(ax.imshow( - img_data, cmap='gray', aspect='auto', zorder=1, - vmin=img_min, vmax=img_max)) + img_data = np.take(self._base_data, self._current_slice[axis], axis=axis).T + self._images["base"].append( + ax.imshow( + img_data, + cmap="gray", + aspect="auto", + zorder=1, + vmin=img_min, + vmax=img_max, + ) + ) img_extent = self._img_extents[axis] # x0, x1, y0, y1 w, h = np.diff(np.array(img_extent).reshape(2, 2), axis=1)[:, 0] - self._images['bounds'].append(Rectangle( - img_extent[::2], w, h, edgecolor='w', facecolor='none', - alpha=0.25, lw=0.5, zorder=1.5)) - ax.add_patch(self._images['bounds'][-1]) + self._images["bounds"].append( + Rectangle( + img_extent[::2], + w, + h, + edgecolor="w", + facecolor="none", + alpha=0.25, + lw=0.5, + zorder=1.5, + ) + ) + ax.add_patch(self._images["bounds"][-1]) v_x = (xyz[plot_x_idx],) * 2 v_y = img_extent[2:4] - self._images['cursor_v'].append(ax.plot( - v_x, v_y, color='lime', linewidth=0.5, alpha=0.5, zorder=8)[0]) + self._images["cursor_v"].append( + ax.plot(v_x, v_y, color="lime", linewidth=0.5, alpha=0.5, zorder=8)[0] + ) h_y = (xyz[plot_y_idx],) * 2 h_x = img_extent[0:2] - self._images['cursor_h'].append(ax.plot( - h_x, h_y, color='lime', linewidth=0.5, alpha=0.5, zorder=8)[0]) + self._images["cursor_h"].append( + ax.plot(h_x, h_y, color="lime", linewidth=0.5, alpha=0.5, zorder=8)[0] + ) # label axes - self._figs[axis].text(0.5, 0.075, _IMG_LABELS[axis][0], - **text_kwargs) - self._figs[axis].text(0.075, 0.5, _IMG_LABELS[axis][1], - **text_kwargs) + self._figs[axis].text(0.5, 0.075, _IMG_LABELS[axis][0], **text_kwargs) + self._figs[axis].text(0.075, 0.5, _IMG_LABELS[axis][1], **text_kwargs) self._figs[axis].axes[0].axis(img_extent) + self._figs[axis].canvas.mpl_connect("scroll_event", self._on_scroll) self._figs[axis].canvas.mpl_connect( - 'scroll_event', self._on_scroll) - self._figs[axis].canvas.mpl_connect( - 'button_release_event', partial(self._on_click, axis=axis)) + "button_release_event", partial(self._on_click, axis=axis) + ) # add head and brain in mm (convert from m) if self._head is None: - logger.debug('Using marching cubes on the base image for the ' - '3D visualization panel') + logger.debug( + "Using marching cubes on the base image for the " + "3D visualization panel" + ) # in this case, leave in voxel coordinates - rr, tris = _marching_cubes(np.where( - self._base_data < np.quantile(self._base_data, 0.95), 0, 1), - [1])[0] + rr, tris = _marching_cubes( + np.where(self._base_data < np.quantile(self._base_data, 0.95), 0, 1), + [1], + )[0] # marching cubes transposes dimensions so flip rr = apply_trans(self._vox_ras_t, rr[:, ::-1]) self._renderer.mesh( - *rr.T, triangles=tris, color='gray', opacity=0.2, - reset_camera=False, render=False) + *rr.T, + triangles=tris, + color="gray", + opacity=0.2, + reset_camera=False, + render=False, + ) self._renderer.set_camera(focalpoint=rr.mean(axis=0)) else: self._renderer.mesh( - *self._head['rr'].T * 1000, triangles=self._head['tris'], - color='gray', opacity=0.2, reset_camera=False, render=False) + *self._head["rr"].T * 1000, + triangles=self._head["tris"], + color="gray", + opacity=0.2, + reset_camera=False, + render=False, + ) if self._lh is not None and self._rh is not None: self._renderer.mesh( - *self._lh['rr'].T * 1000, triangles=self._lh['tris'], - color='white', opacity=0.2, reset_camera=False, render=False) + *self._lh["rr"].T * 1000, + triangles=self._lh["tris"], + color="white", + opacity=0.2, + reset_camera=False, + render=False, + ) self._renderer.mesh( - *self._rh['rr'].T * 1000, triangles=self._rh['tris'], - color='white', opacity=0.2, reset_camera=False, render=False) - self._renderer.set_camera(azimuth=90, elevation=90, distance=300, - focalpoint=tuple(self._ras)) + *self._rh["rr"].T * 1000, + triangles=self._rh["tris"], + color="white", + opacity=0.2, + reset_camera=False, + render=False, + ) + self._renderer.set_camera( + azimuth=90, elevation=90, distance=300, focalpoint=tuple(self._ras) + ) # update plots self._draw() self._renderer._update() @@ -291,19 +377,19 @@ def _configure_status_bar(self, hbox=None): """Make a bar at the bottom with information in it.""" hbox = QHBoxLayout() if hbox is None else hbox - self._intensity_label = QLabel('') # update later + self._intensity_label = QLabel("") # update later hbox.addWidget(self._intensity_label) - VOX_label = QLabel('VOX =') - self._VOX_textbox = QLineEdit('') # update later + VOX_label = QLabel("VOX =") + self._VOX_textbox = QLineEdit("") # update later self._VOX_textbox.setMaximumHeight(25) self._VOX_textbox.setMinimumWidth(75) self._VOX_textbox.focusOutEvent = self._update_VOX hbox.addWidget(VOX_label) hbox.addWidget(self._VOX_textbox) - RAS_label = QLabel('RAS =') - self._RAS_textbox = QLineEdit('') # update later + RAS_label = QLabel("RAS =") + self._RAS_textbox = QLineEdit("") # update later self._RAS_textbox.setMaximumHeight(25) self._RAS_textbox.setMinimumWidth(150) self._RAS_textbox.focusOutEvent = self._update_RAS @@ -318,7 +404,8 @@ def _update_camera(self, render=False): # needs fix, distance moves when focal point updates distance=self._renderer.plotter.camera.distance * 0.9, focalpoint=tuple(self._ras), - reset_camera=False) + reset_camera=False, + ) def _on_scroll(self, event): """Process mouse scroll wheel event to zoom.""" @@ -328,8 +415,8 @@ def _zoom(self, sign=1, draw=False): """Zoom in on the image.""" delta = _ZOOM_STEP_SIZE * sign for axis, fig in enumerate(self._figs): - xcur = self._images['cursor_v'][axis].get_xdata()[0] - ycur = self._images['cursor_h'][axis].get_ydata()[0] + xcur = self._images["cursor_v"][axis].get_xdata()[0] + ycur = self._images["cursor_h"][axis].get_ydata()[0] rx, ry = [self._voxel_ratios[idx] for idx in self._xy_idx[axis]] xmin, xmax = fig.axes[0].get_xlim() ymin, ymax = fig.axes[0].get_ylim() @@ -352,37 +439,38 @@ def _zoom(self, sign=1, draw=False): @Slot() def _update_RAS(self, event): """Interpret user input to the RAS textbox.""" - ras = self._convert_text(self._RAS_textbox.text(), 'ras') + ras = self._convert_text(self._RAS_textbox.text(), "ras") if ras is not None: self._set_ras(ras) @Slot() def _update_VOX(self, event): """Interpret user input to the RAS textbox.""" - ras = self._convert_text(self._VOX_textbox.text(), 'vox') + ras = self._convert_text(self._VOX_textbox.text(), "vox") if ras is not None: self._set_ras(ras) def _convert_text(self, text, text_kind): - text = text.replace('\n', '') - vals = text.split(',') + text = text.replace("\n", "") + vals = text.split(",") if len(vals) != 3: - vals = text.split(' ') # spaces also okay as in freesurfer + vals = text.split(" ") # spaces also okay as in freesurfer vals = [var.lstrip().rstrip() for var in vals] try: vals = np.array([float(var) for var in vals]).reshape(3) except Exception: self._update_moved() # resets RAS label return - if text_kind == 'vox': + if text_kind == "vox": vox = vals ras = apply_trans(self._vox_ras_t, vox) else: - assert text_kind == 'ras' + assert text_kind == "ras" ras = vals vox = apply_trans(self._ras_vox_t, ras) - wrong_size = any(var < 0 or var > n - 1 for var, n in - zip(vox, self._voxel_sizes)) + wrong_size = any( + var < 0 or var > n - 1 for var, n in zip(vox, self._voxel_sizes) + ) if wrong_size: self._update_moved() # resets RAS label return @@ -405,18 +493,18 @@ def set_RAS(self, ras): def _set_ras(self, ras, update_plots=True): ras = np.asarray(ras, dtype=float) assert ras.shape == (3,) - msg = ', '.join(f'{x:0.2f}' for x in ras) - logger.debug(f'Trying RAS: ({msg}) mm') + msg = ", ".join(f"{x:0.2f}" for x in ras) + logger.debug(f"Trying RAS: ({msg}) mm") # clip to valid vox = apply_trans(self._ras_vox_t, ras) - vox = np.array([ - np.clip(d, 0, self._voxel_sizes[ii] - 1) - for ii, d in enumerate(vox)]) + vox = np.array( + [np.clip(d, 0, self._voxel_sizes[ii] - 1) for ii, d in enumerate(vox)] + ) # transform back, make write-only self._ras_safe = apply_trans(self._vox_ras_t, vox) - self._ras_safe.flags['WRITEABLE'] = False - msg = ', '.join(f'{x:0.2f}' for x in self._ras_safe) - logger.debug(f'Setting RAS: ({msg}) mm') + self._ras_safe.flags["WRITEABLE"] = False + msg = ", ".join(f"{x:0.2f}" for x in self._ras_safe) + logger.debug(f"Setting RAS: ({msg}) mm") if update_plots: self._move_cursors_to_pos() @@ -440,15 +528,14 @@ def _current_slice(self): def _draw(self, axis=None): """Update the figures with a draw call.""" - for axis in (range(3) if axis is None else [axis]): + for axis in range(3) if axis is None else [axis]: self._figs[axis].canvas.draw() def _update_base_images(self, axis=None, draw=False): """Update the base images.""" for axis in range(3) if axis is None else [axis]: - img_data = np.take(self._base_data, self._current_slice[axis], - axis=axis).T - self._images['base'][axis].set_data(img_data) + img_data = np.take(self._base_data, self._current_slice[axis], axis=axis).T + self._images["base"][axis].set_data(img_data) if draw: self._draw(axis) @@ -462,23 +549,25 @@ def _move_cursors_to_pos(self): """Move the cursors to a position.""" for axis in range(3): x, y = self._vox[list(self._xy_idx[axis])] - self._images['cursor_v'][axis].set_xdata([x, x]) - self._images['cursor_h'][axis].set_ydata([y, y]) + self._images["cursor_v"][axis].set_xdata([x, x]) + self._images["cursor_h"][axis].set_ydata([y, y]) self._update_images(draw=True) self._update_moved() def _show_help(self): """Show the help menu.""" QMessageBox.information( - self, 'Help', + self, + "Help", "Help:\n" "'+'/'-': zoom\nleft/right arrow: left/right\n" "up/down arrow: superior/inferior\n" - "left angle bracket/right angle bracket: anterior/posterior") + "left angle bracket/right angle bracket: anterior/posterior", + ) def keyPressEvent(self, event): """Execute functions when the user presses a key.""" - if event.key() == 'escape': + if event.key() == "escape": self.close() elif event.key() == QtCore.Qt.Key_Return: @@ -487,25 +576,37 @@ def keyPressEvent(self, event): widget.clearFocus() self.setFocus() # removing focus calls focus out event - elif event.text() == 'h': + elif event.text() == "h": self._show_help() - elif event.text() in ('=', '+', '-'): - self._zoom(sign=-2 * (event.text() == '-') + 1, draw=True) + elif event.text() in ("=", "+", "-"): + self._zoom(sign=-2 * (event.text() == "-") + 1, draw=True) # Changing slices - elif event.key() in (QtCore.Qt.Key_Up, QtCore.Qt.Key_Down, - QtCore.Qt.Key_Left, QtCore.Qt.Key_Right, - QtCore.Qt.Key_Comma, QtCore.Qt.Key_Period, - QtCore.Qt.Key_PageUp, QtCore.Qt.Key_PageDown): + elif event.key() in ( + QtCore.Qt.Key_Up, + QtCore.Qt.Key_Down, + QtCore.Qt.Key_Left, + QtCore.Qt.Key_Right, + QtCore.Qt.Key_Comma, + QtCore.Qt.Key_Period, + QtCore.Qt.Key_PageUp, + QtCore.Qt.Key_PageDown, + ): ras = np.array(self._ras) if event.key() in (QtCore.Qt.Key_Up, QtCore.Qt.Key_Down): ras[2] += 2 * (event.key() == QtCore.Qt.Key_Up) - 1 elif event.key() in (QtCore.Qt.Key_Left, QtCore.Qt.Key_Right): ras[0] += 2 * (event.key() == QtCore.Qt.Key_Right) - 1 else: - ras[1] += 2 * (event.key() == QtCore.Qt.Key_PageUp or - event.key() == QtCore.Qt.Key_Period) - 1 + ras[1] += ( + 2 + * ( + event.key() == QtCore.Qt.Key_PageUp + or event.key() == QtCore.Qt.Key_Period + ) + - 1 + ) self._set_ras(ras) def _on_click(self, event, axis): @@ -516,18 +617,17 @@ def _on_click(self, event, axis): logger.debug(f'Clicked {"XYZ"[axis]} ({axis}) axis at pos {pos}') xyz = self._vox xyz[list(self._xy_idx[axis])] = pos - logger.debug(f'Using voxel {list(xyz)}') + logger.debug(f"Using voxel {list(xyz)}") ras = apply_trans(self._vox_ras_t, xyz) self._set_ras(ras) def _update_moved(self): """Update when cursor position changes.""" - self._RAS_textbox.setText('{:.2f}, {:.2f}, {:.2f}'.format( - *self._ras)) - self._VOX_textbox.setText('{:3d}, {:3d}, {:3d}'.format( - *self._current_slice)) - self._intensity_label.setText('intensity = {:.2f}'.format( - self._base_data[tuple(self._current_slice)])) + self._RAS_textbox.setText("{:.2f}, {:.2f}, {:.2f}".format(*self._ras)) + self._VOX_textbox.setText("{:3d}, {:3d}, {:3d}".format(*self._current_slice)) + self._intensity_label.setText( + "intensity = {:.2f}".format(self._base_data[tuple(self._current_slice)]) + ) @safe_event def closeEvent(self, event): diff --git a/mne/gui/_coreg.py b/mne/gui/_coreg.py index d5126ba2fee..a9f26038107 100644 --- a/mne/gui/_coreg.py +++ b/mne/gui/_coreg.py @@ -21,23 +21,49 @@ from ..io.meas_info import _empty_info from ..io._read_raw import supported as raw_supported_types from ..bem import make_bem_solution, write_bem_solution -from ..coreg import (Coregistration, _is_mri_subject, scale_mri, bem_fname, - _mri_subject_has_bem, fid_fname, _map_fid_name_to_idx, - _find_head_bem) -from ..viz._3d import (_plot_head_surface, _plot_head_fiducials, - _plot_head_shape_points, _plot_mri_fiducials, - _plot_hpi_coils, _plot_sensors, _plot_helmet) +from ..coreg import ( + Coregistration, + _is_mri_subject, + scale_mri, + bem_fname, + _mri_subject_has_bem, + fid_fname, + _map_fid_name_to_idx, + _find_head_bem, +) +from ..viz._3d import ( + _plot_head_surface, + _plot_head_fiducials, + _plot_head_shape_points, + _plot_mri_fiducials, + _plot_hpi_coils, + _plot_sensors, + _plot_helmet, +) from ..viz.backends._utils import _qt_app_exec, _qt_safe_window from ..viz.utils import safe_event -from ..transforms import (read_trans, write_trans, _ensure_trans, _get_trans, - rotation_angles, _get_transforms_to_coord_frame) -from ..utils import (get_subjects_dir, check_fname, _check_fname, fill_doc, - verbose, logger, _validate_type) +from ..transforms import ( + read_trans, + write_trans, + _ensure_trans, + _get_trans, + rotation_angles, + _get_transforms_to_coord_frame, +) +from ..utils import ( + get_subjects_dir, + check_fname, + _check_fname, + fill_doc, + verbose, + logger, + _validate_type, +) from ..surface import _DistanceQuery, _CheckInside from ..channels import read_dig_fif -class _WorkerData(): +class _WorkerData: def __init__(self, name, params=None): self._name = name self._params = params @@ -50,9 +76,9 @@ def _get_subjects(sdir): dir_content = os.listdir(sdir) subjects = [s for s in dir_content if _is_mri_subject(s, sdir)] if len(subjects) == 0: - subjects.append('') + subjects.append("") else: - subjects = [''] + subjects = [""] return sorted(subjects) @@ -136,28 +162,47 @@ class CoregistrationUI(HasTraits): _scale_mode = Unicode() _icp_fid_match = Unicode() - @_qt_safe_window(splash='_renderer.figure.splash', - window='_renderer.figure.plotter') + @_qt_safe_window( + splash="_renderer.figure.splash", window="_renderer.figure.plotter" + ) @verbose - def __init__(self, info_file, *, subject=None, subjects_dir=None, - fiducials='auto', head_resolution=None, - head_opacity=None, hpi_coils=None, - head_shape_points=None, eeg_channels=None, orient_glyphs=None, - scale_by_distance=None, mark_inside=None, - sensor_opacity=None, trans=None, size=None, bgcolor=None, - show=True, block=False, fullscreen=False, - interaction='terrain', verbose=None): + def __init__( + self, + info_file, + *, + subject=None, + subjects_dir=None, + fiducials="auto", + head_resolution=None, + head_opacity=None, + hpi_coils=None, + head_shape_points=None, + eeg_channels=None, + orient_glyphs=None, + scale_by_distance=None, + mark_inside=None, + sensor_opacity=None, + trans=None, + size=None, + bgcolor=None, + show=True, + block=False, + fullscreen=False, + interaction="terrain", + verbose=None, + ): from ..viz.backends.renderer import _get_renderer def _get_default(var, val): return var if var is not None else val + self._actors = dict() self._surfaces = dict() self._widgets = dict() self._verbose = verbose self._plot_locked = False self._params_locked = False - self._refresh_rate_ms = max(int(round(1000. / 60.)), 1) + self._refresh_rate_ms = max(int(round(1000.0 / 60.0)), 1) self._redraws_pending = set() self._parameter_mutex = threading.Lock() self._redraw_mutex = threading.Lock() @@ -176,8 +221,8 @@ def _get_default(var, val): self._mri_scale_modified = False self._accept_close_event = True self._fid_colors = tuple( - DEFAULTS['coreg'][f'{key}_color'] for key in - ('lpa', 'nasion', 'rpa')) + DEFAULTS["coreg"][f"{key}_color"] for key in ("lpa", "nasion", "rpa") + ) self._defaults = dict( size=_get_default(size, (800, 600)), bgcolor=_get_default(bgcolor, "grey"), @@ -198,8 +243,8 @@ def _get_default(var, val): subject_to="", scale_modes=["None", "uniform", "3-axis"], scale_mode="None", - icp_fid_matches=('nearest', 'matched'), - icp_fid_match='matched', + icp_fid_matches=("nearest", "matched"), + icp_fid_match="matched", icp_n_iterations=20, omit_hsp_distance=10.0, lock_head_opacity=self._head_opacity < 1.0, @@ -221,7 +266,7 @@ def _get_default(var, val): subject = _get_default(subject, _get_subjects(subjects_dir)[0]) # setup the window - splash = 'Initializing coregistration GUI...' if show else False + splash = "Initializing coregistration GUI..." if show else False self._renderer = _get_renderer( size=self._defaults["size"], bgcolor=self._defaults["bgcolor"], @@ -233,13 +278,15 @@ def _get_default(var, val): self._renderer.set_interaction(interaction) # coregistration model setup - self._immediate_redraw = (self._renderer._kind != 'qt') + self._immediate_redraw = self._renderer._kind != "qt" self._info = info self._fiducials = fiducials self.coreg = Coregistration( - info=self._info, subject=subject, subjects_dir=subjects_dir, + info=self._info, + subject=subject, + subjects_dir=subjects_dir, fiducials=fiducials, - on_defects='ignore' # safe due to interactive visual inspection + on_defects="ignore", # safe due to interactive visual inspection ) fid_accurate = self.coreg._fid_accurate for fid in self._defaults["weights"].keys(): @@ -286,8 +333,8 @@ def _get_default(var, val): # internally self._set_fiducials_file(self.coreg._fid_filename) else: - self._set_head_resolution('high') - self._forward_widget_command('high_res_head', "set_value", True) + self._set_head_resolution("high") + self._forward_widget_command("high_res_head", "set_value", True) self._set_lock_fids(True) # hack to make the dig disappear self._update_fiducials_label() self._update_fiducials() @@ -301,20 +348,21 @@ def _get_default(var, val): if show: self._renderer.show() # update the view once shown - views = {True: dict(azimuth=90, elevation=90), # front - False: dict(azimuth=180, elevation=90)} # left + views = { + True: dict(azimuth=90, elevation=90), # front + False: dict(azimuth=180, elevation=90), + } # left self._renderer.set_camera(distance=None, **views[self._lock_fids]) self._redraw() # XXX: internal plotter/renderer should not be exposed if not self._immediate_redraw: - self._renderer.plotter.add_callback( - self._redraw, self._refresh_rate_ms) + self._renderer.plotter.add_callback(self._redraw, self._refresh_rate_ms) self._renderer.plotter.show_axes() # initialization does not count as modification by the user self._trans_modified = False self._mri_fids_modified = False self._mri_scale_modified = False - if block and self._renderer._kind != 'notebook': + if block and self._renderer._kind != "notebook": _qt_app_exec(self._renderer.figure.store["app"]) def _set_subjects_dir(self, subjects_dir): @@ -330,10 +378,8 @@ def _set_subjects_dir(self, subjects_dir): ) ) subjects = _get_subjects(subjects_dir) - low_res_path = _find_head_bem( - subjects[0], subjects_dir, high_res=False) - high_res_path = _find_head_bem( - subjects[0], subjects_dir, high_res=True) + low_res_path = _find_head_bem(subjects[0], subjects_dir, high_res=False) + high_res_path = _find_head_bem(subjects[0], subjects_dir, high_res=True) valid = low_res_path is not None or high_res_path is not None except Exception: valid = False @@ -352,7 +398,7 @@ def _set_lock_fids(self, state): def _set_fiducials_file(self, fname): if fname is None: - fids = 'auto' + fids = "auto" else: fname = str( _check_fname( @@ -373,17 +419,11 @@ def _set_fiducials_file(self, fname): if fname is None: self._set_lock_fids(False) - self._forward_widget_command( - 'reload_mri_fids', 'set_enabled', False - ) + self._forward_widget_command("reload_mri_fids", "set_enabled", False) else: self._set_lock_fids(True) - self._forward_widget_command( - 'reload_mri_fids', 'set_enabled', True - ) - self._display_message( - f"Loading MRI fiducials from {fname}... Done!" - ) + self._forward_widget_command("reload_mri_fids", "set_enabled", True) + self._display_message(f"Loading MRI fiducials from {fname}... Done!") def _set_current_fiducial(self, fid): self._current_fiducial = fid.lower() @@ -394,17 +434,23 @@ def _set_info_file(self, fname): # info file can be anything supported by read_raw try: - check_fname(fname, 'info', tuple(raw_supported_types.keys()), - endings_err=tuple(raw_supported_types.keys())) + check_fname( + fname, + "info", + tuple(raw_supported_types.keys()), + endings_err=tuple(raw_supported_types.keys()), + ) fname = str(_check_fname(fname, overwrite="read")) # cast to str # ctf ds `files` are actually directories - if fname.endswith(('.ds',)): + if fname.endswith((".ds",)): info_file = _check_fname( - fname, overwrite='read', must_exist=True, need_dir=True) + fname, overwrite="read", must_exist=True, need_dir=True + ) else: info_file = _check_fname( - fname, overwrite='read', must_exist=True, need_dir=False) + fname, overwrite="read", must_exist=True, need_dir=False + ) valid = True except OSError: valid = False @@ -450,14 +496,12 @@ def _set_grow_hair(self, value): def _set_subject_to(self, value): self._subject_to = value - self._forward_widget_command( - "save_subject", "set_enabled", len(value) > 0) + self._forward_widget_command("save_subject", "set_enabled", len(value) > 0) if self._check_subject_exists(): style = dict(border="2px solid #ff0000") else: style = dict(border="initial") - self._forward_widget_command( - "subject_to", "set_style", style) + self._forward_widget_command("subject_to", "set_style", style) def _set_scale_mode(self, mode): self._scale_mode = mode @@ -470,7 +514,7 @@ def _set_fiducial(self, value, coord): coords = ["X", "Y", "Z"] coord_idx = coords.index(coord) - self.coreg.fiducials.dig[fid_idx]['r'][coord_idx] = value / 1e3 + self.coreg.fiducials.dig[fid_idx]["r"][coord_idx] = value / 1e3 self._update_plot("mri_fids") def _set_parameter(self, value, mode_name, coord, plot_locked=False): @@ -482,10 +526,9 @@ def _set_parameter(self, value, mode_name, coord, plot_locked=False): return if mode_name == "scale" and self._scale_mode == "uniform": with self._lock(params=True): - self._forward_widget_command( - ["sY", "sZ"], "set_value", value) + self._forward_widget_command(["sY", "sZ"], "set_value", value) with self._parameter_mutex: - self. _set_parameter_safe(value, mode_name, coord) + self._set_parameter_safe(value, mode_name, coord) if not plot_locked: self._update_plot("sensors") @@ -521,9 +564,9 @@ def _set_icp_fid_match(self, method): def _set_point_weight(self, weight, point): funcs = { - 'hpi': '_set_hpi_coils', - 'hsp': '_set_head_shape_points', - 'eeg': '_set_eeg_channels', + "hpi": "_set_hpi_coils", + "hsp": "_set_head_shape_points", + "eeg": "_set_eeg_channels", } if point in funcs.keys(): getattr(self, funcs[point])(weight > 0) @@ -567,70 +610,90 @@ def _lock_fids_changed(self, change=None): # MRI fiducials "save_mri_fids", # View options - "helmet", "head_opacity", "high_res_head", + "helmet", + "head_opacity", + "high_res_head", # Digitization source - "info_file", "grow_hair", "omit_distance", "omit", "reset_omit", + "info_file", + "grow_hair", + "omit_distance", + "omit", + "reset_omit", # Scaling - "scaling_mode", "sX", "sY", "sZ", + "scaling_mode", + "sX", + "sY", + "sZ", # Transformation - "tX", "tY", "tZ", - "rX", "rY", "rZ", + "tX", + "tY", + "tZ", + "rX", + "rY", + "rZ", # Fitting buttons - "fit_fiducials", "fit_icp", + "fit_fiducials", + "fit_icp", # Transformation I/O - "save_trans", "load_trans", + "save_trans", + "load_trans", "reset_trans", # ICP - "icp_n_iterations", "icp_fid_match", "reset_fitting_options", + "icp_n_iterations", + "icp_fid_match", + "reset_fitting_options", # Weights - "hsp_weight", "eeg_weight", "hpi_weight", - "lpa_weight", "nasion_weight", "rpa_weight", + "hsp_weight", + "eeg_weight", + "hpi_weight", + "lpa_weight", + "nasion_weight", + "rpa_weight", ] fits_widgets = ["fits_fiducials", "fits_icp"] fid_widgets = ["fid_X", "fid_Y", "fid_Z", "fids_file", "fids"] if self._lock_fids: self._forward_widget_command(locked_widgets, "set_enabled", True) self._forward_widget_command( - 'head_opacity', 'set_value', self._old_head_opacity + "head_opacity", "set_value", self._old_head_opacity ) self._scale_mode_changed() self._display_message() self._update_distance_estimation() else: self._old_head_opacity = self._head_opacity - self._forward_widget_command( - 'head_opacity', 'set_value', 1.0 - ) + self._forward_widget_command("head_opacity", "set_value", 1.0) self._forward_widget_command(locked_widgets, "set_enabled", False) self._forward_widget_command(fits_widgets, "set_enabled", False) - self._display_message("Placing MRI fiducials - " - f"{self._current_fiducial.upper()}") + self._display_message( + "Placing MRI fiducials - " f"{self._current_fiducial.upper()}" + ) self._set_sensors_visibility(self._lock_fids) self._forward_widget_command("lock_fids", "set_value", self._lock_fids) - self._forward_widget_command(fid_widgets, "set_enabled", - not self._lock_fids) + self._forward_widget_command(fid_widgets, "set_enabled", not self._lock_fids) @observe("_current_fiducial") def _current_fiducial_changed(self, change=None): self._update_fiducials() self._follow_fiducial_view() if not self._lock_fids: - self._display_message("Placing MRI fiducials - " - f"{self._current_fiducial.upper()}") + self._display_message( + "Placing MRI fiducials - " f"{self._current_fiducial.upper()}" + ) @observe("_info_file") def _info_file_changed(self, change=None): if not self._info_file: return - elif self._info_file.endswith(('.fif', '.fif.gz')): + elif self._info_file.endswith((".fif", ".fif.gz")): fid, tree, _ = fiff_open(self._info_file) fid.close() if len(dir_tree_find(tree, FIFF.FIFFB_MEAS_INFO)) > 0: self._info = read_info(self._info_file, verbose=False) elif len(dir_tree_find(tree, FIFF.FIFFB_ISOTRAK)) > 0: self._info = _empty_info(1) - self._info['dig'] = read_dig_fif(fname=self._info_file).dig + self._info["dig"] = read_dig_fif(fname=self._info_file).dig self._info._unlocked = False else: self._info = read_raw(self._info_file).info @@ -689,10 +752,12 @@ def _scale_mode_changed(self, change=None): mode = None if self._scale_mode == "None" else self._scale_mode self.coreg.set_scale_mode(mode) if self._lock_fids: - self._forward_widget_command(locked_widgets, "set_enabled", - mode is not None) - self._forward_widget_command("fits_fiducials", "set_enabled", - mode not in (None, "3-axis")) + self._forward_widget_command( + locked_widgets, "set_enabled", mode is not None + ) + self._forward_widget_command( + "fits_fiducials", "set_enabled", mode not in (None, "3-axis") + ) if self._scale_mode == "uniform": self._forward_widget_command(["sY", "sZ"], "set_enabled", False) @@ -712,13 +777,15 @@ def _run_worker(self, queue, jobs): def _configure_dialogs(self): from ..viz.backends.renderer import MNE_3D_BACKEND_TESTING + for name, buttons in zip( - ["overwrite_subject", "overwrite_subject_exit"], - [["Yes", "No"], ["Yes", "Discard", "Cancel"]]): + ["overwrite_subject", "overwrite_subject_exit"], + [["Yes", "No"], ["Yes", "Discard", "Cancel"]], + ): self._widgets[name] = self._renderer._dialog_create( title="CoregistrationUI", text="The name of the output subject used to " - "save the scaled anatomy already exists.", + "save the scaled anatomy already exists.", info_text="Do you want to overwrite?", callback=self._overwrite_subject_callback, buttons=buttons, @@ -731,11 +798,13 @@ def _configure_worker(self): "_parameter_queue": dict(set_parameter=self._set_parameter), } for queue_name, jobs in work_plan.items(): - t = threading.Thread(target=partial( - self._run_worker, - queue=getattr(self, queue_name), - jobs=jobs, - )) + t = threading.Thread( + target=partial( + self._run_worker, + queue=getattr(self, queue_name), + jobs=jobs, + ) + ) t.daemon = True t.start() @@ -744,14 +813,15 @@ def _configure_picking(self): self._on_mouse_move, self._on_button_press, self._on_button_release, - self._on_pick + self._on_pick, ) def _configure_legend(self): - colors = \ - [np.array(DEFAULTS['coreg'][f"{fid.lower()}_color"]).astype(float) - for fid in self._defaults['fiducials']] - labels = list(zip(self._defaults['fiducials'], colors)) + colors = [ + np.array(DEFAULTS["coreg"][f"{fid.lower()}_color"]).astype(float) + for fid in self._defaults["fiducials"] + ] + labels = list(zip(self._defaults["fiducials"], colors)) mri_fids_legend_actor = self._renderer.legend(labels=labels) self._update_actor("mri_fids_legend", mri_fids_legend_actor) @@ -772,16 +842,16 @@ def _redraw(self, *, verbose=None): # We need at least "head" before "hsp", because the grow_hair param # for head sets the rr that are used for inside/outside hsp redraws_ordered = sorted( - self._redraws_pending, - key=lambda key: list(draw_map).index(key)) - logger.debug(f'Redrawing {redraws_ordered}') + self._redraws_pending, key=lambda key: list(draw_map).index(key) + ) + logger.debug(f"Redrawing {redraws_ordered}") for ki, key in enumerate(redraws_ordered): - logger.debug(f'{ki}. Drawing {repr(key)}') + logger.debug(f"{ki}. Drawing {repr(key)}") draw_map[key]() self._redraws_pending.clear() self._renderer._update() # necessary for MacOS - if platform.system() == 'Darwin': + if platform.system() == "Darwin": self._renderer._process_events() def _on_mouse_move(self, vtk_picker, event): @@ -813,8 +883,10 @@ def _on_pick(self, vtk_picker, event): return pos = np.array(vtk_picker.GetPickPosition()) vtk_cell = mesh.GetCell(cell_id) - cell = [vtk_cell.GetPointId(point_id) for point_id - in range(vtk_cell.GetNumberOfPoints())] + cell = [ + vtk_cell.GetPointId(point_id) + for point_id in range(vtk_cell.GetNumberOfPoints()) + ] vertices = mesh.points[cell] idx = np.argmin(abs(vertices - pos), axis=0) vertex_id = cell[idx[0]] @@ -828,14 +900,16 @@ def _on_pick(self, vtk_picker, event): self._update_plot("mri_fids") def _reset_fitting_parameters(self): - self._forward_widget_command("icp_n_iterations", "set_value", - self._defaults["icp_n_iterations"]) - self._forward_widget_command("icp_fid_match", "set_value", - self._defaults["icp_fid_match"]) - weights_widgets = [f"{w}_weight" - for w in self._defaults["weights"].keys()] - self._forward_widget_command(weights_widgets, "set_value", - list(self._defaults["weights"].values())) + self._forward_widget_command( + "icp_n_iterations", "set_value", self._defaults["icp_n_iterations"] + ) + self._forward_widget_command( + "icp_fid_match", "set_value", self._defaults["icp_fid_match"] + ) + weights_widgets = [f"{w}_weight" for w in self._defaults["weights"].keys()] + self._forward_widget_command( + weights_widgets, "set_value", list(self._defaults["weights"].values()) + ) def _reset_fiducials(self): self._set_current_fiducial(self._defaults["fiducial"]) @@ -843,21 +917,22 @@ def _reset_fiducials(self): def _omit_hsp(self): self.coreg.omit_head_shape_points(self._omit_hsp_distance / 1e3) n_omitted = np.sum(~self.coreg._extra_points_filter) - n_remaining = len(self.coreg._dig_dict['hsp']) - n_omitted + n_remaining = len(self.coreg._dig_dict["hsp"]) - n_omitted self._update_plot("hsp") self._update_distance_estimation() self._display_message( - f"{n_omitted} head shape points omitted, " - f"{n_remaining} remaining.") + f"{n_omitted} head shape points omitted, " f"{n_remaining} remaining." + ) def _reset_omit_hsp_filter(self): self.coreg._extra_points_filter = None self.coreg._update_params(force_update=True) self._update_plot("hsp") self._update_distance_estimation() - n_total = len(self.coreg._dig_dict['hsp']) + n_total = len(self.coreg._dig_dict["hsp"]) self._display_message( - f"No head shape point is omitted, the total is {n_total}.") + f"No head shape point is omitted, the total is {n_total}." + ) @verbose def _update_plot(self, changes="all", verbose=None): @@ -866,9 +941,8 @@ def _update_plot(self, changes="all", verbose=None): try: fun_name = inspect.currentframe().f_back.f_back.f_code.co_name except Exception: # just in case one of these attrs is missing - fun_name = 'unknown' - logger.debug( - f'Updating plots based on {fun_name}: {repr(changes)}') + fun_name = "unknown" + logger.debug(f"Updating plots based on {fun_name}: {repr(changes)}") if self._plot_locked: return if self._info is None: @@ -876,15 +950,20 @@ def _update_plot(self, changes="all", verbose=None): self._to_cf_t = dict(mri=dict(trans=np.eye(4)), head=None) else: self._to_cf_t = _get_transforms_to_coord_frame( - self._info, self.coreg.trans, coord_frame=self._coord_frame) + self._info, self.coreg.trans, coord_frame=self._coord_frame + ) all_keys = ( - 'head', 'mri_fids', # MRI first - 'hsp', 'hpi', 'eeg', 'head_fids', # then dig - 'helmet', - ) - if changes == 'all': + "head", + "mri_fids", # MRI first + "hsp", + "hpi", + "eeg", + "head_fids", # then dig + "helmet", + ) + if changes == "all": changes = list(all_keys) - elif changes == 'sensors': + elif changes == "sensors": changes = all_keys[2:] # omit MRI ones elif isinstance(changes, str): changes = [changes] @@ -894,7 +973,7 @@ def _update_plot(self, changes="all", verbose=None): # it would reduce "jerkiness" of the updates, but this should at least # work okay bad = changes.difference(set(all_keys)) - assert len(bad) == 0, f'Unknown changes: {bad}' + assert len(bad) == 0, f"Unknown changes: {bad}" self._redraws_pending.update(changes) if self._immediate_redraw: self._redraw() @@ -913,15 +992,24 @@ def _lock(self, plot=False, params=False, scale_mode=False, fitting=False): self.coreg._scale_mode = None if fitting: widgets = [ - "sX", "sY", "sZ", - "tX", "tY", "tZ", - "rX", "rY", "rZ", - "fit_icp", "fit_fiducials", "fits_icp", "fits_fiducials" + "sX", + "sY", + "sZ", + "tX", + "tY", + "tZ", + "rX", + "rY", + "rZ", + "fit_icp", + "fit_fiducials", + "fits_icp", + "fits_fiducials", ] states = [ self._forward_widget_command( - w, "is_enabled", None, - input_value=False, output_value=True) + w, "is_enabled", None, input_value=False, output_value=True + ) for w in widgets ] self._forward_widget_command(widgets, "set_enabled", False) @@ -939,21 +1027,19 @@ def _lock(self, plot=False, params=False, scale_mode=False, fitting=False): self._forward_widget_command(w, "set_enabled", states[idx]) def _display_message(self, msg=""): - self._forward_widget_command('status_message', 'set_value', msg) + self._forward_widget_command("status_message", "set_value", msg) + self._forward_widget_command("status_message", "show", None, input_value=False) self._forward_widget_command( - 'status_message', 'show', None, input_value=False - ) - self._forward_widget_command( - 'status_message', 'update', None, input_value=False + "status_message", "update", None, input_value=False ) if msg: logger.info(msg) def _follow_fiducial_view(self): fid = self._current_fiducial.lower() - view = dict(lpa='left', rpa='right', nasion='front') - kwargs = dict(front=(90., 90.), left=(180, 90), right=(0., 90)) - kwargs = dict(zip(('azimuth', 'elevation'), kwargs[view[fid]])) + view = dict(lpa="left", rpa="right", nasion="front") + kwargs = dict(front=(90.0, 90.0), left=(180, 90), right=(0.0, 90)) + kwargs = dict(zip(("azimuth", "elevation"), kwargs[view[fid]])) if not self._lock_fids: self._renderer.set_camera(distance=None, **kwargs) @@ -963,35 +1049,39 @@ def _update_fiducials(self): return idx = _map_fid_name_to_idx(name=fid) - val = self.coreg.fiducials.dig[idx]['r'] * 1e3 + val = self.coreg.fiducials.dig[idx]["r"] * 1e3 with self._lock(plot=True): - self._forward_widget_command( - ["fid_X", "fid_Y", "fid_Z"], "set_value", val) + self._forward_widget_command(["fid_X", "fid_Y", "fid_Z"], "set_value", val) def _update_distance_estimation(self): - value = self.coreg._get_fiducials_distance_str() + '\n' + \ - self.coreg._get_point_distance_str() + value = ( + self.coreg._get_fiducials_distance_str() + + "\n" + + self.coreg._get_point_distance_str() + ) dists = self.coreg.compute_dig_mri_distances() * 1e3 if self._hsp_weight > 0: - value += "\nHSP <-> MRI (mean/min/max): "\ - f"{np.mean(dists):.2f} "\ + value += ( + "\nHSP <-> MRI (mean/min/max): " + f"{np.mean(dists):.2f} " f"/ {np.min(dists):.2f} / {np.max(dists):.2f} mm" + ) self._forward_widget_command("fit_label", "set_value", value) def _update_parameters(self): with self._lock(plot=True, params=True): # rotation deg = np.rad2deg(self.coreg._rotation) - logger.debug(f' Rotation: {deg}') + logger.debug(f" Rotation: {deg}") self._forward_widget_command(["rX", "rY", "rZ"], "set_value", deg) # translation mm = self.coreg._translation * 1e3 - logger.debug(f' Translation: {mm}') + logger.debug(f" Translation: {mm}") self._forward_widget_command(["tX", "tY", "tZ"], "set_value", mm) # scale sc = self.coreg._scale * 1e2 - logger.debug(f' Scale: {sc}') + logger.debug(f" Scale: {sc}") self._forward_widget_command(["sX", "sY", "sZ"], "set_value", sc) def _reset(self, keep_trans=False): @@ -1011,8 +1101,9 @@ def _reset(self, keep_trans=False): self._update_parameters() self._update_distance_estimation() - def _forward_widget_command(self, names, command, value, - input_value=True, output_value=False): + def _forward_widget_command( + self, names, command, value, input_value=True, output_value=False + ): """Invoke a method of one or more widgets if the widgets exist. Parameters @@ -1035,11 +1126,7 @@ def _forward_widget_command(self, names, command, value, ``None`` if ``output_value`` is ``False``, and the return value of ``command`` otherwise. """ - _validate_type( - item=names, - types=(str, list), - item_name='names' - ) + _validate_type(item=names, types=(str, list), item_name="names") if isinstance(names, str): names = [names] @@ -1058,8 +1145,7 @@ def _forward_widget_command(self, names, command, value, return ret def _set_sensors_visibility(self, state): - sensors = ["head_fiducials", "hpi_coils", "head_shape_points", - "eeg_channels"] + sensors = ["head_fiducials", "hpi_coils", "head_shape_points", "eeg_channels"] for sensor in sensors: if sensor in self._actors and self._actors[sensor] is not None: actors = self._actors[sensor] @@ -1070,14 +1156,18 @@ def _set_sensors_visibility(self, state): def _update_actor(self, actor_name, actor): # XXX: internal plotter/renderer should not be exposed - self._renderer.plotter.remove_actor(self._actors.get(actor_name), - render=False) + self._renderer.plotter.remove_actor(self._actors.get(actor_name), render=False) self._actors[actor_name] = actor def _add_mri_fiducials(self): mri_fids_actors = _plot_mri_fiducials( - self._renderer, self.coreg._fid_points, self._subjects_dir, - self._subject, self._to_cf_t, self._fid_colors) + self._renderer, + self.coreg._fid_points, + self._subjects_dir, + self._subject, + self._to_cf_t, + self._fid_colors, + ) # disable picking on the markers for actor in mri_fids_actors: actor.SetPickable(False) @@ -1085,19 +1175,24 @@ def _add_mri_fiducials(self): def _add_head_fiducials(self): head_fids_actors = _plot_head_fiducials( - self._renderer, self._info, self._to_cf_t, self._fid_colors) + self._renderer, self._info, self._to_cf_t, self._fid_colors + ) self._update_actor("head_fiducials", head_fids_actors) def _add_hpi_coils(self): if self._hpi_coils: hpi_actors = _plot_hpi_coils( - self._renderer, self._info, self._to_cf_t, + self._renderer, + self._info, + self._to_cf_t, opacity=self._defaults["sensor_opacity"], scale=DEFAULTS["coreg"]["extra_scale"], orient_glyphs=self._orient_glyphs, scale_by_distance=self._scale_by_distance, - surf=self._head_geo, check_inside=self._check_inside, - nearest=self._nearest) + surf=self._head_geo, + check_inside=self._check_inside, + nearest=self._nearest, + ) else: hpi_actors = None self._update_actor("hpi_coils", hpi_actors) @@ -1105,13 +1200,18 @@ def _add_hpi_coils(self): def _add_head_shape_points(self): if self._head_shape_points: hsp_actors = _plot_head_shape_points( - self._renderer, self._info, self._to_cf_t, + self._renderer, + self._info, + self._to_cf_t, opacity=self._defaults["sensor_opacity"], orient_glyphs=self._orient_glyphs, scale_by_distance=self._scale_by_distance, - mark_inside=self._mark_inside, surf=self._head_geo, + mark_inside=self._mark_inside, + surf=self._head_geo, mask=self.coreg._extra_points_filter, - check_inside=self._check_inside, nearest=self._nearest) + check_inside=self._check_inside, + nearest=self._nearest, + ) else: hsp_actors = None self._update_actor("head_shape_points", hsp_actors) @@ -1122,14 +1222,23 @@ def _add_eeg_channels(self): picks = pick_types(self._info, eeg=(len(eeg) > 0), fnirs=True) if len(picks) > 0: actors = _plot_sensors( - self._renderer, self._info, self._to_cf_t, picks, - meg=False, eeg=eeg, fnirs=["sources", "detectors"], - warn_meg=False, head_surf=self._head_geo, units='m', + self._renderer, + self._info, + self._to_cf_t, + picks, + meg=False, + eeg=eeg, + fnirs=["sources", "detectors"], + warn_meg=False, + head_surf=self._head_geo, + units="m", sensor_opacity=self._defaults["sensor_opacity"], orient_glyphs=self._orient_glyphs, scale_by_distance=self._scale_by_distance, - surf=self._head_geo, check_inside=self._check_inside, - nearest=self._nearest) + surf=self._head_geo, + check_inside=self._check_inside, + nearest=self._nearest, + ) sens_actors = actors["eeg"] sens_actors.extend(actors["fnirs"]) else: @@ -1141,22 +1250,34 @@ def _add_eeg_channels(self): def _add_head_surface(self): bem = None if self._head_resolution: - surface = 'head-dense' - key = 'high' + surface = "head-dense" + key = "high" else: - surface = 'head' - key = 'low' + surface = "head" + key = "low" try: head_actor, head_surf, _ = _plot_head_surface( - self._renderer, surface, self._subject, - self._subjects_dir, bem, self._coord_frame, self._to_cf_t, - alpha=self._head_opacity) + self._renderer, + surface, + self._subject, + self._subjects_dir, + bem, + self._coord_frame, + self._to_cf_t, + alpha=self._head_opacity, + ) except OSError: head_actor, head_surf, _ = _plot_head_surface( - self._renderer, "head", self._subject, self._subjects_dir, - bem, self._coord_frame, self._to_cf_t, - alpha=self._head_opacity) - key = 'low' + self._renderer, + "head", + self._subject, + self._subjects_dir, + bem, + self._coord_frame, + self._to_cf_t, + alpha=self._head_opacity, + ) + key = "low" self._update_actor("head", head_actor) # mark head surface mesh to restrict picking head_surf._picking_target = True @@ -1170,16 +1291,16 @@ def _add_head_surface(self): nn = self._surfaces["head"].point_normals assert nn.shape == (len(rr), 3), nn.shape self._head_geo = dict(rr=rr, tris=tris, nn=nn) - self._check_inside = _CheckInside(head_surf, mode='pyvista') + self._check_inside = _CheckInside(head_surf, mode="pyvista") self._nearest = _DistanceQuery(rr) def _add_helmet(self): if self._helmet: - logger.debug('Drawing helmet') - head_mri_t = _get_trans(self.coreg.trans, 'head', 'mri')[0] + logger.debug("Drawing helmet") + head_mri_t = _get_trans(self.coreg.trans, "head", "mri")[0] helmet_actor, _, _ = _plot_helmet( - self._renderer, self._info, self._to_cf_t, head_mri_t, - self._coord_frame) + self._renderer, self._info, self._to_cf_t, head_mri_t, self._coord_frame + ) else: helmet_actor = None self._update_actor("helmet", helmet_actor) @@ -1199,7 +1320,8 @@ def _fits_fiducials(self): ) end = time.time() self._display_message( - f"Fitting fiducials finished in {end - start:.2f} seconds.") + f"Fitting fiducials finished in {end - start:.2f} seconds." + ) self._update_plot("sensors") self._update_parameters() self._update_distance_estimation() @@ -1214,13 +1336,12 @@ def _fits_icp(self): def _fit_icp_real(self, *, update_head): with self._lock(params=True, fitting=True): self._current_icp_iterations = 0 - updates = ['hsp', 'hpi', 'eeg', 'head_fids', 'helmet'] + updates = ["hsp", "hpi", "eeg", "head_fids", "helmet"] if update_head: - updates.insert(0, 'head') + updates.insert(0, "head") def callback(iteration, n_iterations): - self._display_message( - f"Fitting ICP - iteration {iteration + 1}") + self._display_message(f"Fitting ICP - iteration {iteration + 1}") self._update_plot(updates) self._current_icp_iterations += 1 self._update_distance_estimation() @@ -1240,11 +1361,13 @@ def callback(iteration, n_iterations): self._display_message() self._display_message( f"Fitting ICP finished in {end - start:.2f} seconds and " - f"{self._current_icp_iterations} iterations.") + f"{self._current_icp_iterations} iterations." + ) del self._current_icp_iterations def _task_save_subject(self): from ..viz.backends.renderer import MNE_3D_BACKEND_TESTING + if MNE_3D_BACKEND_TESTING: self._save_subject() else: @@ -1252,12 +1375,21 @@ def _task_save_subject(self): def _task_set_parameter(self, value, mode_name, coord): from ..viz.backends.renderer import MNE_3D_BACKEND_TESTING + if MNE_3D_BACKEND_TESTING: self._set_parameter(value, mode_name, coord, self._plot_locked) else: - self._parameter_queue.put(_WorkerData("set_parameter", dict( - value=value, mode_name=mode_name, coord=coord, - plot_locked=self._plot_locked))) + self._parameter_queue.put( + _WorkerData( + "set_parameter", + dict( + value=value, + mode_name=mode_name, + coord=coord, + plot_locked=self._plot_locked, + ), + ) + ) def _overwrite_subject_callback(self, button_name): if button_name == "Yes": @@ -1270,9 +1402,10 @@ def _overwrite_subject_callback(self, button_name): def _check_subject_exists(self): if not self._subject_to: return False - subject_dirname = os.path.join('{subjects_dir}', '{subject}') - dest = subject_dirname.format(subject=self._subject_to, - subjects_dir=self._subjects_dir) + subject_dirname = os.path.join("{subjects_dir}", "{subject}") + dest = subject_dirname.format( + subject=self._subject_to, subjects_dir=self._subjects_dir + ) return os.path.exists(dest) def _save_subject(self, exit_mode=False): @@ -1286,19 +1419,19 @@ def _save_subject_callback(self, overwrite=False): self._display_message(f"Saving {self._subject_to}...") default_cursor = self._renderer._window_get_cursor() self._renderer._window_set_cursor( - self._renderer._window_new_cursor("WaitCursor")) + self._renderer._window_new_cursor("WaitCursor") + ) # prepare bem bem_names = [] if self._scale_mode != "None": - can_prepare_bem = _mri_subject_has_bem( - self._subject, self._subjects_dir) + can_prepare_bem = _mri_subject_has_bem(self._subject, self._subjects_dir) else: can_prepare_bem = False if can_prepare_bem: - pattern = bem_fname.format(subjects_dir=self._subjects_dir, - subject=self._subject, - name='(.+-bem)') + pattern = bem_fname.format( + subjects_dir=self._subjects_dir, subject=self._subject, name="(.+-bem)" + ) bem_dir, pattern = os.path.split(pattern) for filename in os.listdir(bem_dir): match = re.match(pattern, filename) @@ -1309,10 +1442,15 @@ def _save_subject_callback(self, overwrite=False): try: self._display_message(f"Scaling {self._subject_to}...") scale_mri( - subject_from=self._subject, subject_to=self._subject_to, - scale=self.coreg._scale, overwrite=overwrite, - subjects_dir=self._subjects_dir, skip_fiducials=True, - labels=True, annot=True, on_defects='ignore' + subject_from=self._subject, + subject_to=self._subject_to, + scale=self.coreg._scale, + overwrite=overwrite, + subjects_dir=self._subjects_dir, + skip_fiducials=True, + labels=True, + annot=True, + on_defects="ignore", ) except Exception: logger.error(f"Error scaling {self._subject_to}") @@ -1324,16 +1462,17 @@ def _save_subject_callback(self, overwrite=False): for bem_name in bem_names: try: self._display_message(f"Computing {bem_name} solution...") - bem_file = bem_fname.format(subjects_dir=self._subjects_dir, - subject=self._subject_to, - name=bem_name) + bem_file = bem_fname.format( + subjects_dir=self._subjects_dir, + subject=self._subject_to, + name=bem_name, + ) bemsol = make_bem_solution(bem_file) - write_bem_solution(bem_file[:-4] + '-sol.fif', bemsol) + write_bem_solution(bem_file[:-4] + "-sol.fif", bemsol) except Exception: logger.error(f"Error computing {bem_name} solution") else: - self._display_message(f"Computing {bem_name} solution..." - " Done!") + self._display_message(f"Computing {bem_name} solution..." " Done!") self._display_message(f"Saving {self._subject_to}... Done!") self._renderer._window_set_cursor(default_cursor) self._mri_scale_modified = False @@ -1342,7 +1481,7 @@ def _save_mri_fiducials(self, fname): self._display_message(f"Saving {fname}...") dig_montage = self.coreg.fiducials write_fiducials( - fname=fname, pts=dig_montage.dig, coord_frame='mri', overwrite=True + fname=fname, pts=dig_montage.dig, coord_frame="mri", overwrite=True ) self._set_fiducials_file(fname) self._display_message(f"Saving {fname}... Done!") @@ -1350,13 +1489,13 @@ def _save_mri_fiducials(self, fname): def _save_trans(self, fname): write_trans(fname, self.coreg.trans, overwrite=True) - self._display_message( - f"{fname} transform file is saved.") + self._display_message(f"{fname} transform file is saved.") self._trans_modified = False def _load_trans(self, fname): - mri_head_t = _ensure_trans(read_trans(fname, return_all=True), - 'mri', 'head')['trans'] + mri_head_t = _ensure_trans(read_trans(fname, return_all=True), "mri", "head")[ + "trans" + ] rot_x, rot_y, rot_z = rotation_angles(mri_head_t) x, y, z = mri_head_t[:3, 3] self.coreg._update_params( @@ -1366,17 +1505,16 @@ def _load_trans(self, fname): self._update_parameters() self._update_distance_estimation() self._update_plot() - self._display_message( - f"{fname} transform file is loaded.") + self._display_message(f"{fname} transform file is loaded.") def _update_fiducials_label(self): if self._fiducials_file is None: text = ( - '

No custom MRI fiducials loaded!

' - '

MRI fiducials could not be found in the standard ' - 'location. The displayed initial MRI fiducial locations ' - '(diamonds) were derived from fsaverage. Place, lock, and ' - 'save fiducials to discard this message.

' + "

No custom MRI fiducials loaded!

" + "

MRI fiducials could not be found in the standard " + "location. The displayed initial MRI fiducial locations " + "(diamonds) were derived from fsaverage. Place, lock, and " + "save fiducials to discard this message.

" ) else: assert self._fiducials_file == fid_fname.format( @@ -1384,30 +1522,24 @@ def _update_fiducials_label(self): ) assert self.coreg._fid_accurate is True text = ( - f'

MRI fiducials (diamonds) loaded from ' - f'standard location:

' - f'

{self._fiducials_file}

' + f"

MRI fiducials (diamonds) loaded from " + f"standard location:

" + f"

{self._fiducials_file}

" ) - self._forward_widget_command( - 'mri_fiducials_label', 'set_value', text - ) + self._forward_widget_command("mri_fiducials_label", "set_value", text) def _configure_dock(self): - if self._renderer._kind == 'notebook': + if self._renderer._kind == "notebook": collapse = True # collapsible and collapsed else: collapse = None # not collapsible - self._renderer._dock_initialize( - name="Input", area="left", max_width="350px" - ) + self._renderer._dock_initialize(name="Input", area="left", max_width="350px") mri_subject_layout = self._renderer._dock_add_group_box( name="MRI Subject", collapse=collapse, ) - subjects_dir_layout = self._renderer._dock_add_layout( - vertical=False - ) + subjects_dir_layout = self._renderer._dock_add_layout(vertical=False) self._widgets["subjects_dir_field"] = self._renderer._dock_add_text( name="subjects_dir_field", value=self._subjects_dir, @@ -1422,7 +1554,7 @@ def _configure_dock(self): is_directory=True, icon=True, tooltip="Load the path to the directory containing the " - "FreeSurfer subjects", + "FreeSurfer subjects", layout=subjects_dir_layout, ) self._renderer._layout_add_widget( @@ -1444,38 +1576,33 @@ def _configure_dock(self): collapse=collapse, ) # Add MRI fiducials I/O widgets - self._widgets['mri_fiducials_label'] = self._renderer._dock_add_label( - value='', # Will be filled via _update_fiducials_label() + self._widgets["mri_fiducials_label"] = self._renderer._dock_add_label( + value="", # Will be filled via _update_fiducials_label() layout=mri_fiducials_layout, - selectable=True + selectable=True, ) # Reload & Save buttons go into their own layout widget - mri_fiducials_button_layout = self._renderer._dock_add_layout( - vertical=False - ) + mri_fiducials_button_layout = self._renderer._dock_add_layout(vertical=False) self._renderer._layout_add_widget( - layout=mri_fiducials_layout, - widget=mri_fiducials_button_layout + layout=mri_fiducials_layout, widget=mri_fiducials_button_layout ) self._widgets["reload_mri_fids"] = self._renderer._dock_add_button( - name='Reload MRI Fid.', + name="Reload MRI Fid.", callback=lambda: self._set_fiducials_file(self._fiducials_file), tooltip="Reload MRI fiducials from the standard location", layout=mri_fiducials_button_layout, ) # Disable reload button until we've actually loaded a fiducial file # (happens in _set_fiducials_file method) - self._forward_widget_command('reload_mri_fids', 'set_enabled', False) + self._forward_widget_command("reload_mri_fids", "set_enabled", False) self._widgets["save_mri_fids"] = self._renderer._dock_add_button( name="Save MRI Fid.", callback=lambda: self._save_mri_fiducials( - fid_fname.format( - subjects_dir=self._subjects_dir, subject=self._subject - ) + fid_fname.format(subjects_dir=self._subjects_dir, subject=self._subject) ), tooltip="Save MRI fiducials to the standard location. Fiducials " - "must be locked first!", + "must be locked first!", layout=mri_fiducials_button_layout, ) self._widgets["lock_fids"] = self._renderer._dock_add_check_box( @@ -1497,7 +1624,7 @@ def _configure_dock(self): name = f"fid_{coord}" self._widgets[name] = self._renderer._dock_add_spin_box( name=coord, - value=0., + value=0.0, rng=[-1e3, 1e3], callback=partial( self._set_fiducial, @@ -1509,16 +1636,13 @@ def _configure_dock(self): tooltip=f"Set the {coord} fiducial coordinate", layout=fiducial_coords_layout, ) - self._renderer._layout_add_widget( - mri_fiducials_layout, fiducial_coords_layout) + self._renderer._layout_add_widget(mri_fiducials_layout, fiducial_coords_layout) dig_source_layout = self._renderer._dock_add_group_box( name="Info source with digitization", collapse=collapse, ) - info_file_layout = self._renderer._dock_add_layout( - vertical=False - ) + info_file_layout = self._renderer._dock_add_layout(vertical=False) self._widgets["info_file_field"] = self._renderer._dock_add_text( name="info_file_field", value=self._info_file, @@ -1531,8 +1655,7 @@ def _configure_dock(self): desc="Load", func=self._set_info_file, icon=True, - tooltip="Load the FIFF file with digitization data for " - "coregistration", + tooltip="Load the FIFF file with digitization data for " "coregistration", layout=info_file_layout, ) self._renderer._layout_add_widget( @@ -1561,7 +1684,7 @@ def _configure_dock(self): name="Omit", callback=self._omit_hsp, tooltip="Exclude the head shape points that are far away from " - "the MRI head", + "the MRI head", layout=omit_hsp_layout_2, ) self._widgets["reset_omit"] = self._renderer._dock_add_button( @@ -1629,7 +1752,7 @@ def _configure_dock(self): self._widgets[name] = self._renderer._dock_add_spin_box( name=name, value=attr[coords.index(coord)] * 1e2, - rng=[1., 10000.], # percent + rng=[1.0, 10000.0], # percent callback=partial( self._set_parameter, mode_name="scale", @@ -1647,18 +1770,17 @@ def _configure_dock(self): name="Fit fiducials with scaling", callback=self._fits_fiducials, tooltip="Find MRI scaling, rotation, and translation to fit all " - "3 fiducials", + "3 fiducials", layout=fit_scale_layout, ) self._widgets["fits_icp"] = self._renderer._dock_add_button( name="Fit ICP with scaling", callback=self._fits_icp, tooltip="Find MRI scaling, rotation, and translation to match the " - "head shape points", + "head shape points", layout=fit_scale_layout, ) - self._renderer._layout_add_widget( - scale_params_layout, fit_scale_layout) + self._renderer._layout_add_widget(scale_params_layout, fit_scale_layout) subject_to_layout = self._renderer._dock_add_layout(vertical=False) self._widgets["subject_to"] = self._renderer._dock_add_text( name="subject-to", @@ -1673,8 +1795,7 @@ def _configure_dock(self): tooltip="Save scaled anatomy", layout=subject_to_layout, ) - self._renderer._layout_add_widget( - mri_scaling_layout, subject_to_layout) + self._renderer._layout_add_widget(mri_scaling_layout, subject_to_layout) param_layout = self._renderer._dock_add_group_box( name="Translation (t) and Rotation (r)", collapse=collapse, @@ -1699,8 +1820,8 @@ def _configure_dock(self): double=True, step=1, tooltip=f"Set the {coord} {mode_name.lower()}" - f" parameter (in {unit})", - layout=coord_layout + f" parameter (in {unit})", + layout=coord_layout, ) self._renderer._layout_add_widget(param_layout, coord_layout) @@ -1714,8 +1835,7 @@ def _configure_dock(self): self._widgets["fit_icp"] = self._renderer._dock_add_button( name="Fit ICP", callback=self._fit_icp, - tooltip="Find rotation and translation to match the " - "head shape points", + tooltip="Find rotation and translation to match the " "head shape points", layout=fit_layout, ) self._renderer._layout_add_widget(param_layout, fit_layout) @@ -1731,7 +1851,7 @@ def _configure_dock(self): func=self._save_trans, tooltip="Save the transform file to disk", layout=save_trans_layout, - filter='Head->MRI transformation (*-trans.fif *_trans.fif)', + filter="Head->MRI transformation (*-trans.fif *_trans.fif)", initial_directory=str(Path(self._info_file).parent), ) self._widgets["load_trans"] = self._renderer._dock_add_file_button( @@ -1740,7 +1860,7 @@ def _configure_dock(self): func=self._load_trans, tooltip="Load the transform file from disk", layout=save_trans_layout, - filter='Head->MRI transformation (*-trans.fif *_trans.fif)', + filter="Head->MRI transformation (*-trans.fif *_trans.fif)", initial_directory=str(Path(self._info_file).parent), ) self._renderer._layout_add_widget(trans_layout, save_trans_layout) @@ -1782,15 +1902,14 @@ def _configure_dock(self): name="Weights", layout=fitting_options_layout, ) - for point, fid in zip(("HSP", "EEG", "HPI"), - self._defaults["fiducials"]): + for point, fid in zip(("HSP", "EEG", "HPI"), self._defaults["fiducials"]): weight_layout = self._renderer._dock_add_layout(vertical=False) point_lower = point.lower() name = f"{point_lower}_weight" self._widgets[name] = self._renderer._dock_add_spin_box( name=point, value=getattr(self, f"_{point_lower}_weight"), - rng=[0., 100.], + rng=[0.0, 100.0], callback=partial(self._set_point_weight, point=point_lower), compact=True, double=True, @@ -1803,7 +1922,7 @@ def _configure_dock(self): self._widgets[name] = self._renderer._dock_add_spin_box( name=fid, value=getattr(self, f"_{fid_lower}_weight"), - rng=[0., 100.], + rng=[0.0, 100.0], callback=partial(self._set_point_weight, point=fid_lower), compact=True, double=True, @@ -1811,23 +1930,21 @@ def _configure_dock(self): layout=weight_layout, ) self._renderer._layout_add_widget(weights_layout, weight_layout) - self._widgets['reset_fitting_options'] = ( - self._renderer._dock_add_button( - name="Reset Fitting Options", - callback=self._reset_fitting_parameters, - tooltip="Reset all the fitting parameters to default value", - layout=fitting_options_layout, - ) + self._widgets["reset_fitting_options"] = self._renderer._dock_add_button( + name="Reset Fitting Options", + callback=self._reset_fitting_parameters, + tooltip="Reset all the fitting parameters to default value", + layout=fitting_options_layout, ) self._renderer._dock_add_stretch() def _configure_status_bar(self): self._renderer._status_bar_initialize() - self._widgets['status_message'] = self._renderer._status_bar_add_label( + self._widgets["status_message"] = self._renderer._status_bar_add_label( "", stretch=1 ) self._forward_widget_command( - 'status_message', 'hide', value=None, input_value=False + "status_message", "hide", value=None, input_value=False ) def _clean(self): @@ -1851,17 +1968,16 @@ def close(self): def _close_dialog_callback(self, button_name): from ..viz.backends.renderer import MNE_3D_BACKEND_TESTING + self._accept_close_event = True if button_name == "Save": if self._trans_modified: - self._forward_widget_command( - "save_trans", "set_value", None) + self._forward_widget_command("save_trans", "set_value", None) # cancel means _save_trans is not called if self._trans_modified: self._accept_close_event = False if self._mri_fids_modified: - self._forward_widget_command( - "save_mri_fids", "set_value", None) + self._forward_widget_command("save_mri_fids", "set_value", None) if self._mri_scale_modified: if self._subject_to: self._save_subject(exit_mode=True) @@ -1869,7 +1985,7 @@ def _close_dialog_callback(self, button_name): dialog = self._renderer._dialog_create( title="CoregistrationUI", text="The name of the output subject used to " - "save the scaled anatomy is not set.", + "save the scaled anatomy is not set.", info_text="Please set a subject name", callback=lambda x: None, buttons=["Ok"], @@ -1883,9 +1999,9 @@ def _close_dialog_callback(self, button_name): assert button_name == "Discard" def _close_callback(self): - if self._trans_modified or self._mri_fids_modified or \ - self._mri_scale_modified: + if self._trans_modified or self._mri_fids_modified or self._mri_scale_modified: from ..viz.backends.renderer import MNE_3D_BACKEND_TESTING + # prepare the dialog's text text = "The following is/are not saved:" text += "
    " diff --git a/mne/gui/_ieeg_locate.py b/mne/gui/_ieeg_locate.py index a23590d7317..fcc9788e6ce 100644 --- a/mne/gui/_ieeg_locate.py +++ b/mne/gui/_ieeg_locate.py @@ -11,10 +11,18 @@ from qtpy import QtCore, QtGui from qtpy.QtCore import Slot, Signal -from qtpy.QtWidgets import (QVBoxLayout, QHBoxLayout, QLabel, - QMessageBox, QWidget, QAbstractItemView, - QListView, QSlider, QPushButton, - QComboBox) +from qtpy.QtWidgets import ( + QVBoxLayout, + QHBoxLayout, + QLabel, + QMessageBox, + QWidget, + QAbstractItemView, + QListView, + QSlider, + QPushButton, + QComboBox, +) from matplotlib.colors import LinearSegmentedColormap @@ -29,20 +37,34 @@ _RADIUS_SCALAR = 0.4 _TUBE_SCALAR = 0.1 _BOLT_SCALAR = 30 # mm -_CH_MENU_WIDTH = 30 if platform.system() == 'Windows' else 10 +_CH_MENU_WIDTH = 30 if platform.system() == "Windows" else 10 # 20 colors generated to be evenly spaced in a cube, worked better than # matplotlib color cycle -_UNIQUE_COLORS = [(0.1, 0.42, 0.43), (0.9, 0.34, 0.62), (0.47, 0.51, 0.3), - (0.47, 0.55, 0.99), (0.79, 0.68, 0.06), (0.34, 0.74, 0.05), - (0.58, 0.87, 0.13), (0.86, 0.98, 0.4), (0.92, 0.91, 0.66), - (0.77, 0.38, 0.34), (0.9, 0.37, 0.1), (0.2, 0.62, 0.9), - (0.22, 0.65, 0.64), (0.14, 0.94, 0.8), (0.34, 0.31, 0.68), - (0.59, 0.28, 0.74), (0.46, 0.19, 0.94), (0.37, 0.93, 0.7), - (0.56, 0.86, 0.55), (0.67, 0.69, 0.44)] +_UNIQUE_COLORS = [ + (0.1, 0.42, 0.43), + (0.9, 0.34, 0.62), + (0.47, 0.51, 0.3), + (0.47, 0.55, 0.99), + (0.79, 0.68, 0.06), + (0.34, 0.74, 0.05), + (0.58, 0.87, 0.13), + (0.86, 0.98, 0.4), + (0.92, 0.91, 0.66), + (0.77, 0.38, 0.34), + (0.9, 0.37, 0.1), + (0.2, 0.62, 0.9), + (0.22, 0.65, 0.64), + (0.14, 0.94, 0.8), + (0.34, 0.31, 0.68), + (0.59, 0.28, 0.74), + (0.46, 0.19, 0.94), + (0.37, 0.93, 0.7), + (0.56, 0.86, 0.55), + (0.67, 0.69, 0.44), +] _N_COLORS = len(_UNIQUE_COLORS) -_CMAP = LinearSegmentedColormap.from_list( - 'ch_colors', _UNIQUE_COLORS, N=_N_COLORS) +_CMAP = LinearSegmentedColormap.from_list("ch_colors", _UNIQUE_COLORS, N=_N_COLORS) class ComboBox(QComboBox): @@ -59,8 +81,17 @@ def showPopup(self): class IntracranialElectrodeLocator(SliceBrowser): """Locate electrode contacts using a coregistered MRI and CT.""" - def __init__(self, info, trans, base_image, subject=None, - subjects_dir=None, groups=None, show=True, verbose=None): + def __init__( + self, + info, + trans, + base_image, + subject=None, + subjects_dir=None, + groups=None, + show=True, + verbose=None, + ): """GUI for locating intracranial electrodes. .. note:: Images will be displayed using orientation information @@ -68,7 +99,7 @@ def __init__(self, info, trans, base_image, subject=None, dimensions [256, 256, 256] for display. """ if not info.ch_names: - raise ValueError('No channels found in `info` to locate') + raise ValueError("No channels found in `info` to locate") # store info for modification self._info = info @@ -82,29 +113,33 @@ def __init__(self, info, trans, base_image, subject=None, # initialize channel data self._ch_index = 0 # load data, apply trans - self._head_mri_t = _get_trans(trans, 'head', 'mri')[0] + self._head_mri_t = _get_trans(trans, "head", "mri")[0] self._mri_head_t = invert_transform(self._head_mri_t) # ensure channel positions in head montage = info.get_montage() - if montage and montage.get_positions()['coord_frame'] != 'head': - raise RuntimeError('Channel positions in the ``info`` object must ' - 'be in the "head" coordinate frame.') + if montage and montage.get_positions()["coord_frame"] != "head": + raise RuntimeError( + "Channel positions in the ``info`` object must " + 'be in the "head" coordinate frame.' + ) # load channels, convert from m to mm - self._chs = {name: apply_trans(self._head_mri_t, ch['loc'][:3]) * 1000 - for name, ch in zip(info.ch_names, info['chs'])} + self._chs = { + name: apply_trans(self._head_mri_t, ch["loc"][:3]) * 1000 + for name, ch in zip(info.ch_names, info["chs"]) + } self._ch_names = list(self._chs.keys()) self._group_channels(groups) # Initialize GUI super(IntracranialElectrodeLocator, self).__init__( - base_image=base_image, subject=subject, subjects_dir=subjects_dir) + base_image=base_image, subject=subject, subjects_dir=subjects_dir + ) # set current position as current contact location if exists if not np.isnan(self._chs[self._ch_names[self._ch_index]]).any(): - self._set_ras(self._chs[self._ch_names[self._ch_index]], - update_plots=False) + self._set_ras(self._chs[self._ch_names[self._ch_index]], update_plots=False) # add plots of contacts on top self._plot_ch_images() @@ -125,7 +160,7 @@ def __init__(self, info, trans, base_image, subject=None, def _configure_ui(self): # data is loaded for an abstract base image, associate with ct self._ct_data = self._base_data - self._images['ct'] = self._images['base'] + self._images["ct"] = self._images["base"] self._ct_maxima = None # don't compute until turned on toolbar = self._configure_toolbar() @@ -160,8 +195,7 @@ def _configure_channel_sidebar(self): self._color_list_item(name=name) ch_list.setModel(self._ch_list_model) ch_list.clicked.connect(self._go_to_ch) - ch_list.setCurrentIndex( - self._ch_list_model.index(self._ch_index, 0)) + ch_list.setCurrentIndex(self._ch_list_model.index(self._ch_index, 0)) ch_list.keyPressEvent = self.keyPressEvent return ch_list @@ -193,8 +227,11 @@ def color_ch_radius(ch_image, xf, yf, group, radius): dist = np.linalg.norm(xyz - self._current_slice) if proj or dist < self._radius: group = self._groups[name] - r = self._radius if proj else \ - self._radius - np.round(abs(dist)).astype(int) + r = ( + self._radius + if proj + else self._radius - np.round(abs(dist)).astype(int) + ) xf, yf = (xyz / vxyz)[list(self._xy_idx[axis])] ch_image = color_ch_radius(ch_image, xf, yf, group, r) return ch_image @@ -202,34 +239,52 @@ def color_ch_radius(ch_image, xf, yf, group, radius): @verbose def _save_ch_coords(self, info=None, verbose=None): """Save the location of the electrode contacts.""" - logger.info('Saving channel positions to `info`') + logger.info("Saving channel positions to `info`") if info is None: info = self._info montage = info.get_montage() - montage_kwargs = montage.get_positions() if montage else \ - dict(ch_pos=dict(), coord_frame='head') - for ch in info['chs']: + montage_kwargs = ( + montage.get_positions() + if montage + else dict(ch_pos=dict(), coord_frame="head") + ) + for ch in info["chs"]: # surface RAS-> head and mm->m - montage_kwargs['ch_pos'][ch['ch_name']] = apply_trans( - self._mri_head_t, self._chs[ch['ch_name']].copy() / 1000) + montage_kwargs["ch_pos"][ch["ch_name"]] = apply_trans( + self._mri_head_t, self._chs[ch["ch_name"]].copy() / 1000 + ) info.set_montage(make_dig_montage(**montage_kwargs)) def _plot_ch_images(self): img_delta = 0.5 - ch_deltas = list(img_delta * (self._voxel_sizes[ii] / _CH_PLOT_SIZE) - for ii in range(3)) + ch_deltas = list( + img_delta * (self._voxel_sizes[ii] / _CH_PLOT_SIZE) for ii in range(3) + ) self._ch_extents = list( - [-ch_delta, self._voxel_sizes[idx[0]] - ch_delta, - -ch_delta, self._voxel_sizes[idx[1]] - ch_delta] - for idx, ch_delta in zip(self._xy_idx, ch_deltas)) - self._images['chs'] = list() + [ + -ch_delta, + self._voxel_sizes[idx[0]] - ch_delta, + -ch_delta, + self._voxel_sizes[idx[1]] - ch_delta, + ] + for idx, ch_delta in zip(self._xy_idx, ch_deltas) + ) + self._images["chs"] = list() for axis in range(3): fig = self._figs[axis] ax = fig.axes[0] - self._images['chs'].append(ax.imshow( - self._make_ch_image(axis), aspect='auto', - extent=self._ch_extents[axis], zorder=3, - cmap=_CMAP, alpha=self._ch_alpha, vmin=0, vmax=_N_COLORS)) + self._images["chs"].append( + ax.imshow( + self._make_ch_image(axis), + aspect="auto", + extent=self._ch_extents[axis], + zorder=3, + cmap=_CMAP, + alpha=self._ch_alpha, + vmin=0, + vmax=_N_COLORS, + ) + ) self._3d_chs = dict() for name in self._chs: self._plot_3d_ch(name) @@ -237,12 +292,14 @@ def _plot_ch_images(self): def _plot_3d_ch(self, name, render=False): """Plot a single 3D channel.""" if name in self._3d_chs: - self._renderer.plotter.remove_actor( - self._3d_chs.pop(name), render=False) + self._renderer.plotter.remove_actor(self._3d_chs.pop(name), render=False) if not any(np.isnan(self._chs[name])): self._3d_chs[name] = self._renderer.sphere( - tuple(self._chs[name]), scale=1, - color=_CMAP(self._groups[name])[:3], opacity=self._ch_alpha)[0] + tuple(self._chs[name]), + scale=1, + color=_CMAP(self._groups[name])[:3], + opacity=self._ch_alpha, + )[0] # The actor scale is managed differently than the glyph scale # in order not to recreate objects, we use the actor scale self._3d_chs[name].SetOrigin(self._chs[name]) @@ -254,14 +311,14 @@ def _configure_toolbar(self): """Make a bar with buttons for user interactions.""" hbox = QHBoxLayout() - help_button = QPushButton('Help') + help_button = QPushButton("Help") help_button.released.connect(self._show_help) hbox.addWidget(help_button) hbox.addStretch(8) - hbox.addWidget(QLabel('Snap to Center')) - self._snap_button = QPushButton('Off') + hbox.addWidget(QLabel("Snap to Center")) + self._snap_button = QPushButton("Off") self._snap_button.setMaximumWidth(25) # not too big hbox.addWidget(self._snap_button) self._snap_button.released.connect(self._toggle_snap) @@ -269,17 +326,17 @@ def _configure_toolbar(self): hbox.addStretch(1) - self._toggle_brain_button = QPushButton('Show Brain') + self._toggle_brain_button = QPushButton("Show Brain") self._toggle_brain_button.released.connect(self._toggle_show_brain) hbox.addWidget(self._toggle_brain_button) hbox.addStretch(1) - mark_button = QPushButton('Mark') + mark_button = QPushButton("Mark") hbox.addWidget(mark_button) mark_button.released.connect(self.mark_channel) - remove_button = QPushButton('Remove') + remove_button = QPushButton("Remove") hbox.addWidget(remove_button) remove_button.released.connect(self.remove_channel) @@ -287,16 +344,16 @@ def _configure_toolbar(self): group_model = self._group_selector.model() for i in range(_N_COLORS): - self._group_selector.addItem(' ') + self._group_selector.addItem(" ") color = QtGui.QColor() color.setRgb(*(255 * np.array(_CMAP(i))).round().astype(int)) brush = QtGui.QBrush(color) brush.setStyle(QtCore.Qt.SolidPattern) - group_model.setData(group_model.index(i, 0), - brush, QtCore.Qt.BackgroundRole) + group_model.setData( + group_model.index(i, 0), brush, QtCore.Qt.BackgroundRole + ) self._group_selector.clicked.connect(self._select_group) - self._group_selector.currentIndexChanged.connect( - self._select_group) + self._group_selector.currentIndexChanged.connect(self._select_group) hbox.addWidget(self._group_selector) # update background color for current selection @@ -326,33 +383,33 @@ def make_slider(smin, smax, sval, sfun=None): slider_hbox = QHBoxLayout() ch_vbox = QVBoxLayout() - ch_vbox.addWidget(make_label('ch alpha')) - ch_vbox.addWidget(make_label('ch radius')) + ch_vbox.addWidget(make_label("ch alpha")) + ch_vbox.addWidget(make_label("ch radius")) slider_hbox.addLayout(ch_vbox) ch_slider_vbox = QVBoxLayout() - self._alpha_slider = make_slider(0, 100, self._ch_alpha * 100, - self._update_ch_alpha) + self._alpha_slider = make_slider( + 0, 100, self._ch_alpha * 100, self._update_ch_alpha + ) ch_plot_max = _CH_PLOT_SIZE // 50 # max 1 / 50 of plot size ch_slider_vbox.addWidget(self._alpha_slider) - self._radius_slider = make_slider(0, ch_plot_max, self._radius, - self._update_radius) + self._radius_slider = make_slider( + 0, ch_plot_max, self._radius, self._update_radius + ) ch_slider_vbox.addWidget(self._radius_slider) slider_hbox.addLayout(ch_slider_vbox) ct_vbox = QVBoxLayout() - ct_vbox.addWidget(make_label('CT min')) - ct_vbox.addWidget(make_label('CT max')) + ct_vbox.addWidget(make_label("CT min")) + ct_vbox.addWidget(make_label("CT max")) slider_hbox.addLayout(ct_vbox) ct_slider_vbox = QVBoxLayout() ct_min = int(round(np.nanmin(self._ct_data))) ct_max = int(round(np.nanmax(self._ct_data))) - self._ct_min_slider = make_slider( - ct_min, ct_max, ct_min, self._update_ct_scale) + self._ct_min_slider = make_slider(ct_min, ct_max, ct_min, self._update_ct_scale) ct_slider_vbox.addWidget(self._ct_min_slider) - self._ct_max_slider = make_slider( - ct_min, ct_max, ct_max, self._update_ct_scale) + self._ct_max_slider = make_slider(ct_min, ct_max, ct_max, self._update_ct_scale) ct_slider_vbox.addWidget(self._ct_max_slider) slider_hbox.addLayout(ct_slider_vbox) return slider_hbox @@ -362,22 +419,19 @@ def _configure_status_bar(self, hbox=None): hbox.addStretch(3) - self._toggle_show_mip_button = QPushButton('Show Max Intensity Proj') - self._toggle_show_mip_button.released.connect( - self._toggle_show_mip) + self._toggle_show_mip_button = QPushButton("Show Max Intensity Proj") + self._toggle_show_mip_button.released.connect(self._toggle_show_mip) hbox.addWidget(self._toggle_show_mip_button) - self._toggle_show_max_button = QPushButton('Show Maxima') - self._toggle_show_max_button.released.connect( - self._toggle_show_max) + self._toggle_show_max_button = QPushButton("Show Maxima") + self._toggle_show_max_button.released.connect(self._toggle_show_max) hbox.addWidget(self._toggle_show_max_button) - self._intensity_label = QLabel('') # update later + self._intensity_label = QLabel("") # update later hbox.addWidget(self._intensity_label) # add SliceBrowser navigation items - super(IntracranialElectrodeLocator, self)._configure_status_bar( - hbox=hbox) + super(IntracranialElectrodeLocator, self)._configure_status_bar(hbox=hbox) return hbox def _move_cursors_to_pos(self): @@ -389,8 +443,8 @@ def _group_channels(self, groups): if groups is not None: for name in self._ch_names: if name not in groups: - raise ValueError(f'{name} not found in ``groups``') - _validate_type(groups[name], (float, int), f'groups[{name}]') + raise ValueError(f"{name} not found in ``groups``") + _validate_type(groups[name], (float, int), f"groups[{name}]") self.groups = groups else: i = 0 @@ -398,8 +452,13 @@ def _group_channels(self, groups): base_names = dict() for name in self._ch_names: # strip all numbers from the name - base_name = ''.join([letter for letter in name if - not letter.isdigit() and letter != ' ']) + base_name = "".join( + [ + letter + for letter in name + if not letter.isdigit() and letter != " " + ] + ) if base_name in base_names: # look up group number by base name self._groups[name] = base_names[base_name] @@ -415,22 +474,24 @@ def _update_lines(self, group, only_2D=False): line.remove() self._lines_2D.pop(group) if only_2D: # if not in projection, don't add 2D lines - if self._toggle_show_mip_button.text() == \ - 'Show Max Intensity Proj': + if self._toggle_show_mip_button.text() == "Show Max Intensity Proj": return elif group in self._lines: # if updating 3D, remove first - self._renderer.plotter.remove_actor( - self._lines[group], render=False) - pos = np.array([ - self._chs[ch] for i, ch in enumerate(self._ch_names) - if self._groups[ch] == group and i in self._seeg_idx and - not np.isnan(self._chs[ch]).any()]) + self._renderer.plotter.remove_actor(self._lines[group], render=False) + pos = np.array( + [ + self._chs[ch] + for i, ch in enumerate(self._ch_names) + if self._groups[ch] == group + and i in self._seeg_idx + and not np.isnan(self._chs[ch]).any() + ] + ) if len(pos) < 2: # not enough points for line return # first, the insertion will be the point farthest from the origin # brains are a longer posterior-anterior, scale for this (80%) - insert_idx = np.argmax(np.linalg.norm(pos * np.array([1, 0.8, 1]), - axis=1)) + insert_idx = np.argmax(np.linalg.norm(pos * np.array([1, 0.8, 1]), axis=1)) # second, find the farthest point from the insertion target_idx = np.argmax(np.linalg.norm(pos[insert_idx] - pos, axis=1)) # third, make a unit vector and to add to the insertion for the bolt @@ -438,20 +499,31 @@ def _update_lines(self, group, only_2D=False): elec_v /= np.linalg.norm(elec_v) if not only_2D: self._lines[group] = self._renderer.tube( - [pos[target_idx]], [pos[insert_idx] + elec_v * _BOLT_SCALAR], - radius=self._radius * _TUBE_SCALAR, color=_CMAP(group)[:3])[0] - if self._toggle_show_mip_button.text() == 'Hide Max Intensity Proj': + [pos[target_idx]], + [pos[insert_idx] + elec_v * _BOLT_SCALAR], + radius=self._radius * _TUBE_SCALAR, + color=_CMAP(group)[:3], + )[0] + if self._toggle_show_mip_button.text() == "Hide Max Intensity Proj": # add 2D lines on each slice plot if in max intensity projection target_vox = apply_trans(self._ras_vox_t, pos[target_idx]) - insert_vox = apply_trans(self._ras_vox_t, - pos[insert_idx] + elec_v * _BOLT_SCALAR) + insert_vox = apply_trans( + self._ras_vox_t, pos[insert_idx] + elec_v * _BOLT_SCALAR + ) lines_2D = list() for axis in range(3): x, y = self._xy_idx[axis] - lines_2D.append(self._figs[axis].axes[0].plot( - [target_vox[x], insert_vox[x]], - [target_vox[y], insert_vox[y]], - color=_CMAP(group), linewidth=0.25, zorder=7)[0]) + lines_2D.append( + self._figs[axis] + .axes[0] + .plot( + [target_vox[x], insert_vox[x]], + [target_vox[y], insert_vox[y]], + color=_CMAP(group), + linewidth=0.25, + zorder=7, + )[0] + ) self._lines_2D[group] = lines_2D def _select_group(self): @@ -467,14 +539,14 @@ def _update_group(self): group = self._group_selector.currentIndex() rgb = (255 * np.array(_CMAP(group))).round().astype(int) self._group_selector.setStyleSheet( - 'background-color: rgb({:d},{:d},{:d})'.format(*rgb)) + "background-color: rgb({:d},{:d},{:d})".format(*rgb) + ) self._group_selector.update() def _update_ch_selection(self): """Update which channel is selected.""" name = self._ch_names[self._ch_index] - self._ch_list.setCurrentIndex( - self._ch_list_model.index(self._ch_index, 0)) + self._ch_list.setCurrentIndex(self._ch_list_model.index(self._ch_index, 0)) self._group_selector.setCurrentIndex(self._groups[name]) self._update_group() if not np.isnan(self._chs[name]).any(): @@ -496,7 +568,7 @@ def _next_ch(self): def _color_list_item(self, name=None): """Color the item in the view list for easy id of marked channels.""" name = self._ch_names[self._ch_index] if name is None else name - color = QtGui.QColor('white') + color = QtGui.QColor("white") if not np.isnan(self._chs[name]).any(): group = self._groups[name] color.setRgb(*[int(c * 255) for c in _CMAP(group)]) @@ -504,23 +576,27 @@ def _color_list_item(self, name=None): brush.setStyle(QtCore.Qt.SolidPattern) self._ch_list_model.setData( self._ch_list_model.index(self._ch_names.index(name), 0), - brush, QtCore.Qt.BackgroundRole) + brush, + QtCore.Qt.BackgroundRole, + ) # color text black - color = QtGui.QColor('black') + color = QtGui.QColor("black") brush = QtGui.QBrush(color) brush.setStyle(QtCore.Qt.SolidPattern) self._ch_list_model.setData( self._ch_list_model.index(self._ch_names.index(name), 0), - brush, QtCore.Qt.ForegroundRole) + brush, + QtCore.Qt.ForegroundRole, + ) @Slot() def _toggle_snap(self): """Toggle snapping the contact location to the center of mass.""" - if self._snap_button.text() == 'Off': - self._snap_button.setText('On') + if self._snap_button.text() == "Off": + self._snap_button.setText("On") self._snap_button.setStyleSheet("background-color: green") else: # text == 'On', turn off - self._snap_button.setText('Off') + self._snap_button.setText("Off") self._snap_button.setStyleSheet("background-color: red") @Slot() @@ -534,20 +610,27 @@ def mark_channel(self, ch=None): is marked. """ if ch is not None and ch not in self._ch_names: - raise ValueError(f'Channel {ch} not found') - name = self._ch_names[self._ch_index if ch is None else - self._ch_names.index(ch)] - if self._snap_button.text() == 'Off': + raise ValueError(f"Channel {ch} not found") + name = self._ch_names[ + self._ch_index if ch is None else self._ch_names.index(ch) + ] + if self._snap_button.text() == "Off": self._chs[name][:] = self._ras else: shape = np.mean(self._voxel_sizes) # Freesurfer shape (256) voxels_max = int( - 4 / 3 * np.pi * (shape * self._radius / _CH_PLOT_SIZE)**3) + 4 / 3 * np.pi * (shape * self._radius / _CH_PLOT_SIZE) ** 3 + ) neighbors = _voxel_neighbors( - self._vox, self._ct_data, thresh=0.5, - voxels_max=voxels_max, use_relative=True) + self._vox, + self._ct_data, + thresh=0.5, + voxels_max=voxels_max, + use_relative=True, + ) self._chs[name][:] = apply_trans( # to surface RAS - self._vox_ras_t, np.array(list(neighbors)).mean(axis=0)) + self._vox_ras_t, np.array(list(neighbors)).mean(axis=0) + ) self._color_list_item() self._update_lines(self._groups[name]) self._update_ch_images(draw=True) @@ -567,9 +650,10 @@ def remove_channel(self, ch=None): is removed. """ if ch is not None and ch not in self._ch_names: - raise ValueError(f'Channel {ch} not found') - name = self._ch_names[self._ch_index if ch is None else - self._ch_names.index(ch)] + raise ValueError(f"Channel {ch} not found") + name = self._ch_names[ + self._ch_index if ch is None else self._ch_names.index(ch) + ] self._chs[name] *= np.nan self._color_list_item() self._save_ch_coords() @@ -582,38 +666,37 @@ def remove_channel(self, ch=None): def _update_ch_images(self, axis=None, draw=False): """Update the channel image(s).""" for axis in range(3) if axis is None else [axis]: - self._images['chs'][axis].set_data( - self._make_ch_image(axis)) - if self._toggle_show_mip_button.text() == \ - 'Hide Max Intensity Proj': - self._images['mip_chs'][axis].set_data( - self._make_ch_image(axis, proj=True)) + self._images["chs"][axis].set_data(self._make_ch_image(axis)) + if self._toggle_show_mip_button.text() == "Hide Max Intensity Proj": + self._images["mip_chs"][axis].set_data( + self._make_ch_image(axis, proj=True) + ) if draw: self._draw(axis) def _update_ct_images(self, axis=None, draw=False): """Update the CT image(s).""" for axis in range(3) if axis is None else [axis]: - ct_data = np.take(self._ct_data, self._current_slice[axis], - axis=axis).T + ct_data = np.take(self._ct_data, self._current_slice[axis], axis=axis).T # Threshold the CT so only bright objects (electrodes) are visible ct_data[ct_data < self._ct_min_slider.value()] = np.nan ct_data[ct_data > self._ct_max_slider.value()] = np.nan - self._images['ct'][axis].set_data(ct_data) - if 'local_max' in self._images: + self._images["ct"][axis].set_data(ct_data) + if "local_max" in self._images: ct_max_data = np.take( - self._ct_maxima, self._current_slice[axis], axis=axis).T - self._images['local_max'][axis].set_data(ct_max_data) + self._ct_maxima, self._current_slice[axis], axis=axis + ).T + self._images["local_max"][axis].set_data(ct_max_data) if draw: self._draw(axis) def _update_mri_images(self, axis=None, draw=False): """Update the CT image(s).""" - if 'mri' in self._images: + if "mri" in self._images: for axis in range(3) if axis is None else [axis]: - self._images['mri'][axis].set_data( - np.take(self._mri_data, self._current_slice[axis], - axis=axis).T) + self._images["mri"][axis].set_data( + np.take(self._mri_data, self._current_slice[axis], axis=axis).T + ) if draw: self._draw(axis) @@ -635,7 +718,7 @@ def _update_ct_scale(self): def _update_radius(self): """Update channel plot radius.""" self._radius = np.round(self._radius_slider.value()).astype(int) - if self._toggle_show_max_button.text() == 'Hide Maxima': + if self._toggle_show_max_button.text() == "Hide Maxima": self._update_ct_maxima() self._update_ct_images() else: @@ -652,7 +735,7 @@ def _update_ch_alpha(self): """Update channel plot alpha.""" self._ch_alpha = self._alpha_slider.value() / 100 for axis in range(3): - self._images['chs'][axis].set_alpha(self._ch_alpha) + self._images["chs"][axis].set_alpha(self._ch_alpha) self._draw() for actor in self._3d_chs.values(): actor.GetProperty().SetOpacity(self._ch_alpha) @@ -662,35 +745,45 @@ def _update_ch_alpha(self): def _show_help(self): """Show the help menu.""" QMessageBox.information( - self, 'Help', + self, + "Help", "Help:\n'm': mark channel location\n" "'r': remove channel location\n" "'b': toggle viewing of brain in T1\n" "'+'/'-': zoom\nleft/right arrow: left/right\n" "up/down arrow: superior/inferior\n" - "left angle bracket/right angle bracket: anterior/posterior") + "left angle bracket/right angle bracket: anterior/posterior", + ) def _update_ct_maxima(self): """Compute the maximum voxels based on the current radius.""" - self._ct_maxima = maximum_filter( - self._ct_data, (self._radius,) * 3) == self._ct_data - self._ct_maxima[self._ct_data <= np.median(self._ct_data)] = \ - False + self._ct_maxima = ( + maximum_filter(self._ct_data, (self._radius,) * 3) == self._ct_data + ) + self._ct_maxima[self._ct_data <= np.median(self._ct_data)] = False self._ct_maxima = np.where(self._ct_maxima, 1, np.nan) # transparent def _toggle_show_mip(self): """Toggle whether the maximum-intensity projection is shown.""" - if self._toggle_show_mip_button.text() == 'Show Max Intensity Proj': - self._toggle_show_mip_button.setText('Hide Max Intensity Proj') - self._images['mip'] = list() - self._images['mip_chs'] = list() + if self._toggle_show_mip_button.text() == "Show Max Intensity Proj": + self._toggle_show_mip_button.setText("Hide Max Intensity Proj") + self._images["mip"] = list() + self._images["mip_chs"] = list() ct_min, ct_max = np.nanmin(self._ct_data), np.nanmax(self._ct_data) for axis in range(3): ct_mip_data = np.max(self._ct_data, axis=axis).T - self._images['mip'].append( - self._figs[axis].axes[0].imshow( - ct_mip_data, cmap='gray', aspect='auto', - vmin=ct_min, vmax=ct_max, zorder=5)) + self._images["mip"].append( + self._figs[axis] + .axes[0] + .imshow( + ct_mip_data, + cmap="gray", + aspect="auto", + vmin=ct_min, + vmax=ct_max, + zorder=5, + ) + ) # add circles for each channel xs, ys, colors = list(), list(), list() for name, ras in self._chs.items(): @@ -698,71 +791,93 @@ def _toggle_show_mip(self): xs.append(xyz[self._xy_idx[axis][0]]) ys.append(xyz[self._xy_idx[axis][1]]) colors.append(_CMAP(self._groups[name])) - self._images['mip_chs'].append( - self._figs[axis].axes[0].imshow( - self._make_ch_image(axis, proj=True), aspect='auto', - extent=self._ch_extents[axis], zorder=6, - cmap=_CMAP, alpha=1, vmin=0, vmax=_N_COLORS)) + self._images["mip_chs"].append( + self._figs[axis] + .axes[0] + .imshow( + self._make_ch_image(axis, proj=True), + aspect="auto", + extent=self._ch_extents[axis], + zorder=6, + cmap=_CMAP, + alpha=1, + vmin=0, + vmax=_N_COLORS, + ) + ) for group in set(self._groups.values()): self._update_lines(group, only_2D=True) else: - for img in self._images['mip'] + self._images['mip_chs']: + for img in self._images["mip"] + self._images["mip_chs"]: img.remove() - self._images.pop('mip') - self._images.pop('mip_chs') - self._toggle_show_mip_button.setText('Show Max Intensity Proj') + self._images.pop("mip") + self._images.pop("mip_chs") + self._toggle_show_mip_button.setText("Show Max Intensity Proj") for group in set(self._groups.values()): # remove lines self._update_lines(group, only_2D=True) self._draw() def _toggle_show_max(self): """Toggle whether to color local maxima differently.""" - if self._toggle_show_max_button.text() == 'Show Maxima': - self._toggle_show_max_button.setText('Hide Maxima') + if self._toggle_show_max_button.text() == "Show Maxima": + self._toggle_show_max_button.setText("Hide Maxima") # happens on initiation or if the radius is changed with it off if self._ct_maxima is None: # otherwise don't recompute self._update_ct_maxima() - self._images['local_max'] = list() + self._images["local_max"] = list() for axis in range(3): - ct_max_data = np.take(self._ct_maxima, - self._current_slice[axis], axis=axis).T - self._images['local_max'].append( - self._figs[axis].axes[0].imshow( - ct_max_data, cmap='autumn', aspect='auto', - vmin=0, vmax=1, zorder=4)) + ct_max_data = np.take( + self._ct_maxima, self._current_slice[axis], axis=axis + ).T + self._images["local_max"].append( + self._figs[axis] + .axes[0] + .imshow( + ct_max_data, + cmap="autumn", + aspect="auto", + vmin=0, + vmax=1, + zorder=4, + ) + ) else: - for img in self._images['local_max']: + for img in self._images["local_max"]: img.remove() - self._images.pop('local_max') - self._toggle_show_max_button.setText('Show Maxima') + self._images.pop("local_max") + self._toggle_show_max_button.setText("Show Maxima") self._draw() def _toggle_show_brain(self): """Toggle whether the brain/MRI is being shown.""" - if 'mri' in self._images: - for img in self._images['mri']: + if "mri" in self._images: + for img in self._images["mri"]: img.remove() - self._images.pop('mri') - self._toggle_brain_button.setText('Show Brain') + self._images.pop("mri") + self._toggle_brain_button.setText("Show Brain") else: - self._images['mri'] = list() + self._images["mri"] = list() for axis in range(3): - mri_data = np.take(self._mri_data, - self._current_slice[axis], axis=axis).T - self._images['mri'].append(self._figs[axis].axes[0].imshow( - mri_data, cmap='hot', aspect='auto', alpha=0.25, zorder=2)) - self._toggle_brain_button.setText('Hide Brain') + mri_data = np.take( + self._mri_data, self._current_slice[axis], axis=axis + ).T + self._images["mri"].append( + self._figs[axis] + .axes[0] + .imshow(mri_data, cmap="hot", aspect="auto", alpha=0.25, zorder=2) + ) + self._toggle_brain_button.setText("Hide Brain") self._draw() def keyPressEvent(self, event): """Execute functions when the user presses a key.""" super(IntracranialElectrodeLocator, self).keyPressEvent(event) - if event.text() == 'm': + if event.text() == "m": self.mark_channel() - if event.text() == 'r': + if event.text() == "r": self.remove_channel() - if event.text() == 'b': + if event.text() == "b": self._toggle_show_brain() diff --git a/mne/gui/tests/test_core.py b/mne/gui/tests/test_core.py index 013bca7eed5..7c94afd67e0 100644 --- a/mne/gui/tests/test_core.py +++ b/mne/gui/tests/test_core.py @@ -19,12 +19,15 @@ @testing.requires_testing_data def test_slice_browser_io(renderer_interactive_pyvistaqt): """Test the input/output of the slice browser GUI.""" - nib = pytest.importorskip('nibabel') + nib = pytest.importorskip("nibabel") from mne.gui._core import SliceBrowser - with pytest.raises(ValueError, match='Base image is not aligned to MRI'): - SliceBrowser(nib.MGHImage( - np.ones((96, 96, 96), dtype=np.float32), np.eye(4)), - subject=subject, subjects_dir=subjects_dir) + + with pytest.raises(ValueError, match="Base image is not aligned to MRI"): + SliceBrowser( + nib.MGHImage(np.ones((96, 96, 96), dtype=np.float32), np.eye(4)), + subject=subject, + subjects_dir=subjects_dir, + ) # TODO: For some reason this leaves some stuff un-closed, we should fix it @@ -32,35 +35,37 @@ def test_slice_browser_io(renderer_interactive_pyvistaqt): @testing.requires_testing_data def test_slice_browser_display(renderer_interactive_pyvistaqt): """Test that the slice browser GUI displays properly.""" - pytest.importorskip('nibabel') + pytest.importorskip("nibabel") from mne.gui._core import SliceBrowser + # test no seghead, fsaverage doesn't have seghead - with pytest.warns(RuntimeWarning, match='`seghead` not found'): + with pytest.warns(RuntimeWarning, match="`seghead` not found"): with catch_logging() as log: gui = SliceBrowser( - subject='fsaverage', subjects_dir=subjects_dir, - verbose=True) + subject="fsaverage", subjects_dir=subjects_dir, verbose=True + ) log = log.getvalue() - assert 'using marching cubes' in log + assert "using marching cubes" in log gui.close() # test functions - with pytest.warns(RuntimeWarning, match='`pial` surface not found'): + with pytest.warns(RuntimeWarning, match="`pial` surface not found"): gui = SliceBrowser(subject=subject, subjects_dir=subjects_dir) # test RAS - gui._RAS_textbox.setText('10 10 10') + gui._RAS_textbox.setText("10 10 10") gui._RAS_textbox.focusOutEvent(event=None) assert_allclose(gui._ras, [10, 10, 10]) # test vox - gui._VOX_textbox.setText('150, 150, 150') + gui._VOX_textbox.setText("150, 150, 150") gui._VOX_textbox.focusOutEvent(event=None) assert_allclose(gui._ras, [23, 22, 23]) # test click - with use_log_level('debug'): - _fake_click(gui._figs[2], gui._figs[2].axes[0], - [137, 140], xform='data', kind='release') + with use_log_level("debug"): + _fake_click( + gui._figs[2], gui._figs[2].axes[0], [137, 140], xform="data", kind="release" + ) assert_allclose(gui._ras, [10, 12, 23]) gui.close() diff --git a/mne/gui/tests/test_coreg.py b/mne/gui/tests/test_coreg.py index 5c705c6dcb5..3f217ef275a 100644 --- a/mne/gui/tests/test_coreg.py +++ b/mne/gui/tests/test_coreg.py @@ -29,30 +29,17 @@ fid_fname = subjects_dir / "sample" / "bem" / "sample-fiducials.fif" ctf_raw_path = data_path / "CTF" / "catch-alp-good-f.ds" nirx_15_0_raw_path = ( - data_path - / "NIRx" - / "nirscout" - / "nirx_15_0_recording" - / "NIRS-2019-10-27_003.hdr" + data_path / "NIRx" / "nirscout" / "nirx_15_0_recording" / "NIRS-2019-10-27_003.hdr" ) nirsport2_raw_path = ( - data_path - / "NIRx" - / "nirsport_v2" - / "aurora_2021_9" - / "2021-10-01_002_config.hdr" + data_path / "NIRx" / "nirsport_v2" / "aurora_2021_9" / "2021-10-01_002_config.hdr" ) snirf_nirsport2_raw_path = ( - data_path - / "SNIRF" - / "NIRx" - / "NIRSport2" - / "1.0.3" - / "2021-05-05_001.snirf" + data_path / "SNIRF" / "NIRx" / "NIRSport2" / "1.0.3" / "2021-05-05_001.snirf" ) -pytest.importorskip('nibabel') +pytest.importorskip("nibabel") class TstVTKPicker: @@ -75,8 +62,10 @@ def GetDataSet(self): def GetPickPosition(self): """Return the picked position.""" vtk_cell = self.mesh.GetCell(self.cell_id) - cell = [vtk_cell.GetPointId(point_id) for point_id - in range(vtk_cell.GetNumberOfPoints())] + cell = [ + vtk_cell.GetPointId(point_id) + for point_id in range(vtk_cell.GetNumberOfPoints()) + ] self.point_id = cell[0] return self.mesh.points[self.point_id] @@ -88,62 +77,77 @@ def GetEventPosition(self): @pytest.mark.slowtest @testing.requires_testing_data @pytest.mark.parametrize( - 'inst_path', (raw_path, 'gen_montage', ctf_raw_path, nirx_15_0_raw_path, - nirsport2_raw_path, snirf_nirsport2_raw_path)) -def test_coreg_gui_pyvista_file_support(inst_path, tmp_path, - renderer_interactive_pyvistaqt): + "inst_path", + ( + raw_path, + "gen_montage", + ctf_raw_path, + nirx_15_0_raw_path, + nirsport2_raw_path, + snirf_nirsport2_raw_path, + ), +) +def test_coreg_gui_pyvista_file_support( + inst_path, tmp_path, renderer_interactive_pyvistaqt +): """Test reading supported files.""" from mne.gui import coregistration - if inst_path == 'gen_montage': + if inst_path == "gen_montage": # generate a montage fig to use as inst. tmp_info = read_info(raw_path) eeg_chans = [] - for pt in tmp_info['dig']: - if pt['kind'] == FIFF.FIFFV_POINT_EEG: + for pt in tmp_info["dig"]: + if pt["kind"] == FIFF.FIFFV_POINT_EEG: eeg_chans.append(f"EEG {pt['ident']:03d}") dig = DigMontage(dig=tmp_info["dig"], ch_names=eeg_chans) - inst_path = tmp_path / 'tmp-dig.fif' + inst_path = tmp_path / "tmp-dig.fif" dig.save(inst_path) if inst_path == ctf_raw_path: - ctx = pytest.warns(RuntimeWarning, match='MEG ref channel RMSP') + ctx = pytest.warns(RuntimeWarning, match="MEG ref channel RMSP") elif inst_path == snirf_nirsport2_raw_path: # TODO: This is maybe a bug? ctx = pytest.warns(RuntimeWarning, match='assuming "head"') else: ctx = nullcontext() with ctx: coreg = coregistration( - inst=inst_path, subject='sample', subjects_dir=subjects_dir) + inst=inst_path, subject="sample", subjects_dir=subjects_dir + ) coreg._accept_close_event = True coreg.close() @pytest.mark.slowtest @testing.requires_testing_data -def test_coreg_gui_pyvista_basic(tmp_path, monkeypatch, - renderer_interactive_pyvistaqt): +def test_coreg_gui_pyvista_basic(tmp_path, monkeypatch, renderer_interactive_pyvistaqt): """Test that using CoregistrationUI matches mne coreg.""" from mne.gui import coregistration + config = get_config() # the sample subject in testing has MRI fids assert (subjects_dir / "sample" / "bem" / "sample-fiducials.fif").is_file() - coreg = coregistration(subject='sample', subjects_dir=subjects_dir, - trans=fname_trans) + coreg = coregistration( + subject="sample", subjects_dir=subjects_dir, trans=fname_trans + ) assert coreg._lock_fids coreg._reset_fiducials() coreg.close() # make it always log the distances - monkeypatch.setattr(_3d.logger, 'info', _3d.logger.warning) + monkeypatch.setattr(_3d.logger, "info", _3d.logger.warning) with catch_logging() as log: - coreg = coregistration(inst=raw_path, subject='sample', - head_high_res=False, # for speed - subjects_dir=subjects_dir, verbose='debug') + coreg = coregistration( + inst=raw_path, + subject="sample", + head_high_res=False, # for speed + subjects_dir=subjects_dir, + verbose="debug", + ) log = log.getvalue() - assert 'Total 16/78 points inside the surface' in log + assert "Total 16/78 points inside the surface" in log coreg._set_fiducials_file(fid_fname) assert coreg._fiducials_file == str(fid_fname) @@ -153,18 +157,18 @@ def test_coreg_gui_pyvista_basic(tmp_path, monkeypatch, coreg._reset_fitting_parameters() coreg._set_scale_mode("uniform") coreg._fits_fiducials() - assert_allclose(coreg.coreg._scale, - np.array([97.46, 97.46, 97.46]) * 1e-2, - atol=1e-3) - shown_scale = [coreg._widgets[f's{x}'].get_value() for x in 'XYZ'] + assert_allclose( + coreg.coreg._scale, np.array([97.46, 97.46, 97.46]) * 1e-2, atol=1e-3 + ) + shown_scale = [coreg._widgets[f"s{x}"].get_value() for x in "XYZ"] assert_allclose(shown_scale, coreg.coreg._scale * 100, atol=1e-2) coreg._set_icp_fid_match("nearest") coreg._set_scale_mode("3-axis") coreg._fits_icp() - assert_allclose(coreg.coreg._scale, - np.array([104.43, 101.47, 125.78]) * 1e-2, - atol=1e-3) - shown_scale = [coreg._widgets[f's{x}'].get_value() for x in 'XYZ'] + assert_allclose( + coreg.coreg._scale, np.array([104.43, 101.47, 125.78]) * 1e-2, atol=1e-3 + ) + shown_scale = [coreg._widgets[f"s{x}"].get_value() for x in "XYZ"] assert_allclose(shown_scale, coreg.coreg._scale * 100, atol=1e-2) coreg._set_scale_mode("None") coreg._set_icp_fid_match("matched") @@ -177,7 +181,7 @@ def test_coreg_gui_pyvista_basic(tmp_path, monkeypatch, # picking assert not coreg._mri_fids_modified - vtk_picker = TstVTKPicker(coreg._surfaces['head'], 0, (0, 0)) + vtk_picker = TstVTKPicker(coreg._surfaces["head"], 0, (0, 0)) coreg._on_mouse_move(vtk_picker, None) coreg._on_button_press(vtk_picker, None) coreg._on_pick(vtk_picker, None) @@ -190,31 +194,31 @@ def test_coreg_gui_pyvista_basic(tmp_path, monkeypatch, assert coreg._lock_fids # fitting (no scaling) - assert coreg._nasion_weight == 10. - coreg._set_point_weight(11., 'nasion') - assert coreg._nasion_weight == 11. + assert coreg._nasion_weight == 10.0 + coreg._set_point_weight(11.0, "nasion") + assert coreg._nasion_weight == 11.0 coreg._fit_fiducials() with catch_logging() as log: coreg._redraw() # actually emit the log log = log.getvalue() - assert 'Total 6/78 points inside the surface' in log + assert "Total 6/78 points inside the surface" in log with catch_logging() as log: coreg._fit_icp() coreg._redraw() log = log.getvalue() - assert 'Total 38/78 points inside the surface' in log + assert "Total 38/78 points inside the surface" in log assert coreg.coreg._extra_points_filter is None coreg._omit_hsp() with catch_logging() as log: coreg._redraw() log = log.getvalue() - assert 'Total 29/53 points inside the surface' in log + assert "Total 29/53 points inside the surface" in log assert coreg.coreg._extra_points_filter is not None coreg._reset_omit_hsp_filter() with catch_logging() as log: coreg._redraw() log = log.getvalue() - assert 'Total 38/78 points inside the surface' in log + assert "Total 38/78 points inside the surface" in log assert coreg.coreg._extra_points_filter is None assert coreg._grow_hair == 0 @@ -222,48 +226,48 @@ def test_coreg_gui_pyvista_basic(tmp_path, monkeypatch, with catch_logging() as log: coreg._redraw() log = log.getvalue() - assert 'Total 6/78 points inside the surface' in log - norm = np.linalg.norm(coreg._head_geo['rr']) # what's used for inside + assert "Total 6/78 points inside the surface" in log + norm = np.linalg.norm(coreg._head_geo["rr"]) # what's used for inside assert_allclose(norm, 5.949288, atol=1e-3) coreg._set_grow_hair(20.0) with catch_logging() as log: coreg._redraw() assert coreg._grow_hair == 20.0 - norm = np.linalg.norm(coreg._head_geo['rr']) + norm = np.linalg.norm(coreg._head_geo["rr"]) assert_allclose(norm, 6.555220, atol=1e-3) # outward log = log.getvalue() - assert 'Total 8/78 points inside the surface' in log # more outside now + assert "Total 8/78 points inside the surface" in log # more outside now # visualization assert not coreg._helmet - assert coreg._actors['helmet'] is None + assert coreg._actors["helmet"] is None coreg._set_helmet(True) assert coreg._helmet with catch_logging() as log: - coreg._redraw(verbose='debug') + coreg._redraw(verbose="debug") log = log.getvalue() - assert 'Drawing helmet' in log - coreg._set_point_weight(1., 'nasion') + assert "Drawing helmet" in log + coreg._set_point_weight(1.0, "nasion") coreg._fit_fiducials() with catch_logging() as log: - coreg._redraw(verbose='debug') + coreg._redraw(verbose="debug") log = log.getvalue() - assert 'Drawing helmet' in log + assert "Drawing helmet" in log assert coreg._orient_glyphs assert coreg._scale_by_distance assert coreg._mark_inside assert_allclose( - coreg._head_opacity, - float(config.get('MNE_COREG_HEAD_OPACITY', '0.8'))) + coreg._head_opacity, float(config.get("MNE_COREG_HEAD_OPACITY", "0.8")) + ) assert coreg._hpi_coils assert coreg._eeg_channels assert coreg._head_shape_points - assert coreg._scale_mode == 'None' - assert coreg._icp_fid_match == 'matched' + assert coreg._scale_mode == "None" + assert coreg._icp_fid_match == "matched" assert coreg._head_resolution is False assert coreg._trans_modified - tmp_trans = tmp_path / 'tmp-trans.fif' + tmp_trans = tmp_path / "tmp-trans.fif" coreg._save_trans(tmp_trans) assert not coreg._trans_modified assert tmp_trans.is_file() @@ -274,14 +278,14 @@ def test_coreg_gui_pyvista_basic(tmp_path, monkeypatch, coreg._renderer._process_events() assert coreg._mri_fids_modified # should prompt assert coreg._renderer.plotter.app_window.children() is not None - assert 'close_dialog' not in coreg._widgets + assert "close_dialog" not in coreg._widgets assert not coreg._renderer.plotter._closed assert coreg._accept_close_event # make sure it's ignored (PySide6 causes problems here and doesn't wait) coreg._accept_close_event = False coreg.close() assert not coreg._renderer.plotter._closed - coreg._widgets['close_dialog'].trigger('Discard') # do not save + coreg._widgets["close_dialog"].trigger("Discard") # do not save coreg.close() assert coreg._renderer.plotter._closed coreg._clean() # finally, cleanup internal structures @@ -296,30 +300,31 @@ def test_coreg_gui_pyvista_basic(tmp_path, monkeypatch, def test_fullscreen(renderer_interactive_pyvistaqt): """Test fullscreen mode.""" from mne.gui import coregistration + # Fullscreen mode - coreg = coregistration( - subject='sample', subjects_dir=subjects_dir, fullscreen=True - ) + coreg = coregistration(subject="sample", subjects_dir=subjects_dir, fullscreen=True) coreg._accept_close_event = True coreg.close() @pytest.mark.slowtest -@requires_version('sphinx_gallery') +@requires_version("sphinx_gallery") @testing.requires_testing_data def test_coreg_gui_scraper(tmp_path, renderer_interactive_pyvistaqt): """Test the scrapper for the coregistration GUI.""" from mne.gui import coregistration - coreg = coregistration(subject='sample', subjects_dir=subjects_dir, - trans=fname_trans) - (tmp_path / '_images').mkdir() - image_path = tmp_path / '_images' / 'temp.png' - gallery_conf = dict(builder_name='html', src_dir=tmp_path) + + coreg = coregistration( + subject="sample", subjects_dir=subjects_dir, trans=fname_trans + ) + (tmp_path / "_images").mkdir() + image_path = tmp_path / "_images" / "temp.png" + gallery_conf = dict(builder_name="html", src_dir=tmp_path) block_vars = dict( - example_globals=dict(gui=coreg), - image_path_iterator=iter([str(image_path)])) + example_globals=dict(gui=coreg), image_path_iterator=iter([str(image_path)]) + ) assert not image_path.is_file() - assert not getattr(coreg, '_scraped', False) + assert not getattr(coreg, "_scraped", False) mne.gui._GUIScraper()(None, block_vars, gallery_conf) assert image_path.is_file() assert coreg._scraped @@ -334,39 +339,40 @@ def test_coreg_gui_notebook(renderer_notebook, nbexec): from mne.datasets import testing from mne.gui import coregistration - mne.viz.set_3d_backend('notebook') # set the 3d backend + mne.viz.set_3d_backend("notebook") # set the 3d backend with pytest.MonkeyPatch().context() as mp: - mp.delenv('_MNE_FAKE_HOME_DIR') + mp.delenv("_MNE_FAKE_HOME_DIR") data_path = testing.data_path(download=False) subjects_dir = data_path / "subjects" - coregistration(subject='sample', subjects_dir=subjects_dir) + coregistration(subject="sample", subjects_dir=subjects_dir) @pytest.mark.slowtest -def test_no_sparse_head(subjects_dir_tmp, renderer_interactive_pyvistaqt, - monkeypatch): +def test_no_sparse_head(subjects_dir_tmp, renderer_interactive_pyvistaqt, monkeypatch): """Test mne.gui.coregistration with no sparse head.""" from mne.gui import coregistration subjects_dir_tmp = Path(subjects_dir_tmp) - subject = 'sample' + subject = "sample" out_rr, out_tris = mne.read_surface( subjects_dir_tmp / subject / "bem" / "outer_skin.surf" ) - for head in ('sample-head.fif', 'outer_skin.surf'): + for head in ("sample-head.fif", "outer_skin.surf"): os.remove(subjects_dir_tmp / subject / "bem" / head) # Avoid actually doing the decimation (it's slow) monkeypatch.setattr( - mne.coreg, 'decimate_surface', - lambda rr, tris, n_triangles: (out_rr, out_tris)) - with pytest.warns(RuntimeWarning, match='No low-resolution head found'): + mne.coreg, "decimate_surface", lambda rr, tris, n_triangles: (out_rr, out_tris) + ) + with pytest.warns(RuntimeWarning, match="No low-resolution head found"): coreg = coregistration( - inst=raw_path, subject=subject, subjects_dir=subjects_dir_tmp) + inst=raw_path, subject=subject, subjects_dir=subjects_dir_tmp + ) coreg.close() def test_splash_closed(tmp_path, renderer_interactive_pyvistaqt): """Test that the splash closes on error.""" from mne.gui import coregistration - with pytest.raises(RuntimeError, match='No standard head model'): - coregistration(subjects_dir=tmp_path, subject='fsaverage') + + with pytest.raises(RuntimeError, match="No standard head model"): + coregistration(subjects_dir=tmp_path, subject="fsaverage") diff --git a/mne/gui/tests/test_gui_api.py b/mne/gui/tests/test_gui_api.py index 8e693cf65ff..e4d1c7887dd 100644 --- a/mne/gui/tests/test_gui_api.py +++ b/mne/gui/tests/test_gui_api.py @@ -9,59 +9,62 @@ # These will skip all tests in this scope pytestmark = pytest.mark.skipif( - sys.platform.startswith('win'), reason='nbexec does not work on Windows') -pytest.importorskip('nibabel') + sys.platform.startswith("win"), reason="nbexec does not work on Windows" +) +pytest.importorskip("nibabel") -def test_gui_api(renderer_notebook, nbexec, *, n_warn=0, backend='qt'): +def test_gui_api(renderer_notebook, nbexec, *, n_warn=0, backend="qt"): """Test GUI API.""" import contextlib import mne import warnings import sys + try: # Function backend # noqa except Exception: # Notebook standalone mode - backend = 'notebook' + backend = "notebook" n_warn = 0 # nbexec does not expose renderer_notebook so I use a # temporary variable to synchronize the tests - if backend == 'notebook': - mne.viz.set_3d_backend('notebook') + if backend == "notebook": + mne.viz.set_3d_backend("notebook") renderer = mne.viz.backends.renderer._get_renderer(size=(300, 300)) # theme with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') - renderer._window_set_theme('/does/not/exist') - if backend == 'qt': + warnings.simplefilter("always") + renderer._window_set_theme("/does/not/exist") + if backend == "qt": assert len(w) == 1 - assert 'not found' in str(w[0].message), str(w[0].message) + assert "not found" in str(w[0].message), str(w[0].message) else: assert len(w) == 0 with mne.utils._record_warnings() as w: - renderer._window_set_theme('dark') - w = [ww for ww in w if 'is not yet supported' in str(ww.message)] - if sys.platform != 'darwin': # sometimes this is fine + renderer._window_set_theme("dark") + w = [ww for ww in w if "is not yet supported" in str(ww.message)] + if sys.platform != "darwin": # sometimes this is fine assert len(w) == n_warn, [ww.message for ww in w] # window without 3d plotter - if backend == 'qt': + if backend == "qt": window = renderer._window_create() widget = renderer._window_create() - central_layout = renderer._layout_create(orientation='grid') + central_layout = renderer._layout_create(orientation="grid") renderer._layout_add_widget(central_layout, widget, row=0, col=0) - renderer._window_initialize(window=window, - central_layout=central_layout) + renderer._window_initialize(window=window, central_layout=central_layout) from unittest.mock import Mock + mock = Mock() @contextlib.contextmanager - def _check_widget_trigger(widget, mock, before, after, call_count=True, - get_value=True): + def _check_widget_trigger( + widget, mock, before, after, call_count=True, get_value=True + ): if get_value: assert widget.get_value() == before old_call_count = mock.call_count @@ -74,16 +77,16 @@ def _check_widget_trigger(widget, mock, before, after, call_count=True, assert mock.call_count == old_call_count + 1 # --- BEGIN: dock --- - renderer._dock_initialize(name='', area='left') + renderer._dock_initialize(name="", area="left") # label (not interactive) widget = renderer._dock_add_label( - value='', + value="", align=False, selectable=True, ) widget = renderer._dock_add_label( - value='', + value="", align=True, ) widget.update() @@ -95,17 +98,17 @@ def _check_widget_trigger(widget, mock, before, after, call_count=True, # ToolButton widget = renderer._dock_add_button( - name='', + name="", callback=mock, - style='toolbutton', - tooltip='button', + style="toolbutton", + tooltip="button", ) with _check_widget_trigger(widget, mock, None, None, get_value=False): widget.set_value(True) # PushButton widget = renderer._dock_add_button( - name='', + name="", callback=mock, ) with _check_widget_trigger(widget, mock, None, None, get_value=False): @@ -113,36 +116,36 @@ def _check_widget_trigger(widget, mock, before, after, call_count=True, # slider widget = renderer._dock_add_slider( - name='', + name="", value=0, rng=[0, 10], callback=mock, - tooltip='slider', + tooltip="slider", ) with _check_widget_trigger(widget, mock, 0, 5): widget.set_value(5) # check box widget = renderer._dock_add_check_box( - name='', + name="", value=False, callback=mock, - tooltip='check box', + tooltip="check box", ) with _check_widget_trigger(widget, mock, False, True): widget.set_value(True) # spin box renderer._dock_add_spin_box( - name='', + name="", value=0, rng=[0, 1], callback=mock, step=0.1, - tooltip='spin box', + tooltip="spin box", ) widget = renderer._dock_add_spin_box( - name='', + name="", value=0, rng=[0, 1], callback=mock, @@ -153,71 +156,66 @@ def _check_widget_trigger(widget, mock, before, after, call_count=True, # combo box widget = renderer._dock_add_combo_box( - name='', - value='foo', - rng=['foo', 'bar'], + name="", + value="foo", + rng=["foo", "bar"], callback=mock, - tooltip='combo box', + tooltip="combo box", ) - with _check_widget_trigger(widget, mock, 'foo', 'bar'): - widget.set_value('bar') + with _check_widget_trigger(widget, mock, "foo", "bar"): + widget.set_value("bar") # radio buttons widget = renderer._dock_add_radio_buttons( - value='foo', - rng=['foo', 'bar'], + value="foo", + rng=["foo", "bar"], callback=mock, ) with _check_widget_trigger(widget, mock, None, None, get_value=False): - widget.set_value(1, 'bar') - assert widget.get_value(0) == 'foo' - assert widget.get_value(1) == 'bar' + widget.set_value(1, "bar") + assert widget.get_value(0) == "foo" + assert widget.get_value(1) == "bar" widget.set_enabled(False) # text field widget = renderer._dock_add_text( - name='', - value='foo', - placeholder='', + name="", + value="foo", + placeholder="", callback=mock, ) - with _check_widget_trigger(widget, mock, 'foo', 'bar'): - widget.set_value('bar') + with _check_widget_trigger(widget, mock, "foo", "bar"): + widget.set_value("bar") widget.set_style(dict(border="2px solid #ff0000")) # file button renderer._dock_add_file_button( - name='', - desc='', + name="", + desc="", func=mock, is_directory=True, - tooltip='file button', + tooltip="file button", ) renderer._dock_add_file_button( - name='', - desc='', + name="", + desc="", func=mock, - initial_directory='', + initial_directory="", ) renderer._dock_add_file_button( - name='', - desc='', - func=mock, - ) - widget = renderer._dock_add_file_button( - name='', - desc='', + name="", + desc="", func=mock, - save=True ) + widget = renderer._dock_add_file_button(name="", desc="", func=mock, save=True) # XXX: the internal file dialogs may hang without signals widget.set_enabled(False) - renderer._dock_initialize(name='', area='right') - renderer._dock_named_layout(name='') + renderer._dock_initialize(name="", area="right") + renderer._dock_named_layout(name="") for collapse in (None, True, False): - renderer._dock_add_group_box(name='', collapse=collapse) + renderer._dock_add_group_box(name="", collapse=collapse) renderer._dock_add_stretch() renderer._dock_add_layout() renderer._dock_finalize() @@ -232,75 +230,75 @@ def _check_widget_trigger(widget, mock, before, after, call_count=True, ) # button - assert 'reset' not in renderer.actions + assert "reset" not in renderer.actions renderer._tool_bar_add_button( - name='reset', - desc='', + name="reset", + desc="", func=mock, - icon_name='help', + icon_name="help", ) - assert 'reset' in renderer.actions + assert "reset" in renderer.actions # icon renderer._tool_bar_update_button_icon( - name='reset', - icon_name='reset', + name="reset", + icon_name="reset", ) # text renderer._tool_bar_add_text( - name='', - value='', - placeholder='', + name="", + value="", + placeholder="", ) # spacer renderer._tool_bar_add_spacer() # file button - assert 'help' not in renderer.actions + assert "help" not in renderer.actions renderer._tool_bar_add_file_button( - name='help', - desc='', + name="help", + desc="", func=mock, shortcut=None, ) - renderer.actions['help'].trigger() - if renderer._kind == 'qt': + renderer.actions["help"].trigger() + if renderer._kind == "qt": dialog = renderer._window.children()[-1] - assert 'FileDialog' in repr(dialog) + assert "FileDialog" in repr(dialog) dialog.close() dialog.deleteLater() # play button - assert 'play' not in renderer.actions + assert "play" not in renderer.actions renderer._tool_bar_add_play_button( - name='play', - desc='', + name="play", + desc="", func=mock, shortcut=None, ) - assert 'play' in renderer.actions + assert "play" in renderer.actions # --- END: tool bar --- # --- BEGIN: menu bar --- renderer._menu_initialize() # submenu - renderer._menu_add_submenu(name='foo', desc='foo') - assert 'foo' in renderer._menus - assert 'foo' in renderer._menu_actions + renderer._menu_add_submenu(name="foo", desc="foo") + assert "foo" in renderer._menus + assert "foo" in renderer._menu_actions # button renderer._menu_add_button( - menu_name='foo', - name='bar', - desc='bar', + menu_name="foo", + name="bar", + desc="bar", func=mock, ) - assert 'bar' in renderer._menu_actions['foo'] - with _check_widget_trigger(None, mock, '', '', get_value=False): - renderer._menu_actions['foo']['bar'].trigger() + assert "bar" in renderer._menu_actions["foo"] + with _check_widget_trigger(None, mock, "", "", get_value=False): + renderer._menu_actions["foo"]["bar"].trigger() # --- END: menu bar --- @@ -309,8 +307,8 @@ def _check_widget_trigger(widget, mock, before, after, call_count=True, renderer._status_bar_update() # label - widget = renderer._status_bar_add_label(value='foo', stretch=0) - assert widget.get_value() == 'foo' + widget = renderer._status_bar_add_label(value="foo", stretch=0) + assert widget.get_value() == "foo" # progress bar widget = renderer._status_bar_add_progress_bar(stretch=0) @@ -320,74 +318,70 @@ def _check_widget_trigger(widget, mock, before, after, call_count=True, # --- END: status bar --- # --- BEGIN: tooltips --- - widget = renderer._dock_add_button( - name='', - callback=mock, - tooltip='foo' - ) - assert widget.get_tooltip() == 'foo' + widget = renderer._dock_add_button(name="", callback=mock, tooltip="foo") + assert widget.get_tooltip() == "foo" # Change it … - widget.set_tooltip('bar') - assert widget.get_tooltip() == 'bar' + widget.set_tooltip("bar") + assert widget.get_tooltip() == "bar" # --- END: tooltips --- # --- BEGIN: dialog --- # dialogs are not supported yet on notebook - if renderer._kind == 'qt': + if renderer._kind == "qt": # warning buttons = ["Save", "Cancel"] widget = renderer._dialog_create( - title='', - text='', - info_text='', + title="", + text="", + info_text="", callback=mock, buttons=buttons, modal=False, ) widget.show() for button in buttons: - with _check_widget_trigger(None, mock, '', '', get_value=False): + with _check_widget_trigger(None, mock, "", "", get_value=False): widget.trigger(button=button) assert mock.call_args.args == (button,) assert not widget._widget.isVisible() # buttons list empty means OK button (default) - button = 'Ok' + button = "Ok" widget = renderer._dialog_create( - title='', - text='', - info_text='', + title="", + text="", + info_text="", callback=mock, - icon='NoIcon', + icon="NoIcon", modal=False, ) widget.show() - with _check_widget_trigger(None, mock, '', '', get_value=False): + with _check_widget_trigger(None, mock, "", "", get_value=False): widget.trigger(button=button) assert mock.call_args.args == (button,) - widget.trigger(button='Ok') + widget.trigger(button="Ok") # --- END: dialog --- # --- BEGIN: keypress --- renderer._keypress_initialize() - renderer._keypress_add('a', mock) + renderer._keypress_add("a", mock) # keypress is not supported yet on notebook - if renderer._kind == 'qt': - with _check_widget_trigger(None, mock, '', '', get_value=False): - renderer._keypress_trigger('a') + if renderer._kind == "qt": + with _check_widget_trigger(None, mock, "", "", get_value=False): + renderer._keypress_trigger("a") # --- END: keypress --- renderer.show() - renderer._window_close_connect(lambda: mock('first'), after=False) - renderer._window_close_connect(lambda: mock('last')) + renderer._window_close_connect(lambda: mock("first"), after=False) + renderer._window_close_connect(lambda: mock("last")) old_call_count = mock.call_count renderer.close() - if renderer._kind == 'qt': + if renderer._kind == "qt": assert mock.call_count == old_call_count + 2 - assert mock.call_args_list[-1].args == ('last',) - assert mock.call_args_list[-2].args == ('first',) + assert mock.call_args_list[-1].args == ("last",) + assert mock.call_args_list[-2].args == ("first",) assert renderer._window.isVisible() is False del renderer @@ -395,10 +389,10 @@ def _check_widget_trigger(widget, mock, before, after, call_count=True, def test_gui_api_qt(renderer_interactive_pyvistaqt): """Test GUI API with the Qt backend.""" _, api = _check_qt_version(return_api=True) - n_warn = int(api in ('PySide6', 'PyQt6')) + n_warn = int(api in ("PySide6", "PyQt6")) # TODO: After merging https://github.com/mne-tools/mne-python/pull/11567 # The Qt CI run started failing about 50% of the time, so let's skip this # for now. - if api == 'PySide6': - pytest.skip('PySide6 causes segfaults on CIs sometimes') - test_gui_api(None, None, n_warn=n_warn, backend='qt') + if api == "PySide6": + pytest.skip("PySide6 causes segfaults on CIs sometimes") + test_gui_api(None, None, n_warn=n_warn, backend="qt") diff --git a/mne/gui/tests/test_ieeg_locate.py b/mne/gui/tests/test_ieeg_locate.py index 6c086e73260..2ad39d0b320 100644 --- a/mne/gui/tests/test_ieeg_locate.py +++ b/mne/gui/tests/test_ieeg_locate.py @@ -22,21 +22,24 @@ # Module-level ignore pytestmark = pytest.mark.filterwarnings( - 'ignore:.*locate_ieeg.*deprecated.*:FutureWarning') + "ignore:.*locate_ieeg.*deprecated.*:FutureWarning" +) @pytest.fixture def _fake_CT_coords(skull_size=5, contact_size=2): """Make somewhat realistic CT data with contacts.""" - nib = pytest.importorskip('nibabel') + nib = pytest.importorskip("nibabel") brain = nib.load(subjects_dir / subject / "mri" / "brain.mgz") - verts = mne.read_surface( - subjects_dir / subject / "bem" / "outer_skull.surf" - )[0] + verts = mne.read_surface(subjects_dir / subject / "bem" / "outer_skull.surf")[0] verts = apply_trans(np.linalg.inv(brain.header.get_vox2ras_tkr()), verts) x, y, z = np.array(brain.shape).astype(int) // 2 - coords = [(x, y - 14, z), (x - 10, y - 15, z), - (x - 20, y - 16, z + 1), (x - 30, y - 16, z + 1)] + coords = [ + (x, y - 14, z), + (x - 10, y - 15, z), + (x - 20, y - 16, z + 1), + (x - 30, y - 16, z + 1), + ] center = np.array(brain.shape) / 2 # make image np.random.seed(99) @@ -44,18 +47,22 @@ def _fake_CT_coords(skull_size=5, contact_size=2): # make skull for vert in verts: x, y, z = np.round(vert).astype(int) - ct_data[slice(x - skull_size, x + skull_size + 1), - slice(y - skull_size, y + skull_size + 1), - slice(z - skull_size, z + skull_size + 1)] = 1000 + ct_data[ + slice(x - skull_size, x + skull_size + 1), + slice(y - skull_size, y + skull_size + 1), + slice(z - skull_size, z + skull_size + 1), + ] = 1000 # add electrode with contacts - for (x, y, z) in coords: + for x, y, z in coords: # make sure not in skull assert np.linalg.norm(center - np.array((x, y, z))) < 50 - ct_data[slice(x - contact_size, x + contact_size + 1), - slice(y - contact_size, y + contact_size + 1), - slice(z - contact_size, z + contact_size + 1)] = \ - 1000 - np.linalg.norm(np.array(np.meshgrid( - *[range(-contact_size, contact_size + 1)] * 3)), axis=0) + ct_data[ + slice(x - contact_size, x + contact_size + 1), + slice(y - contact_size, y + contact_size + 1), + slice(z - contact_size, z + contact_size + 1), + ] = 1000 - np.linalg.norm( + np.array(np.meshgrid(*[range(-contact_size, contact_size + 1)] * 3)), axis=0 + ) ct = nib.MGHImage(ct_data, brain.affine) coords = apply_trans(ct.header.get_vox2ras_tkr(), np.array(coords)) return ct, coords @@ -63,56 +70,58 @@ def _fake_CT_coords(skull_size=5, contact_size=2): def test_ieeg_elec_locate_io(renderer_interactive_pyvistaqt): """Test the input/output of the intracranial location GUI.""" - nib = pytest.importorskip('nibabel') + nib = pytest.importorskip("nibabel") import mne.gui + info = mne.create_info([], 1000) # fake as T1 so that aligned aligned_ct = nib.load(subjects_dir / subject / "mri" / "brain.mgz") - trans = mne.transforms.Transform('head', 'mri') - with pytest.raises(ValueError, - match='No channels found in `info` to locate'): + trans = mne.transforms.Transform("head", "mri") + with pytest.raises(ValueError, match="No channels found in `info` to locate"): mne.gui.locate_ieeg(info, trans, aligned_ct, subject, subjects_dir) - info = mne.create_info(['test'], 1000, 'seeg') - montage = mne.channels.make_dig_montage( - {'test': [0, 0, 0]}, coord_frame='mri') - with pytest.warns(RuntimeWarning, match='nasion not found'): + info = mne.create_info(["test"], 1000, "seeg") + montage = mne.channels.make_dig_montage({"test": [0, 0, 0]}, coord_frame="mri") + with pytest.warns(RuntimeWarning, match="nasion not found"): info.set_montage(montage) - with pytest.raises(RuntimeError, - match='must be in the "head" coordinate frame'): - with pytest.warns(RuntimeWarning, match='`pial` surface not found'): + with pytest.raises(RuntimeError, match='must be in the "head" coordinate frame'): + with pytest.warns(RuntimeWarning, match="`pial` surface not found"): mne.gui.locate_ieeg(info, trans, aligned_ct, subject, subjects_dir) -@requires_version('sphinx_gallery') +@requires_version("sphinx_gallery") @testing.requires_testing_data -def test_locate_scraper(renderer_interactive_pyvistaqt, _fake_CT_coords, - tmp_path): +def test_locate_scraper(renderer_interactive_pyvistaqt, _fake_CT_coords, tmp_path): """Test sphinx-gallery scraping of the GUI.""" import mne.gui + raw = mne.io.read_raw_fif(raw_path) raw.pick_types(eeg=True) - ch_dict = {'EEG 001': 'LAMY 1', 'EEG 002': 'LAMY 2', - 'EEG 003': 'LSTN 1', 'EEG 004': 'LSTN 2'} + ch_dict = { + "EEG 001": "LAMY 1", + "EEG 002": "LAMY 2", + "EEG 003": "LSTN 1", + "EEG 004": "LSTN 2", + } raw.pick_channels(list(ch_dict.keys())) raw.rename_channels(ch_dict) raw.set_montage(None) aligned_ct, _ = _fake_CT_coords trans = mne.read_trans(fname_trans) - with pytest.warns(RuntimeWarning, match='`pial` surface not found'): + with pytest.warns(RuntimeWarning, match="`pial` surface not found"): gui = mne.gui.locate_ieeg( - raw.info, trans, aligned_ct, - subject=subject, subjects_dir=subjects_dir) - (tmp_path / '_images').mkdir() - image_path = tmp_path / '_images' / 'temp.png' - gallery_conf = dict(builder_name='html', src_dir=tmp_path) + raw.info, trans, aligned_ct, subject=subject, subjects_dir=subjects_dir + ) + (tmp_path / "_images").mkdir() + image_path = tmp_path / "_images" / "temp.png" + gallery_conf = dict(builder_name="html", src_dir=tmp_path) block_vars = dict( - example_globals=dict(gui=gui), - image_path_iterator=iter([str(image_path)])) + example_globals=dict(gui=gui), image_path_iterator=iter([str(image_path)]) + ) assert not image_path.is_file() - assert not getattr(gui, '_scraped', False) + assert not getattr(gui, "_scraped", False) mne.gui._GUIScraper()(None, block_vars, gallery_conf) assert image_path.is_file() assert gui._scraped @@ -120,45 +129,55 @@ def test_locate_scraper(renderer_interactive_pyvistaqt, _fake_CT_coords, @testing.requires_testing_data -def test_ieeg_elec_locate_display(renderer_interactive_pyvistaqt, - _fake_CT_coords): +def test_ieeg_elec_locate_display(renderer_interactive_pyvistaqt, _fake_CT_coords): """Test that the intracranial location GUI displays properly.""" raw = mne.io.read_raw_fif(raw_path, preload=True) raw.pick_types(eeg=True) - ch_dict = {'EEG 001': 'LAMY 1', 'EEG 002': 'LAMY 2', - 'EEG 003': 'LSTN 1', 'EEG 004': 'LSTN 2'} + ch_dict = { + "EEG 001": "LAMY 1", + "EEG 002": "LAMY 2", + "EEG 003": "LSTN 1", + "EEG 004": "LSTN 2", + } raw.pick_channels(list(ch_dict.keys())) raw.rename_channels(ch_dict) - raw.set_eeg_reference('average') - raw.set_channel_types({name: 'seeg' for name in raw.ch_names}) + raw.set_eeg_reference("average") + raw.set_channel_types({name: "seeg" for name in raw.ch_names}) raw.set_montage(None) aligned_ct, coords = _fake_CT_coords trans = mne.read_trans(fname_trans) - with pytest.warns(RuntimeWarning, match='`pial` surface not found'): + with pytest.warns(RuntimeWarning, match="`pial` surface not found"): gui = mne.gui.locate_ieeg( - raw.info, trans, aligned_ct, - subject=subject, subjects_dir=subjects_dir, - verbose=True) - - with pytest.raises(ValueError, match='read-only'): + raw.info, + trans, + aligned_ct, + subject=subject, + subjects_dir=subjects_dir, + verbose=True, + ) + + with pytest.raises(ValueError, match="read-only"): gui._ras[:] = coords[0] # start in the right position gui.set_RAS(coords[0]) gui.mark_channel() - with pytest.raises(ValueError, match='not found'): - gui.mark_channel('foo') + with pytest.raises(ValueError, match="not found"): + gui.mark_channel("foo") assert not gui._lines and not gui._lines_2D # no lines for one contact for ci, coord in enumerate(coords[1:], 1): coord_vox = apply_trans(gui._ras_vox_t, coord) - with use_log_level('debug'): - _fake_click(gui._figs[2], gui._figs[2].axes[0], - coord_vox[:-1], xform='data', kind='release') - assert_allclose(coord[:2], gui._ras[:2], atol=0.1, - err_msg=f'coords[{ci}][:2]') - assert_allclose(coord[2], gui._ras[2], atol=2, - err_msg=f'coords[{ci}][2]') + with use_log_level("debug"): + _fake_click( + gui._figs[2], + gui._figs[2].axes[0], + coord_vox[:-1], + xform="data", + kind="release", + ) + assert_allclose(coord[:2], gui._ras[:2], atol=0.1, err_msg=f"coords[{ci}][:2]") + assert_allclose(coord[2], gui._ras[2], atol=2, err_msg=f"coords[{ci}][2]") gui.mark_channel() # ensure a 3D line was made for each group @@ -168,53 +187,56 @@ def test_ieeg_elec_locate_display(renderer_interactive_pyvistaqt, gui._ch_index = 0 gui.set_RAS(coords[0]) # move to first position gui.mark_channel() - assert_allclose(coords[0], gui._chs['LAMY 1'], atol=0.2) + assert_allclose(coords[0], gui._chs["LAMY 1"], atol=0.2) gui._snap_button.click() - assert gui._snap_button.text() == 'Off' + assert gui._snap_button.text() == "Off" # now make sure no snap happens gui._ch_index = 0 gui.set_RAS(coords[1] + 1) gui.mark_channel() - assert_allclose(coords[1] + 1, gui._chs['LAMY 1'], atol=0.01) + assert_allclose(coords[1] + 1, gui._chs["LAMY 1"], atol=0.01) # check that it turns back on gui._snap_button.click() - assert gui._snap_button.text() == 'On' + assert gui._snap_button.text() == "On" # test remove - gui.remove_channel('LAMY 2') - assert np.isnan(gui._chs['LAMY 2']).all() + gui.remove_channel("LAMY 2") + assert np.isnan(gui._chs["LAMY 2"]).all() - with pytest.raises(ValueError, match='not found'): - gui.remove_channel('foo') + with pytest.raises(ValueError, match="not found"): + gui.remove_channel("foo") # check that raw object saved - assert not np.isnan(raw.info['chs'][0]['loc'][:3]).any() # LAMY 1 - assert np.isnan(raw.info['chs'][1]['loc'][:3]).all() # LAMY 2 (removed) + assert not np.isnan(raw.info["chs"][0]["loc"][:3]).any() # LAMY 1 + assert np.isnan(raw.info["chs"][1]["loc"][:3]).all() # LAMY 2 (removed) # move sliders gui._alpha_slider.setValue(75) assert gui._ch_alpha == 0.75 gui._radius_slider.setValue(5) assert gui._radius == 5 - ct_sum_before = np.nansum(gui._images['ct'][0].get_array().data) + ct_sum_before = np.nansum(gui._images["ct"][0].get_array().data) gui._ct_min_slider.setValue(500) - assert np.nansum(gui._images['ct'][0].get_array().data) < ct_sum_before + assert np.nansum(gui._images["ct"][0].get_array().data) < ct_sum_before # test buttons gui._toggle_show_brain() - assert 'mri' in gui._images - assert 'local_max' not in gui._images + assert "mri" in gui._images + assert "local_max" not in gui._images gui._toggle_show_max() - assert 'local_max' in gui._images - assert 'mip' not in gui._images + assert "local_max" in gui._images + assert "mip" not in gui._images gui._toggle_show_mip() - assert 'mip' in gui._images - assert 'mip_chs' in gui._images + assert "mip" in gui._images + assert "mip_chs" in gui._images assert len(gui._lines_2D) == 1 # LAMY only has one contact # check montage montage = raw.get_montage() assert montage is not None - assert_allclose(montage.get_positions()['ch_pos']['LAMY 1'], - [0.00726235, 0.01713514, 0.04167233], atol=0.01) + assert_allclose( + montage.get_positions()["ch_pos"]["LAMY 1"], + [0.00726235, 0.01713514, 0.04167233], + atol=0.01, + ) gui.close() diff --git a/mne/html_templates/_templates.py b/mne/html_templates/_templates.py index 28fd93617e5..5204cad131f 100644 --- a/mne/html_templates/_templates.py +++ b/mne/html_templates/_templates.py @@ -1,25 +1,18 @@ import jinja2 -autoescape = jinja2.select_autoescape( - default=True, - default_for_string=True -) +autoescape = jinja2.select_autoescape(default=True, default_for_string=True) # For _html_repr_() repr_templates_env = jinja2.Environment( - loader=jinja2.PackageLoader( - package_name='mne.html_templates', - package_path='repr' - ), - autoescape=autoescape + loader=jinja2.PackageLoader(package_name="mne.html_templates", package_path="repr"), + autoescape=autoescape, ) # For mne.Report report_templates_env = jinja2.Environment( loader=jinja2.PackageLoader( - package_name='mne.html_templates', - package_path='report' + package_name="mne.html_templates", package_path="report" ), - autoescape=autoescape + autoescape=autoescape, ) -report_templates_env.filters['zip'] = zip +report_templates_env.filters["zip"] = zip diff --git a/mne/inverse_sparse/__init__.py b/mne/inverse_sparse/__init__.py index 867becd38a5..a90b27f7ab1 100644 --- a/mne/inverse_sparse/__init__.py +++ b/mne/inverse_sparse/__init__.py @@ -4,6 +4,5 @@ # # License: Simplified BSD -from .mxne_inverse import (mixed_norm, tf_mixed_norm, - make_stc_from_dipoles) +from .mxne_inverse import mixed_norm, tf_mixed_norm, make_stc_from_dipoles from ._gamma_map import gamma_map diff --git a/mne/inverse_sparse/_gamma_map.py b/mne/inverse_sparse/_gamma_map.py index 6f71cbedae7..8e864df3837 100644 --- a/mne/inverse_sparse/_gamma_map.py +++ b/mne/inverse_sparse/_gamma_map.py @@ -7,14 +7,28 @@ from ..forward import is_fixed_orient from ..minimum_norm.inverse import _check_reference, _log_exp_var from ..utils import logger, verbose, warn -from .mxne_inverse import (_check_ori, _make_sparse_stc, _prepare_gain, - _reapply_source_weighting, _compute_residual, - _make_dipoles_sparse) +from .mxne_inverse import ( + _check_ori, + _make_sparse_stc, + _prepare_gain, + _reapply_source_weighting, + _compute_residual, + _make_dipoles_sparse, +) @verbose -def _gamma_map_opt(M, G, alpha, maxit=10000, tol=1e-6, update_mode=1, - group_size=1, gammas=None, verbose=None): +def _gamma_map_opt( + M, + G, + alpha, + maxit=10000, + tol=1e-6, + update_mode=1, + group_size=1, + gammas=None, + verbose=None, +): """Hierarchical Bayes (Gamma-MAP). Parameters @@ -46,6 +60,7 @@ def _gamma_map_opt(M, G, alpha, maxit=10000, tol=1e-6, update_mode=1, Indices of active sources. """ from scipy import linalg + G = G.copy() M = M.copy() @@ -58,15 +73,16 @@ def _gamma_map_opt(M, G, alpha, maxit=10000, tol=1e-6, update_mode=1, n_sensors, n_times = M.shape # apply normalization so the numerical values are sane - M_normalize_constant = np.linalg.norm(np.dot(M, M.T), ord='fro') + M_normalize_constant = np.linalg.norm(np.dot(M, M.T), ord="fro") M /= np.sqrt(M_normalize_constant) alpha /= M_normalize_constant G_normalize_constant = np.linalg.norm(G, ord=np.inf) G /= G_normalize_constant if n_sources % group_size != 0: - raise ValueError('Number of sources has to be evenly dividable by the ' - 'group size') + raise ValueError( + "Number of sources has to be evenly dividable by the " "group size" + ) n_active = n_sources active_set = np.arange(n_sources) @@ -84,7 +100,7 @@ def denom_fun(x): for itno in range(maxit): gammas[np.isnan(gammas)] = 0.0 - gidx = (np.abs(gammas) > eps) + gidx = np.abs(gammas) > eps active_set = active_set[gidx] gammas = gammas[gidx] @@ -94,7 +110,7 @@ def denom_fun(x): G = G[:, gidx] CM = np.dot(G * gammas[np.newaxis, :], G.T) - CM.flat[::n_sensors + 1] += alpha + CM.flat[:: n_sensors + 1] += alpha # Invert CM keeping symmetry U, S, _ = linalg.svd(CM, full_matrices=False) S = S[np.newaxis, :] @@ -105,21 +121,20 @@ def denom_fun(x): if update_mode == 1: # MacKay fixed point update (10) in [1] - numer = gammas ** 2 * np.mean((A * A.conj()).real, axis=1) + numer = gammas**2 * np.mean((A * A.conj()).real, axis=1) denom = gammas * np.sum(G * CMinvG, axis=0) elif update_mode == 2: # modified MacKay fixed point update (11) in [1] numer = gammas * np.sqrt(np.mean((A * A.conj()).real, axis=1)) denom = np.sum(G * CMinvG, axis=0) # sqrt is applied below else: - raise ValueError('Invalid value for update_mode') + raise ValueError("Invalid value for update_mode") if group_size == 1: if denom is None: gammas = numer else: - gammas = numer / np.maximum(denom_fun(denom), - np.finfo('float').eps) + gammas = numer / np.maximum(denom_fun(denom), np.finfo("float").eps) else: numer_comb = np.sum(numer.reshape(-1, group_size), axis=1) if denom is None: @@ -134,24 +149,27 @@ def denom_fun(x): gammas_full = np.zeros(n_sources, dtype=np.float64) gammas_full[active_set] = gammas - err = (np.sum(np.abs(gammas_full - gammas_full_old)) / - np.sum(np.abs(gammas_full_old))) + err = np.sum(np.abs(gammas_full - gammas_full_old)) / np.sum( + np.abs(gammas_full_old) + ) gammas_full_old = gammas_full - breaking = (err < tol or n_active == 0) + breaking = err < tol or n_active == 0 if len(gammas) != last_size or breaking: - logger.info('Iteration: %d\t active set size: %d\t convergence: ' - '%0.3e' % (itno, len(gammas), err)) + logger.info( + "Iteration: %d\t active set size: %d\t convergence: " + "%0.3e" % (itno, len(gammas), err) + ) last_size = len(gammas) if breaking: break if itno < maxit - 1: - logger.info('\nConvergence reached !\n') + logger.info("\nConvergence reached !\n") else: - warn('\nConvergence NOT reached !\n') + warn("\nConvergence NOT reached !\n") # undo normalization and compute final posterior mean n_const = np.sqrt(M_normalize_constant) / G_normalize_constant @@ -161,10 +179,25 @@ def denom_fun(x): @verbose -def gamma_map(evoked, forward, noise_cov, alpha, loose="auto", depth=0.8, - xyz_same_gamma=True, maxit=10000, tol=1e-6, update_mode=1, - gammas=None, pca=True, return_residual=False, - return_as_dipoles=False, rank=None, pick_ori=None, verbose=None): +def gamma_map( + evoked, + forward, + noise_cov, + alpha, + loose="auto", + depth=0.8, + xyz_same_gamma=True, + maxit=10000, + tol=1e-6, + update_mode=1, + gammas=None, + pca=True, + return_residual=False, + return_as_dipoles=False, + rank=None, + pick_ori=None, + verbose=None, +): """Hierarchical Bayes (Gamma-MAP) sparse source localization method. Models each source time course using a zero-mean Gaussian prior with an @@ -228,23 +261,32 @@ def gamma_map(evoked, forward, noise_cov, alpha, loose="auto", depth=0.8, _check_reference(evoked) forward, gain, gain_info, whitener, source_weighting, mask = _prepare_gain( - forward, evoked.info, noise_cov, pca, depth, loose, rank) + forward, evoked.info, noise_cov, pca, depth, loose, rank + ) _check_ori(pick_ori, forward) group_size = 1 if (is_fixed_orient(forward) or not xyz_same_gamma) else 3 # get the data - sel = [evoked.ch_names.index(name) for name in gain_info['ch_names']] + sel = [evoked.ch_names.index(name) for name in gain_info["ch_names"]] M = evoked.data[sel] # whiten the data - logger.info('Whitening data matrix.') + logger.info("Whitening data matrix.") M = np.dot(whitener, M) # run the optimization - X, active_set = _gamma_map_opt(M, gain, alpha, maxit=maxit, tol=tol, - update_mode=update_mode, gammas=gammas, - group_size=group_size, verbose=verbose) + X, active_set = _gamma_map_opt( + M, + gain, + alpha, + maxit=maxit, + tol=tol, + update_mode=update_mode, + gammas=gammas, + group_size=group_size, + verbose=verbose, + ) if len(active_set) == 0: raise Exception("No active dipoles found. alpha is too big.") @@ -255,8 +297,7 @@ def gamma_map(evoked, forward, noise_cov, alpha, loose="auto", depth=0.8, X = _reapply_source_weighting(X, source_weighting, active_set) if return_residual: - residual = _compute_residual(forward, evoked, X, active_set, - gain_info) + residual = _compute_residual(forward, evoked, X, active_set, gain_info) if group_size == 1 and not is_fixed_orient(forward): # make sure each source has 3 components @@ -274,18 +315,26 @@ def gamma_map(evoked, forward, noise_cov, alpha, loose="auto", depth=0.8, del source_weighting tmin = evoked.times[0] - tstep = 1.0 / evoked.info['sfreq'] + tstep = 1.0 / evoked.info["sfreq"] if return_as_dipoles: - out = _make_dipoles_sparse(X, active_set, forward, tmin, tstep, M, - gain_active, active_is_idx=True) + out = _make_dipoles_sparse( + X, active_set, forward, tmin, tstep, M, gain_active, active_is_idx=True + ) else: - out = _make_sparse_stc(X, active_set, forward, tmin, tstep, - active_is_idx=True, pick_ori=pick_ori, - verbose=verbose) - - _log_exp_var(M, M_estimate, prefix='') - logger.info('[done]') + out = _make_sparse_stc( + X, + active_set, + forward, + tmin, + tstep, + active_is_idx=True, + pick_ori=pick_ori, + verbose=verbose, + ) + + _log_exp_var(M, M_estimate, prefix="") + logger.info("[done]") if return_residual: out = out, residual diff --git a/mne/inverse_sparse/mxne_debiasing.py b/mne/inverse_sparse/mxne_debiasing.py index 1ea3ca6f95d..472afc1242c 100644 --- a/mne/inverse_sparse/mxne_debiasing.py +++ b/mne/inverse_sparse/mxne_debiasing.py @@ -37,13 +37,13 @@ def power_iteration_kron(A, C, max_iter=1000, tol=1e-3, random_state=0): AS_size = C.shape[0] rng = check_random_state(random_state) B = rng.randn(AS_size, AS_size) - B /= np.linalg.norm(B, 'fro') + B /= np.linalg.norm(B, "fro") ATA = np.dot(A.T, A) CCT = np.dot(C, C.T) L0 = np.inf for _ in range(max_iter): Y = np.dot(np.dot(ATA, B), CCT) - L = np.linalg.norm(Y, 'fro') + L = np.linalg.norm(Y, "fro") if abs(L - L0) < tol: break @@ -115,7 +115,7 @@ def compute_bias(M, G, X, max_iter=1000, tol=1e-6, n_orient=1, verbose=None): D = np.maximum(D, 1.0) t0 = t - t = 0.5 * (1.0 + sqrt(1.0 + 4.0 * t ** 2)) + t = 0.5 * (1.0 + sqrt(1.0 + 4.0 * t**2)) Y.fill(0.0) dt = (t0 - 1.0) / t Y = D + dt * (D - D0) @@ -123,11 +123,15 @@ def compute_bias(M, G, X, max_iter=1000, tol=1e-6, n_orient=1, verbose=None): Ddiff = np.linalg.norm(D - D0, np.inf) if Ddiff < tol: - logger.info("Debiasing converged after %d iterations " - "max(|D - D0| = %e < %e)" % (i, Ddiff, tol)) + logger.info( + "Debiasing converged after %d iterations " + "max(|D - D0| = %e < %e)" % (i, Ddiff, tol) + ) break else: Ddiff = np.linalg.norm(D - D0, np.inf) - logger.info("Debiasing did not converge after %d iterations! " - "max(|D - D0| = %e >= %e)" % (max_iter, Ddiff, tol)) + logger.info( + "Debiasing did not converge after %d iterations! " + "max(|D - D0| = %e >= %e)" % (max_iter, Ddiff, tol) + ) return D diff --git a/mne/inverse_sparse/mxne_inverse.py b/mne/inverse_sparse/mxne_inverse.py index ac2cbc5f488..96f596cb3b5 100644 --- a/mne/inverse_sparse/mxne_inverse.py +++ b/mne/inverse_sparse/mxne_inverse.py @@ -6,25 +6,46 @@ import numpy as np from ..source_estimate import SourceEstimate, _BaseSourceEstimate, _make_stc -from ..minimum_norm.inverse import (combine_xyz, _prepare_forward, - _check_reference, _log_exp_var) +from ..minimum_norm.inverse import ( + combine_xyz, + _prepare_forward, + _check_reference, + _log_exp_var, +) from ..forward import is_fixed_orient from ..io.proj import deactivate_proj -from ..utils import (logger, verbose, _check_depth, _check_option, sum_squared, - _validate_type, check_random_state, warn) +from ..utils import ( + logger, + verbose, + _check_depth, + _check_option, + sum_squared, + _validate_type, + check_random_state, + warn, +) from ..dipole import Dipole -from .mxne_optim import (mixed_norm_solver, iterative_mixed_norm_solver, _Phi, - tf_mixed_norm_solver, iterative_tf_mixed_norm_solver, - norm_l2inf, norm_epsilon_inf, groups_norm2) +from .mxne_optim import ( + mixed_norm_solver, + iterative_mixed_norm_solver, + _Phi, + tf_mixed_norm_solver, + iterative_tf_mixed_norm_solver, + norm_l2inf, + norm_epsilon_inf, + groups_norm2, +) def _check_ori(pick_ori, forward): """Check pick_ori.""" - _check_option('pick_ori', pick_ori, [None, 'vector']) - if pick_ori == 'vector' and is_fixed_orient(forward): - raise ValueError('pick_ori="vector" cannot be combined with a fixed ' - 'orientation forward solution.') + _check_option("pick_ori", pick_ori, [None, "vector"]) + if pick_ori == "vector" and is_fixed_orient(forward): + raise ValueError( + 'pick_ori="vector" cannot be combined with a fixed ' + "orientation forward solution." + ) def _prepare_weights(forward, gain, source_weighting, weights, weights_min): @@ -33,15 +54,18 @@ def _prepare_weights(forward, gain, source_weighting, weights, weights_min): weights = np.max(np.abs(weights.data), axis=1) weights_max = np.max(weights) if weights_min > weights_max: - raise ValueError('weights_min > weights_max (%s > %s)' % - (weights_min, weights_max)) + raise ValueError( + "weights_min > weights_max (%s > %s)" % (weights_min, weights_max) + ) weights_min = weights_min / weights_max weights = weights / weights_max n_dip_per_pos = 1 if is_fixed_orient(forward) else 3 weights = np.ravel(np.tile(weights, [n_dip_per_pos, 1]).T) if len(weights) != gain.shape[1]: - raise ValueError('weights do not have the correct dimension ' - ' (%d != %d)' % (len(weights), gain.shape[1])) + raise ValueError( + "weights do not have the correct dimension " + " (%d != %d)" % (len(weights), gain.shape[1]) + ) if len(source_weighting.shape) == 1: source_weighting *= weights else: @@ -49,7 +73,7 @@ def _prepare_weights(forward, gain, source_weighting, weights, weights_min): gain *= weights[None, :] if weights_min is not None: - mask = (weights > weights_min) + mask = weights > weights_min gain = gain[:, mask] n_sources = np.sum(mask) // n_dip_per_pos logger.info("Reducing source space to %d sources" % n_sources) @@ -57,18 +81,20 @@ def _prepare_weights(forward, gain, source_weighting, weights, weights_min): return gain, source_weighting, mask -def _prepare_gain(forward, info, noise_cov, pca, depth, loose, rank, - weights=None, weights_min=None): - depth = _check_depth(depth, 'depth_sparse') - forward, gain_info, gain, _, _, source_weighting, _, _, whitener = \ - _prepare_forward(forward, info, noise_cov, 'auto', loose, rank, pca, - use_cps=True, **depth) +def _prepare_gain( + forward, info, noise_cov, pca, depth, loose, rank, weights=None, weights_min=None +): + depth = _check_depth(depth, "depth_sparse") + forward, gain_info, gain, _, _, source_weighting, _, _, whitener = _prepare_forward( + forward, info, noise_cov, "auto", loose, rank, pca, use_cps=True, **depth + ) if weights is None: mask = None else: gain, source_weighting, mask = _prepare_weights( - forward, gain, source_weighting, weights, weights_min) + forward, gain, source_weighting, weights, weights_min + ) return forward, gain, gain_info, whitener, source_weighting, mask @@ -80,25 +106,26 @@ def _reapply_source_weighting(X, source_weighting, active_set): def _compute_residual(forward, evoked, X, active_set, info): # OK, picking based on row_names is safe - sel = [forward['sol']['row_names'].index(c) for c in info['ch_names']] - residual = evoked.copy().pick(info['ch_names']) + sel = [forward["sol"]["row_names"].index(c) for c in info["ch_names"]] + residual = evoked.copy().pick(info["ch_names"]) r_tmp = residual.copy() - r_tmp.data = np.dot(forward['sol']['data'][sel, :][:, active_set], X) + r_tmp.data = np.dot(forward["sol"]["data"][sel, :][:, active_set], X) # Take care of proj active_projs = list() non_active_projs = list() - for p in evoked.info['projs']: - if p['active']: + for p in evoked.info["projs"]: + if p["active"]: active_projs.append(p) else: non_active_projs.append(p) if len(active_projs) > 0: with r_tmp.info._unlock(): - r_tmp.info['projs'] = deactivate_proj(active_projs, copy=True, - verbose=False) + r_tmp.info["projs"] = deactivate_proj( + active_projs, copy=True, verbose=False + ) r_tmp.apply_proj(verbose=False) r_tmp.add_proj(non_active_projs, remove_existing=False, verbose=False) @@ -108,13 +135,21 @@ def _compute_residual(forward, evoked, X, active_set, info): @verbose -def _make_sparse_stc(X, active_set, forward, tmin, tstep, - active_is_idx=False, pick_ori=None, verbose=None): - source_nn = forward['source_nn'] +def _make_sparse_stc( + X, + active_set, + forward, + tmin, + tstep, + active_is_idx=False, + pick_ori=None, + verbose=None, +): + source_nn = forward["source_nn"] vector = False if not is_fixed_orient(forward): - if pick_ori != 'vector': - logger.info('combining the current components...') + if pick_ori != "vector": + logger.info("combining the current components...") X = combine_xyz(X) else: vector = True @@ -129,21 +164,29 @@ def _make_sparse_stc(X, active_set, forward, tmin, tstep, if n_dip_per_pos > 1: active_idx = np.unique(active_idx // n_dip_per_pos) - src = forward['src'] + src = forward["src"] vertices = [] n_points_so_far = 0 for this_src in src: - this_n_points_so_far = n_points_so_far + len(this_src['vertno']) - this_active_idx = active_idx[(n_points_so_far <= active_idx) & - (active_idx < this_n_points_so_far)] + this_n_points_so_far = n_points_so_far + len(this_src["vertno"]) + this_active_idx = active_idx[ + (n_points_so_far <= active_idx) & (active_idx < this_n_points_so_far) + ] this_active_idx -= n_points_so_far - this_vertno = this_src['vertno'][this_active_idx] + this_vertno = this_src["vertno"][this_active_idx] n_points_so_far = this_n_points_so_far vertices.append(this_vertno) source_nn = source_nn[active_idx] return _make_stc( - X, vertices, src.kind, tmin, tstep, src[0]['subject_his_id'], - vector=vector, source_nn=source_nn) + X, + vertices, + src.kind, + tmin, + tstep, + src[0]["subject_his_id"], + vector=vector, + source_nn=source_nn, + ) def _split_gof(M, X, gain): @@ -170,7 +213,7 @@ def _split_gof(M, X, gain): # determine the weights by projecting each one onto this basis w = (U.T @ gain)[:, :, np.newaxis] * X w_norm = np.linalg.norm(w, axis=1, keepdims=True) - w_norm[w_norm == 0] = 1. + w_norm[w_norm == 0] = 1.0 w /= w_norm # our weights are now unit-norm positive (will presrve power) fit_back = np.linalg.norm(fit_orth[:, np.newaxis] * w, axis=0) ** 2 @@ -182,9 +225,17 @@ def _split_gof(M, X, gain): @verbose -def _make_dipoles_sparse(X, active_set, forward, tmin, tstep, M, - gain_active, active_is_idx=False, - verbose=None): +def _make_dipoles_sparse( + X, + active_set, + forward, + tmin, + tstep, + M, + gain_active, + active_is_idx=False, + verbose=None, +): times = tmin + tstep * np.arange(X.shape[1]) if not active_is_idx: @@ -212,21 +263,26 @@ def _make_dipoles_sparse(X, active_set, forward, tmin, tstep, M, dipoles = [] for k, i_dip in enumerate(active_idx): - i_pos = forward['source_rr'][i_dip][np.newaxis, :] + i_pos = forward["source_rr"][i_dip][np.newaxis, :] i_pos = i_pos.repeat(len(times), axis=0) - X_ = X[k * n_dip_per_pos: (k + 1) * n_dip_per_pos] + X_ = X[k * n_dip_per_pos : (k + 1) * n_dip_per_pos] if n_dip_per_pos == 1: amplitude = X_[0] - i_ori = forward['source_nn'][i_dip][np.newaxis, :] + i_ori = forward["source_nn"][i_dip][np.newaxis, :] i_ori = i_ori.repeat(len(times), axis=0) else: - if forward['surf_ori']: - X_ = np.dot(forward['source_nn'][ - i_dip * n_dip_per_pos:(i_dip + 1) * n_dip_per_pos].T, X_) + if forward["surf_ori"]: + X_ = np.dot( + forward["source_nn"][ + i_dip * n_dip_per_pos : (i_dip + 1) * n_dip_per_pos + ].T, + X_, + ) amplitude = np.linalg.norm(X_, axis=0) i_ori = np.zeros((len(times), 3)) - i_ori[amplitude > 0.] = (X_[:, amplitude > 0.] / - amplitude[amplitude > 0.]).T + i_ori[amplitude > 0.0] = ( + X_[:, amplitude > 0.0] / amplitude[amplitude > 0.0] + ).T dipoles.append(Dipole(times, i_pos, amplitude, i_ori, gof_split[k])) @@ -250,47 +306,68 @@ def make_stc_from_dipoles(dipoles, src, verbose=None): stc : SourceEstimate The source estimate. """ - logger.info('Converting dipoles into a SourceEstimate.') + logger.info("Converting dipoles into a SourceEstimate.") if isinstance(dipoles, Dipole): dipoles = [dipoles] if not isinstance(dipoles, list): - raise ValueError('Dipoles must be an instance of Dipole or ' - 'a list of instances of Dipole. ' - 'Got %s!' % type(dipoles)) + raise ValueError( + "Dipoles must be an instance of Dipole or " + "a list of instances of Dipole. " + "Got %s!" % type(dipoles) + ) tmin = dipoles[0].times[0] tstep = dipoles[0].times[1] - tmin X = np.zeros((len(dipoles), len(dipoles[0].times))) - source_rr = np.concatenate([_src['rr'][_src['vertno'], :] for _src in src], - axis=0) - n_lh_points = len(src[0]['vertno']) + source_rr = np.concatenate([_src["rr"][_src["vertno"], :] for _src in src], axis=0) + n_lh_points = len(src[0]["vertno"]) lh_vertno = list() rh_vertno = list() for i in range(len(dipoles)): if not np.all(dipoles[i].pos == dipoles[i].pos[0]): - raise ValueError('Only dipoles with fixed position over time ' - 'are supported!') + raise ValueError( + "Only dipoles with fixed position over time " "are supported!" + ) X[i] = dipoles[i].amplitude idx = np.all(source_rr == dipoles[i].pos[0], axis=1) idx = np.where(idx)[0][0] if idx < n_lh_points: - lh_vertno.append(src[0]['vertno'][idx]) + lh_vertno.append(src[0]["vertno"][idx]) else: - rh_vertno.append(src[1]['vertno'][idx - n_lh_points]) - vertices = [np.array(lh_vertno).astype(int), - np.array(rh_vertno).astype(int)] - stc = SourceEstimate(X, vertices=vertices, tmin=tmin, tstep=tstep, - subject=src._subject) - logger.info('[done]') + rh_vertno.append(src[1]["vertno"][idx - n_lh_points]) + vertices = [np.array(lh_vertno).astype(int), np.array(rh_vertno).astype(int)] + stc = SourceEstimate( + X, vertices=vertices, tmin=tmin, tstep=tstep, subject=src._subject + ) + logger.info("[done]") return stc @verbose -def mixed_norm(evoked, forward, noise_cov, alpha='sure', loose='auto', - depth=0.8, maxit=3000, tol=1e-4, active_set_size=10, - debias=True, time_pca=True, weights=None, weights_min=0., - solver='auto', n_mxne_iter=1, return_residual=False, - return_as_dipoles=False, dgap_freq=10, rank=None, pick_ori=None, - sure_alpha_grid="auto", random_state=None, verbose=None): +def mixed_norm( + evoked, + forward, + noise_cov, + alpha="sure", + loose="auto", + depth=0.8, + maxit=3000, + tol=1e-4, + active_set_size=10, + debias=True, + time_pca=True, + weights=None, + weights_min=0.0, + solver="auto", + n_mxne_iter=1, + return_residual=False, + return_as_dipoles=False, + dgap_freq=10, + rank=None, + pick_ori=None, + sure_alpha_grid="auto", + random_state=None, + verbose=None, +): """Mixed-norm estimate (MxNE) and iterative reweighted MxNE (irMxNE). Compute L1/L2 mixed-norm solution :footcite:`GramfortEtAl2012` or L0.5/L2 @@ -380,26 +457,38 @@ def mixed_norm(evoked, forward, noise_cov, alpha='sure', loose='auto', .. footbibliography:: """ from scipy import linalg - _validate_type(alpha, ('numeric', str), 'alpha') + + _validate_type(alpha, ("numeric", str), "alpha") if isinstance(alpha, str): - _check_option('alpha', alpha, ('sure',)) - elif not 0. <= alpha < 100: - raise ValueError('If not equal to "sure" alpha must be in [0, 100). ' - 'Got alpha = %s' % alpha) + _check_option("alpha", alpha, ("sure",)) + elif not 0.0 <= alpha < 100: + raise ValueError( + 'If not equal to "sure" alpha must be in [0, 100). ' + "Got alpha = %s" % alpha + ) if n_mxne_iter < 1: - raise ValueError('MxNE has to be computed at least 1 time. ' - 'Requires n_mxne_iter >= 1, got %d' % n_mxne_iter) - if dgap_freq <= 0.: - raise ValueError('dgap_freq must be a positive integer.' - ' Got dgap_freq = %s' % dgap_freq) - if not (isinstance(sure_alpha_grid, (np.ndarray, list)) or - sure_alpha_grid == "auto"): - raise ValueError('If not equal to "auto" sure_alpha_grid must be an ' - 'array. Got %s' % type(sure_alpha_grid)) - if ((isinstance(sure_alpha_grid, str) and sure_alpha_grid != "auto") - and (isinstance(alpha, str) and alpha != "sure")): - raise Exception('If sure_alpha_grid is manually specified, alpha must ' - 'be "sure". Got %s' % alpha) + raise ValueError( + "MxNE has to be computed at least 1 time. " + "Requires n_mxne_iter >= 1, got %d" % n_mxne_iter + ) + if dgap_freq <= 0.0: + raise ValueError( + "dgap_freq must be a positive integer." " Got dgap_freq = %s" % dgap_freq + ) + if not ( + isinstance(sure_alpha_grid, (np.ndarray, list)) or sure_alpha_grid == "auto" + ): + raise ValueError( + 'If not equal to "auto" sure_alpha_grid must be an ' + "array. Got %s" % type(sure_alpha_grid) + ) + if (isinstance(sure_alpha_grid, str) and sure_alpha_grid != "auto") and ( + isinstance(alpha, str) and alpha != "sure" + ): + raise Exception( + "If sure_alpha_grid is manually specified, alpha must " + 'be "sure". Got %s' % alpha + ) pca = True if not isinstance(evoked, list): evoked = [evoked] @@ -407,20 +496,27 @@ def mixed_norm(evoked, forward, noise_cov, alpha='sure', loose='auto', _check_reference(evoked[0]) all_ch_names = evoked[0].ch_names - if not all(all_ch_names == evoked[i].ch_names - for i in range(1, len(evoked))): - raise Exception('All the datasets must have the same good channels.') + if not all(all_ch_names == evoked[i].ch_names for i in range(1, len(evoked))): + raise Exception("All the datasets must have the same good channels.") forward, gain, gain_info, whitener, source_weighting, mask = _prepare_gain( - forward, evoked[0].info, noise_cov, pca, depth, loose, rank, - weights, weights_min) + forward, + evoked[0].info, + noise_cov, + pca, + depth, + loose, + rank, + weights, + weights_min, + ) _check_ori(pick_ori, forward) - sel = [all_ch_names.index(name) for name in gain_info['ch_names']] + sel = [all_ch_names.index(name) for name in gain_info["ch_names"]] M = np.concatenate([e.data[sel] for e in evoked], axis=1) # Whiten data - logger.info('Whitening data matrix.') + logger.info("Whitening data matrix.") M = np.dot(whitener, M) if time_pca: @@ -445,24 +541,52 @@ def mixed_norm(evoked, forward, noise_cov, alpha='sure', loose='auto', if isinstance(sure_alpha_grid, str) and sure_alpha_grid == "auto": alpha_grid = np.geomspace(100, 10, num=15) X, active_set, best_alpha_ = _compute_mxne_sure( - M, gain, alpha_grid, sigma=1, random_state=random_state, - n_mxne_iter=n_mxne_iter, maxit=maxit, tol=tol, - n_orient=n_dip_per_pos, active_set_size=active_set_size, - debias=debias, solver=solver, dgap_freq=dgap_freq, verbose=verbose) - logger.info('Selected alpha: %s' % best_alpha_) + M, + gain, + alpha_grid, + sigma=1, + random_state=random_state, + n_mxne_iter=n_mxne_iter, + maxit=maxit, + tol=tol, + n_orient=n_dip_per_pos, + active_set_size=active_set_size, + debias=debias, + solver=solver, + dgap_freq=dgap_freq, + verbose=verbose, + ) + logger.info("Selected alpha: %s" % best_alpha_) else: if n_mxne_iter == 1: X, active_set, E = mixed_norm_solver( - M, gain, alpha, maxit=maxit, tol=tol, - active_set_size=active_set_size, n_orient=n_dip_per_pos, - debias=debias, solver=solver, dgap_freq=dgap_freq, - verbose=verbose) + M, + gain, + alpha, + maxit=maxit, + tol=tol, + active_set_size=active_set_size, + n_orient=n_dip_per_pos, + debias=debias, + solver=solver, + dgap_freq=dgap_freq, + verbose=verbose, + ) else: X, active_set, E = iterative_mixed_norm_solver( - M, gain, alpha, n_mxne_iter, maxit=maxit, tol=tol, - n_orient=n_dip_per_pos, active_set_size=active_set_size, - debias=debias, solver=solver, dgap_freq=dgap_freq, - verbose=verbose) + M, + gain, + alpha, + n_mxne_iter, + maxit=maxit, + tol=tol, + n_orient=n_dip_per_pos, + active_set_size=active_set_size, + debias=debias, + solver=solver, + dgap_freq=dgap_freq, + verbose=verbose, + ) if time_pca: X = np.dot(X, Vh) @@ -491,25 +615,30 @@ def mixed_norm(evoked, forward, noise_cov, alpha='sure', loose='auto', cnt = 0 for e in evoked: tmin = e.times[0] - tstep = 1.0 / e.info['sfreq'] - Xe = X[:, cnt:(cnt + len(e.times))] + tstep = 1.0 / e.info["sfreq"] + Xe = X[:, cnt : (cnt + len(e.times))] if return_as_dipoles: out = _make_dipoles_sparse( - Xe, active_set, forward, tmin, tstep, - M[:, cnt:(cnt + len(e.times))], - gain_active) + Xe, + active_set, + forward, + tmin, + tstep, + M[:, cnt : (cnt + len(e.times))], + gain_active, + ) else: out = _make_sparse_stc( - Xe, active_set, forward, tmin, tstep, pick_ori=pick_ori) + Xe, active_set, forward, tmin, tstep, pick_ori=pick_ori + ) outs.append(out) cnt += len(e.times) if return_residual: - residual.append(_compute_residual(forward, e, Xe, active_set, - gain_info)) + residual.append(_compute_residual(forward, e, Xe, active_set, gain_info)) - _log_exp_var(M, M_estimate, prefix='') - logger.info('[done]') + _log_exp_var(M, M_estimate, prefix="") + logger.info("[done]") if len(outs) == 1: out = outs[0] @@ -531,7 +660,7 @@ def _window_evoked(evoked, size): else: lsize, rsize = size evoked = evoked.copy() - sfreq = float(evoked.info['sfreq']) + sfreq = float(evoked.info["sfreq"]) lsize = int(lsize * sfreq) rsize = int(rsize * sfreq) lhann = np.hanning(lsize * 2)[:lsize] @@ -542,13 +671,31 @@ def _window_evoked(evoked, size): @verbose -def tf_mixed_norm(evoked, forward, noise_cov, - loose='auto', depth=0.8, maxit=3000, - tol=1e-4, weights=None, weights_min=0., pca=True, - debias=True, wsize=64, tstep=4, window=0.02, - return_residual=False, return_as_dipoles=False, alpha=None, - l1_ratio=None, dgap_freq=10, rank=None, pick_ori=None, - n_tfmxne_iter=1, verbose=None): +def tf_mixed_norm( + evoked, + forward, + noise_cov, + loose="auto", + depth=0.8, + maxit=3000, + tol=1e-4, + weights=None, + weights_min=0.0, + pca=True, + debias=True, + wsize=64, + tstep=4, + window=0.02, + return_residual=False, + return_as_dipoles=False, + alpha=None, + l1_ratio=None, + dgap_freq=10, + rank=None, + pick_ori=None, + n_tfmxne_iter=1, + verbose=None, +): """Time-Frequency Mixed-norm estimate (TF-MxNE). Compute L1/L2 + L1 mixed-norm solution on time-frequency @@ -641,34 +788,38 @@ def tf_mixed_norm(evoked, forward, noise_cov, all_ch_names = evoked.ch_names info = evoked.info - if not (0. <= alpha < 100.): - raise ValueError('alpha must be in [0, 100). ' - 'Got alpha = %s' % alpha) + if not (0.0 <= alpha < 100.0): + raise ValueError("alpha must be in [0, 100). " "Got alpha = %s" % alpha) - if not (0. <= l1_ratio <= 1.): - raise ValueError('l1_ratio must be in range [0, 1].' - ' Got l1_ratio = %s' % l1_ratio) - alpha_space = alpha * (1. - l1_ratio) + if not (0.0 <= l1_ratio <= 1.0): + raise ValueError( + "l1_ratio must be in range [0, 1]." " Got l1_ratio = %s" % l1_ratio + ) + alpha_space = alpha * (1.0 - l1_ratio) alpha_time = alpha * l1_ratio if n_tfmxne_iter < 1: - raise ValueError('TF-MxNE has to be computed at least 1 time. ' - 'Requires n_tfmxne_iter >= 1, got %s' % n_tfmxne_iter) + raise ValueError( + "TF-MxNE has to be computed at least 1 time. " + "Requires n_tfmxne_iter >= 1, got %s" % n_tfmxne_iter + ) - if dgap_freq <= 0.: - raise ValueError('dgap_freq must be a positive integer.' - ' Got dgap_freq = %s' % dgap_freq) + if dgap_freq <= 0.0: + raise ValueError( + "dgap_freq must be a positive integer." " Got dgap_freq = %s" % dgap_freq + ) tstep = np.atleast_1d(tstep) wsize = np.atleast_1d(wsize) if len(tstep) != len(wsize): - raise ValueError('The same number of window sizes and steps must be ' - 'passed. Got tstep = %s and wsize = %s' % - (tstep, wsize)) + raise ValueError( + "The same number of window sizes and steps must be " + "passed. Got tstep = %s and wsize = %s" % (tstep, wsize) + ) forward, gain, gain_info, whitener, source_weighting, mask = _prepare_gain( - forward, evoked.info, noise_cov, pca, depth, loose, rank, - weights, weights_min) + forward, evoked.info, noise_cov, pca, depth, loose, rank, weights, weights_min + ) _check_ori(pick_ori, forward) n_dip_per_pos = 1 if is_fixed_orient(forward) else 3 @@ -680,7 +831,7 @@ def tf_mixed_norm(evoked, forward, noise_cov, M = evoked.data[sel] # Whiten data - logger.info('Whitening data matrix.') + logger.info("Whitening data matrix.") M = np.dot(whitener, M) n_steps = np.ceil(M.shape[1] / tstep.astype(float)).astype(int) @@ -697,18 +848,40 @@ def tf_mixed_norm(evoked, forward, noise_cov, if n_tfmxne_iter == 1: X, active_set, E = tf_mixed_norm_solver( - M, gain, alpha_space, alpha_time, wsize=wsize, tstep=tstep, - maxit=maxit, tol=tol, verbose=verbose, n_orient=n_dip_per_pos, - dgap_freq=dgap_freq, debias=debias) + M, + gain, + alpha_space, + alpha_time, + wsize=wsize, + tstep=tstep, + maxit=maxit, + tol=tol, + verbose=verbose, + n_orient=n_dip_per_pos, + dgap_freq=dgap_freq, + debias=debias, + ) else: X, active_set, E = iterative_tf_mixed_norm_solver( - M, gain, alpha_space, alpha_time, wsize=wsize, tstep=tstep, - n_tfmxne_iter=n_tfmxne_iter, maxit=maxit, tol=tol, verbose=verbose, - n_orient=n_dip_per_pos, dgap_freq=dgap_freq, debias=debias) + M, + gain, + alpha_space, + alpha_time, + wsize=wsize, + tstep=tstep, + n_tfmxne_iter=n_tfmxne_iter, + maxit=maxit, + tol=tol, + verbose=verbose, + n_orient=n_dip_per_pos, + dgap_freq=dgap_freq, + debias=debias, + ) if active_set.sum() == 0: - raise Exception("No active dipoles found. " - "alpha_space/alpha_time are too big.") + raise Exception( + "No active dipoles found. " "alpha_space/alpha_time are too big." + ) # Compute estimated whitened sensor data for each dipole (dip, ch, time) gain_active = gain[:, active_set] @@ -723,19 +896,23 @@ def tf_mixed_norm(evoked, forward, noise_cov, gain_active /= source_weighting[active_set] if return_residual: - residual = _compute_residual( - forward, evoked, X, active_set, gain_info) + residual = _compute_residual(forward, evoked, X, active_set, gain_info) if return_as_dipoles: out = _make_dipoles_sparse( - X, active_set, forward, evoked.times[0], 1.0 / info['sfreq'], - M, gain_active) + X, active_set, forward, evoked.times[0], 1.0 / info["sfreq"], M, gain_active + ) else: out = _make_sparse_stc( - X, active_set, forward, evoked.times[0], 1.0 / info['sfreq'], - pick_ori=pick_ori) + X, + active_set, + forward, + evoked.times[0], + 1.0 / info["sfreq"], + pick_ori=pick_ori, + ) - logger.info('[done]') + logger.info("[done]") if return_residual: out = out, residual @@ -744,9 +921,22 @@ def tf_mixed_norm(evoked, forward, noise_cov, @verbose -def _compute_mxne_sure(M, gain, alpha_grid, sigma, n_mxne_iter, maxit, tol, - n_orient, active_set_size, debias, solver, dgap_freq, - random_state, verbose): +def _compute_mxne_sure( + M, + gain, + alpha_grid, + sigma, + n_mxne_iter, + maxit, + tol, + n_orient, + active_set_size, + debias, + solver, + dgap_freq, + random_state, + verbose, +): """Stein Unbiased Risk Estimator (SURE). Implements the finite-difference Monte-Carlo approximation @@ -799,26 +989,46 @@ def _compute_mxne_sure(M, gain, alpha_grid, sigma, n_mxne_iter, maxit, tol, ---------- .. footbibliography:: """ + def g(w): return np.sqrt(np.sqrt(groups_norm2(w.copy(), n_orient))) def gprime(w): - return 2. * np.repeat(g(w), n_orient).ravel() + return 2.0 * np.repeat(g(w), n_orient).ravel() - def _run_solver(alpha, M, n_mxne_iter, as_init=None, X_init=None, - w_init=None): + def _run_solver(alpha, M, n_mxne_iter, as_init=None, X_init=None, w_init=None): if n_mxne_iter == 1: X, active_set, _ = mixed_norm_solver( - M, gain, alpha, maxit=maxit, tol=tol, - active_set_size=active_set_size, n_orient=n_orient, - debias=debias, solver=solver, dgap_freq=dgap_freq, - active_set_init=as_init, X_init=X_init, verbose=False) + M, + gain, + alpha, + maxit=maxit, + tol=tol, + active_set_size=active_set_size, + n_orient=n_orient, + debias=debias, + solver=solver, + dgap_freq=dgap_freq, + active_set_init=as_init, + X_init=X_init, + verbose=False, + ) else: X, active_set, _ = iterative_mixed_norm_solver( - M, gain, alpha, n_mxne_iter, maxit=maxit, tol=tol, - n_orient=n_orient, active_set_size=active_set_size, - debias=debias, solver=solver, dgap_freq=dgap_freq, - weight_init=w_init, verbose=False) + M, + gain, + alpha, + n_mxne_iter, + maxit=maxit, + tol=tol, + n_orient=n_orient, + active_set_size=active_set_size, + debias=debias, + solver=solver, + dgap_freq=dgap_freq, + weight_init=w_init, + verbose=False, + ) return X, active_set def _fit_on_grid(gain, M, eps, delta): @@ -827,9 +1037,9 @@ def _fit_on_grid(gain, M, eps, delta): active_sets, active_sets_eps = [], [] M_eps = M + eps * delta # warm start - first iteration (leverages convexity) - logger.info('Warm starting...') + logger.info("Warm starting...") for j, alpha in enumerate(alpha_grid): - logger.info('alpha: %s' % alpha) + logger.info("alpha: %s" % alpha) X, a_set = _run_solver(alpha, M, 1) X_eps, a_set_eps = _run_solver(alpha, M_eps, 1) coefs_grid_1_0[j][a_set, :] = X @@ -842,20 +1052,19 @@ def _fit_on_grid(gain, M, eps, delta): else: coefs_grid_1 = coefs_grid_1_0.copy() coefs_grid_2 = coefs_grid_2_0.copy() - logger.info('Fitting SURE on grid.') + logger.info("Fitting SURE on grid.") for j, alpha in enumerate(alpha_grid): - logger.info('alpha: %s' % alpha) + logger.info("alpha: %s" % alpha) if active_sets[j].sum() > 0: w = gprime(coefs_grid_1[j]) - X, a_set = _run_solver(alpha, M, n_mxne_iter - 1, - w_init=w) + X, a_set = _run_solver(alpha, M, n_mxne_iter - 1, w_init=w) coefs_grid_1[j][a_set, :] = X active_sets[j] = a_set if active_sets_eps[j].sum() > 0: w_eps = gprime(coefs_grid_2[j]) - X_eps, a_set_eps = _run_solver(alpha, M_eps, - n_mxne_iter - 1, - w_init=w_eps) + X_eps, a_set_eps = _run_solver( + alpha, M_eps, n_mxne_iter - 1, w_init=w_eps + ) coefs_grid_2[j][a_set_eps, :] = X_eps active_sets_eps[j] = a_set_eps @@ -865,8 +1074,8 @@ def _compute_sure_val(coef1, coef2, gain, M, sigma, delta, eps): n_sensors, n_times = gain.shape[0], M.shape[1] dof = (gain @ (coef2 - coef1) * delta).sum() / eps df_term = np.linalg.norm(M - gain @ coef1) ** 2 - sure = df_term - n_sensors * n_times * sigma ** 2 - sure += 2 * dof * sigma ** 2 + sure = df_term - n_sensors * n_times * sigma**2 + sure += 2 * dof * sigma**2 return sure sure_path = np.empty(len(alpha_grid)) @@ -880,8 +1089,7 @@ def _compute_sure_val(coef1, coef2, gain, M, sigma, delta, eps): logger.info("Computing SURE values on grid.") for i, (coef1, coef2) in enumerate(zip(coefs_grid_1, coefs_grid_2)): - sure_path[i] = _compute_sure_val( - coef1, coef2, gain, M, sigma, delta, eps) + sure_path[i] = _compute_sure_val(coef1, coef2, gain, M, sigma, delta, eps) if verbose: logger.info("alpha %s :: sure %s" % (alpha_grid[i], sure_path[i])) best_alpha_ = alpha_grid[np.argmin(sure_path)] diff --git a/mne/inverse_sparse/mxne_optim.py b/mne/inverse_sparse/mxne_optim.py index bff7a909781..e4e29912b68 100644 --- a/mne/inverse_sparse/mxne_optim.py +++ b/mne/inverse_sparse/mxne_optim.py @@ -9,14 +9,21 @@ import numpy as np from .mxne_debiasing import compute_bias -from ..utils import (logger, verbose, sum_squared, warn, _get_blas_funcs, - _validate_type, _check_option) +from ..utils import ( + logger, + verbose, + sum_squared, + warn, + _get_blas_funcs, + _validate_type, + _check_option, +) from ..time_frequency._stft import stft_norm1, stft_norm2, stft, istft @functools.lru_cache(None) def _get_dgemm(): - return _get_blas_funcs(np.float64, 'gemm') + return _get_blas_funcs(np.float64, "gemm") def groups_norm2(A, n_orient): @@ -121,23 +128,37 @@ def dgap_l21(M, G, X, active_set, alpha, n_orient): dual_norm = norm_l2inf(np.dot(G.T, R), n_orient, copy=False) scaling = alpha / dual_norm scaling = min(scaling, 1.0) - d_obj = (scaling - 0.5 * (scaling ** 2)) * nR2 + scaling * np.sum(R * GX) + d_obj = (scaling - 0.5 * (scaling**2)) * nR2 + scaling * np.sum(R * GX) gap = p_obj - d_obj return gap, p_obj, d_obj, R @verbose -def _mixed_norm_solver_cd(M, G, alpha, lipschitz_constant, maxit=10000, - tol=1e-8, verbose=None, init=None, n_orient=1, - dgap_freq=10): +def _mixed_norm_solver_cd( + M, + G, + alpha, + lipschitz_constant, + maxit=10000, + tol=1e-8, + verbose=None, + init=None, + n_orient=1, + dgap_freq=10, +): """Solve L21 inverse problem with coordinate descent.""" from sklearn.linear_model import MultiTaskLasso assert M.ndim == G.ndim and M.shape[0] == G.shape[0] - clf = MultiTaskLasso(alpha=alpha / len(M), tol=tol / sum_squared(M), - fit_intercept=False, max_iter=maxit, warm_start=True) + clf = MultiTaskLasso( + alpha=alpha / len(M), + tol=tol / sum_squared(M), + fit_intercept=False, + max_iter=maxit, + warm_start=True, + ) if init is not None: clf.coef_ = init.T else: @@ -152,9 +173,20 @@ def _mixed_norm_solver_cd(M, G, alpha, lipschitz_constant, maxit=10000, @verbose -def _mixed_norm_solver_bcd(M, G, alpha, lipschitz_constant, maxit=200, - tol=1e-8, verbose=None, init=None, n_orient=1, - dgap_freq=10, use_accel=True, K=5): +def _mixed_norm_solver_bcd( + M, + G, + alpha, + lipschitz_constant, + maxit=200, + tol=1e-8, + verbose=None, + init=None, + n_orient=1, + dgap_freq=10, + use_accel=True, + K=5, +): """Solve L21 inverse problem with block coordinate descent.""" _, n_times = M.shape _, n_sources = G.shape @@ -168,7 +200,7 @@ def _mixed_norm_solver_bcd(M, G, alpha, lipschitz_constant, maxit=200, R = M - np.dot(G, X) E = [] # track primal objective function - highest_d_obj = - np.inf + highest_d_obj = -np.inf active_set = np.zeros(n_sources, dtype=bool) # start with full AS alpha_lc = alpha / lipschitz_constant @@ -182,7 +214,7 @@ def _mixed_norm_solver_bcd(M, G, alpha, lipschitz_constant, maxit=200, # Ensure these are correct for dgemm assert R.dtype == np.float64 assert G.dtype == np.float64 - one_ovr_lc = 1. / lipschitz_constant + one_ovr_lc = 1.0 / lipschitz_constant # assert that all the multiplied matrices are fortran contiguous assert X.T.flags.f_contiguous @@ -198,17 +230,19 @@ def _mixed_norm_solver_bcd(M, G, alpha, lipschitz_constant, maxit=200, _bcd(G, X, R, active_set, one_ovr_lc, n_orient, alpha_lc, list_G_j_c) if (i + 1) % dgap_freq == 0: - _, p_obj, d_obj, _ = dgap_l21(M, G, X[active_set], active_set, - alpha, n_orient) + _, p_obj, d_obj, _ = dgap_l21( + M, G, X[active_set], active_set, alpha, n_orient + ) highest_d_obj = max(d_obj, highest_d_obj) gap = p_obj - highest_d_obj E.append(p_obj) - logger.debug("Iteration %d :: p_obj %f :: dgap %f :: n_active %d" % - (i + 1, p_obj, gap, np.sum(active_set) / n_orient)) + logger.debug( + "Iteration %d :: p_obj %f :: dgap %f :: n_active %d" + % (i + 1, p_obj, gap, np.sum(active_set) / n_orient) + ) if gap < tol: - logger.debug('Convergence reached ! (gap: %s < %s)' - % (gap, tol)) + logger.debug("Convergence reached ! (gap: %s < %s)" % (gap, tol)) break # using Anderson acceleration of the primal variable for faster @@ -230,19 +264,17 @@ def _mixed_norm_solver_bcd(M, G, alpha, lipschitz_constant, maxit=200, continue z = ((u * 1 / s) @ u.T).sum(0) c = z / z.sum() - X_acc = np.sum( - last_K_X[:-1] * c[:, None, None], axis=0 - ) + X_acc = np.sum(last_K_X[:-1] * c[:, None, None], axis=0) _grp_norm2_acc = groups_norm2(X_acc, n_orient) active_set_acc = _grp_norm2_acc != 0 if n_orient > 1: active_set_acc = np.kron( active_set_acc, np.ones(n_orient, dtype=bool) ) - p_obj = _primal_l21(M, G, X[active_set], active_set, alpha, - n_orient)[0] - p_obj_acc = _primal_l21(M, G, X_acc[active_set_acc], - active_set_acc, alpha, n_orient)[0] + p_obj = _primal_l21(M, G, X[active_set], active_set, alpha, n_orient)[0] + p_obj_acc = _primal_l21( + M, G, X_acc[active_set_acc], active_set_acc, alpha, n_orient + )[0] if p_obj_acc < p_obj: X = X_acc active_set = active_set_acc @@ -278,43 +310,54 @@ def _bcd(G, X, R, active_set, one_ovr_lc, n_orient, alpha_lc, list_G_j_c): alpha_lc: array, shape (n_positions, ) alpha * (Lipschitz constants). """ - X_j_new = np.zeros_like(X[:n_orient, :], order='C') + X_j_new = np.zeros_like(X[:n_orient, :], order="C") dgemm = _get_dgemm() for j, G_j_c in enumerate(list_G_j_c): idx = slice(j * n_orient, (j + 1) * n_orient) G_j = G[:, idx] X_j = X[idx] - dgemm(alpha=one_ovr_lc[j], beta=0., a=R.T, b=G_j, c=X_j_new.T, - overwrite_c=True) + dgemm( + alpha=one_ovr_lc[j], beta=0.0, a=R.T, b=G_j, c=X_j_new.T, overwrite_c=True + ) # X_j_new = G_j.T @ R # Mathurin's trick to avoid checking all the entries was_non_zero = X_j[0, 0] != 0 # was_non_zero = np.any(X_j) if was_non_zero: - dgemm(alpha=1., beta=1., a=X_j.T, b=G_j_c.T, c=R.T, - overwrite_c=True) + dgemm(alpha=1.0, beta=1.0, a=X_j.T, b=G_j_c.T, c=R.T, overwrite_c=True) # R += np.dot(G_j, X_j) X_j_new += X_j block_norm = sqrt(sum_squared(X_j_new)) if block_norm <= alpha_lc[j]: - X_j.fill(0.) + X_j.fill(0.0) active_set[idx] = False else: shrink = max(1.0 - alpha_lc[j] / block_norm, 0.0) X_j_new *= shrink - dgemm(alpha=-1., beta=1., a=X_j_new.T, b=G_j_c.T, c=R.T, - overwrite_c=True) + dgemm(alpha=-1.0, beta=1.0, a=X_j_new.T, b=G_j_c.T, c=R.T, overwrite_c=True) # R -= np.dot(G_j, X_j_new) X_j[:] = X_j_new active_set[idx] = True @verbose -def mixed_norm_solver(M, G, alpha, maxit=3000, tol=1e-8, verbose=None, - active_set_size=50, debias=True, n_orient=1, - solver='auto', return_gap=False, dgap_freq=10, - active_set_init=None, X_init=None): +def mixed_norm_solver( + M, + G, + alpha, + maxit=3000, + tol=1e-8, + verbose=None, + active_set_size=50, + debias=True, + n_orient=1, + solver="auto", + return_gap=False, + dgap_freq=10, + active_set_init=None, + X_init=None, +): """Solve L1/L2 mixed-norm inverse problem with active set strategy. See references :footcite:`GramfortEtAl2012,StrohmeierEtAl2016, @@ -385,31 +428,35 @@ def mixed_norm_solver(M, G, alpha, maxit=3000, tol=1e-8, verbose=None, except ImportError: has_sklearn = False - _validate_type(solver, str, 'solver') - _check_option('solver', solver, ('cd', 'bcd', 'auto')) - if solver == 'auto': + _validate_type(solver, str, "solver") + _check_option("solver", solver, ("cd", "bcd", "auto")) + if solver == "auto": if has_sklearn and (n_orient == 1): - solver = 'cd' + solver = "cd" else: - solver = 'bcd' + solver = "bcd" - if solver == 'cd': + if solver == "cd": if n_orient == 1 and not has_sklearn: - warn('Scikit-learn >= 0.12 cannot be found. Using block coordinate' - ' descent instead of coordinate descent.') - solver = 'bcd' + warn( + "Scikit-learn >= 0.12 cannot be found. Using block coordinate" + " descent instead of coordinate descent." + ) + solver = "bcd" if n_orient > 1: - warn('Coordinate descent is only available for fixed orientation. ' - 'Using block coordinate descent instead of coordinate ' - 'descent') - solver = 'bcd' - - if solver == 'cd': + warn( + "Coordinate descent is only available for fixed orientation. " + "Using block coordinate descent instead of coordinate " + "descent" + ) + solver = "bcd" + + if solver == "cd": logger.info("Using coordinate descent") l21_solver = _mixed_norm_solver_cd lc = None else: - assert solver == 'bcd' + assert solver == "bcd" logger.info("Using block coordinate descent") l21_solver = _mixed_norm_solver_bcd G = np.asfortranarray(G) @@ -418,59 +465,77 @@ def mixed_norm_solver(M, G, alpha, maxit=3000, tol=1e-8, verbose=None, else: lc = np.empty(n_positions) for j in range(n_positions): - G_tmp = G[:, (j * n_orient):((j + 1) * n_orient)] + G_tmp = G[:, (j * n_orient) : ((j + 1) * n_orient)] lc[j] = np.linalg.norm(np.dot(G_tmp.T, G_tmp), ord=2) if active_set_size is not None: E = list() - highest_d_obj = - np.inf + highest_d_obj = -np.inf if X_init is not None and X_init.shape != (n_dipoles, n_times): - raise ValueError('Wrong dim for initialized coefficients.') - active_set = (active_set_init if active_set_init is not None else - np.zeros(n_dipoles, dtype=bool)) + raise ValueError("Wrong dim for initialized coefficients.") + active_set = ( + active_set_init + if active_set_init is not None + else np.zeros(n_dipoles, dtype=bool) + ) idx_large_corr = np.argsort(groups_norm2(np.dot(G.T, M), n_orient)) new_active_idx = idx_large_corr[-active_set_size:] if n_orient > 1: - new_active_idx = (n_orient * new_active_idx[:, None] + - np.arange(n_orient)[None, :]).ravel() + new_active_idx = ( + n_orient * new_active_idx[:, None] + np.arange(n_orient)[None, :] + ).ravel() active_set[new_active_idx] = True as_size = np.sum(active_set) gap = np.inf for k in range(maxit): - if solver == 'bcd': + if solver == "bcd": lc_tmp = lc[active_set[::n_orient]] - elif solver == 'cd': + elif solver == "cd": lc_tmp = None else: lc_tmp = 1.01 * np.linalg.norm(G[:, active_set], ord=2) ** 2 - X, as_, _ = l21_solver(M, G[:, active_set], alpha, lc_tmp, - maxit=maxit, tol=tol, init=X_init, - n_orient=n_orient, dgap_freq=dgap_freq) + X, as_, _ = l21_solver( + M, + G[:, active_set], + alpha, + lc_tmp, + maxit=maxit, + tol=tol, + init=X_init, + n_orient=n_orient, + dgap_freq=dgap_freq, + ) active_set[active_set] = as_.copy() idx_old_active_set = np.where(active_set)[0] - _, p_obj, d_obj, R = dgap_l21(M, G, X, active_set, alpha, - n_orient) + _, p_obj, d_obj, R = dgap_l21(M, G, X, active_set, alpha, n_orient) highest_d_obj = max(d_obj, highest_d_obj) gap = p_obj - highest_d_obj E.append(p_obj) - logger.info("Iteration %d :: p_obj %f :: dgap %f :: " - "n_active_start %d :: n_active_end %d" % ( - k + 1, p_obj, gap, as_size // n_orient, - np.sum(active_set) // n_orient)) + logger.info( + "Iteration %d :: p_obj %f :: dgap %f :: " + "n_active_start %d :: n_active_end %d" + % ( + k + 1, + p_obj, + gap, + as_size // n_orient, + np.sum(active_set) // n_orient, + ) + ) if gap < tol: - logger.info('Convergence reached ! (gap: %s < %s)' - % (gap, tol)) + logger.info("Convergence reached ! (gap: %s < %s)" % (gap, tol)) break # add sources if not last iteration if k < (maxit - 1): - idx_large_corr = np.argsort(groups_norm2(np.dot(G.T, R), - n_orient)) + idx_large_corr = np.argsort(groups_norm2(np.dot(G.T, R), n_orient)) new_active_idx = idx_large_corr[-active_set_size:] if n_orient > 1: - new_active_idx = (n_orient * new_active_idx[:, None] + - np.arange(n_orient)[None, :]) + new_active_idx = ( + n_orient * new_active_idx[:, None] + + np.arange(n_orient)[None, :] + ) new_active_idx = new_active_idx.ravel() active_set[new_active_idx] = True idx_active_set = np.where(active_set)[0] @@ -479,10 +544,11 @@ def mixed_norm_solver(M, G, alpha, maxit=3000, tol=1e-8, verbose=None, idx = np.searchsorted(idx_active_set, idx_old_active_set) X_init[idx] = X else: - warn('Did NOT converge ! (gap: %s > %s)' % (gap, tol)) + warn("Did NOT converge ! (gap: %s > %s)" % (gap, tol)) else: - X, active_set, E = l21_solver(M, G, alpha, lc, maxit=maxit, - tol=tol, n_orient=n_orient, init=None) + X, active_set, E = l21_solver( + M, G, alpha, lc, maxit=maxit, tol=tol, n_orient=n_orient, init=None + ) if return_gap: gap = dgap_l21(M, G, X, active_set, alpha, n_orient)[0] @@ -490,7 +556,7 @@ def mixed_norm_solver(M, G, alpha, maxit=3000, tol=1e-8, verbose=None, bias = compute_bias(M, G[:, active_set], X, n_orient=n_orient) X *= bias[:, np.newaxis] - logger.info('Final active set size: %s' % (np.sum(active_set) // n_orient)) + logger.info("Final active set size: %s" % (np.sum(active_set) // n_orient)) if return_gap: return X, active_set, E, gap @@ -499,10 +565,21 @@ def mixed_norm_solver(M, G, alpha, maxit=3000, tol=1e-8, verbose=None, @verbose -def iterative_mixed_norm_solver(M, G, alpha, n_mxne_iter, maxit=3000, - tol=1e-8, verbose=None, active_set_size=50, - debias=True, n_orient=1, dgap_freq=10, - solver='auto', weight_init=None): +def iterative_mixed_norm_solver( + M, + G, + alpha, + n_mxne_iter, + maxit=3000, + tol=1e-8, + verbose=None, + active_set_size=50, + debias=True, + n_orient=1, + dgap_freq=10, + solver="auto", + weight_init=None, +): """Solve L0.5/L2 mixed-norm inverse problem with active set strategy. See reference :footcite:`StrohmeierEtAl2016`. @@ -551,20 +628,23 @@ def iterative_mixed_norm_solver(M, G, alpha, n_mxne_iter, maxit=3000, ---------- .. footbibliography:: """ + def g(w): return np.sqrt(np.sqrt(groups_norm2(w.copy(), n_orient))) def gprime(w): - return 2. * np.repeat(g(w), n_orient).ravel() + return 2.0 * np.repeat(g(w), n_orient).ravel() E = list() if weight_init is not None and weight_init.shape != (G.shape[1],): - raise ValueError('Wrong dimension for weight initialization. Got %s. ' - 'Expected %s.' % (weight_init.shape, (G.shape[1],))) + raise ValueError( + "Wrong dimension for weight initialization. Got %s. " + "Expected %s." % (weight_init.shape, (G.shape[1],)) + ) weights = weight_init if weight_init is not None else np.ones(G.shape[1]) - active_set = (weights != 0) + active_set = weights != 0 weights = weights[active_set] X = np.zeros((G.shape[1], M.shape[1])) @@ -576,39 +656,70 @@ def gprime(w): if active_set_size is not None: if np.sum(active_set) > (active_set_size * n_orient): X, _active_set, _ = mixed_norm_solver( - M, G_tmp, alpha, debias=False, n_orient=n_orient, - maxit=maxit, tol=tol, active_set_size=active_set_size, - dgap_freq=dgap_freq, solver=solver, verbose=verbose) + M, + G_tmp, + alpha, + debias=False, + n_orient=n_orient, + maxit=maxit, + tol=tol, + active_set_size=active_set_size, + dgap_freq=dgap_freq, + solver=solver, + verbose=verbose, + ) else: X, _active_set, _ = mixed_norm_solver( - M, G_tmp, alpha, debias=False, n_orient=n_orient, - maxit=maxit, tol=tol, active_set_size=None, - dgap_freq=dgap_freq, solver=solver, verbose=verbose) + M, + G_tmp, + alpha, + debias=False, + n_orient=n_orient, + maxit=maxit, + tol=tol, + active_set_size=None, + dgap_freq=dgap_freq, + solver=solver, + verbose=verbose, + ) else: X, _active_set, _ = mixed_norm_solver( - M, G_tmp, alpha, debias=False, n_orient=n_orient, - maxit=maxit, tol=tol, active_set_size=None, - dgap_freq=dgap_freq, solver=solver, verbose=verbose) - - logger.info('active set size %d' % (_active_set.sum() / n_orient)) + M, + G_tmp, + alpha, + debias=False, + n_orient=n_orient, + maxit=maxit, + tol=tol, + active_set_size=None, + dgap_freq=dgap_freq, + solver=solver, + verbose=verbose, + ) + + logger.info("active set size %d" % (_active_set.sum() / n_orient)) if _active_set.sum() > 0: active_set[active_set] = _active_set # Reapply weights to have correct unit X *= weights[_active_set][:, np.newaxis] weights = gprime(X) - p_obj = 0.5 * np.linalg.norm(M - np.dot(G[:, active_set], X), - 'fro') ** 2. + alpha * np.sum(g(X)) + p_obj = 0.5 * np.linalg.norm( + M - np.dot(G[:, active_set], X), "fro" + ) ** 2.0 + alpha * np.sum(g(X)) E.append(p_obj) # Check convergence - if ((k >= 1) and np.all(active_set == active_set_0) and - np.all(np.abs(X - X0) < tol)): - print('Convergence reached after %d reweightings!' % k) + if ( + (k >= 1) + and np.all(active_set == active_set_0) + and np.all(np.abs(X - X0) < tol) + ): + print("Convergence reached after %d reweightings!" % k) break else: active_set = np.zeros_like(active_set) - p_obj = 0.5 * np.linalg.norm(M) ** 2. + p_obj = 0.5 * np.linalg.norm(M) ** 2.0 E.append(p_obj) break @@ -622,6 +733,7 @@ def gprime(w): ############################################################################### # TF-MxNE + @verbose def tf_lipschitz_constant(M, G, phi, phiT, tol=1e-3, verbose=None): """Compute lipschitz constant for FISTA. @@ -635,7 +747,7 @@ def tf_lipschitz_constant(M, G, phi, phiT, tol=1e-3, verbose=None): L = 1e100 for it in range(100): L_old = L - logger.info('Lipschitz estimation: iteration = %d' % it) + logger.info("Lipschitz estimation: iteration = %d" % it) iv = np.real(phiT(v)) Gv = np.dot(G, iv) GtGv = np.dot(G.T, Gv) @@ -652,7 +764,7 @@ def safe_max_abs(A, ia): if np.sum(ia): # ia is not empty return np.max(np.abs(A[ia])) else: - return 0. + return 0.0 def safe_max_abs_diff(A, ia, B, ib): @@ -677,31 +789,29 @@ def __init__(self, wsize, tstep, n_coefs, n_times): # noqa: D102 self.ops = list() for ws, ts in zip(self.wsize, self.tstep): self.ops.append( - stft(np.eye(n_times), ws, ts, - verbose=False).reshape(n_times, -1)) + stft(np.eye(n_times), ws, ts, verbose=False).reshape(n_times, -1) + ) def __call__(self, x): # noqa: D105 if self.n_dicts == 1: return x @ self.ops[0] else: - return np.hstack( - [x @ op for op in self.ops]) / np.sqrt(self.n_dicts) + return np.hstack([x @ op for op in self.ops]) / np.sqrt(self.n_dicts) def norm(self, z, ord=2): """Squared L2 norm if ord == 2 and L1 norm if order == 1.""" if ord not in (1, 2): - raise ValueError('Only supported norm order are 1 and 2. ' - 'Got ord = %s' % ord) + raise ValueError( + "Only supported norm order are 1 and 2. " "Got ord = %s" % ord + ) stft_norm = stft_norm1 if ord == 1 else stft_norm2 - norm = 0. + norm = 0.0 if len(self.n_coefs) > 1: - z_ = np.array_split(np.atleast_2d(z), np.cumsum(self.n_coefs)[:-1], - axis=1) + z_ = np.array_split(np.atleast_2d(z), np.cumsum(self.n_coefs)[:-1], axis=1) else: z_ = [np.atleast_2d(z)] for i in range(len(z_)): - norm += stft_norm( - z_[i].reshape(-1, self.n_freqs[i], self.n_steps[i])) + norm += stft_norm(z_[i].reshape(-1, self.n_freqs[i], self.n_steps[i])) return norm @@ -721,10 +831,8 @@ def __init__(self, tstep, n_freqs, n_steps, n_times): # noqa: D102 nc = nf * ns self.n_coefs.append(nc) eye = np.eye(nc).reshape(nf, ns, nf, ns) - self.op_re.append(istft( - eye, ts, n_times).reshape(nc, n_times)) - self.op_im.append(istft( - eye * 1j, ts, n_times).reshape(nc, n_times)) + self.op_re.append(istft(eye, ts, n_times).reshape(nc, n_times)) + self.op_im.append(istft(eye * 1j, ts, n_times).reshape(nc, n_times)) def __call__(self, z): # noqa: D105 if self.n_dicts == 1: @@ -740,13 +848,12 @@ def __call__(self, z): # noqa: D105 def norm_l21_tf(Z, phi, n_orient, w_space=None): """L21 norm for TF.""" if Z.shape[0]: - l21_norm = np.sqrt( - phi.norm(Z, ord=2).reshape(-1, n_orient).sum(axis=1)) + l21_norm = np.sqrt(phi.norm(Z, ord=2).reshape(-1, n_orient).sum(axis=1)) if w_space is not None: l21_norm *= w_space l21_norm = l21_norm.sum() else: - l21_norm = 0. + l21_norm = 0.0 return l21_norm @@ -754,18 +861,19 @@ def norm_l1_tf(Z, phi, n_orient, w_time): """L1 norm for TF.""" if Z.shape[0]: n_positions = Z.shape[0] // n_orient - Z_ = np.sqrt(np.sum( - (np.abs(Z) ** 2.).reshape((n_orient, -1), order='F'), axis=0)) - Z_ = Z_.reshape((n_positions, -1), order='F') + Z_ = np.sqrt( + np.sum((np.abs(Z) ** 2.0).reshape((n_orient, -1), order="F"), axis=0) + ) + Z_ = Z_.reshape((n_positions, -1), order="F") if w_time is not None: Z_ *= w_time l1_norm = phi.norm(Z_, ord=1).sum() else: - l1_norm = 0. + l1_norm = 0.0 return l1_norm -def norm_epsilon(Y, l1_ratio, phi, w_space=1., w_time=None): +def norm_epsilon(Y, l1_ratio, phi, w_space=1.0, w_time=None): """Weighted epsilon norm. The weighted epsilon norm is the dual norm of:: @@ -810,35 +918,35 @@ def norm_epsilon(Y, l1_ratio, phi, w_space=1., w_time=None): # Add negative freqs: count all freqs twice except first and last: freqs_count = np.full(len(Y), 2) - for i, fc in enumerate(np.array_split(freqs_count, - np.cumsum(phi.n_coefs)[:-1])): - fc[:phi.n_steps[i]] = 1 - fc[-phi.n_steps[i]:] = 1 + for i, fc in enumerate(np.array_split(freqs_count, np.cumsum(phi.n_coefs)[:-1])): + fc[: phi.n_steps[i]] = 1 + fc[-phi.n_steps[i] :] = 1 # exclude 0 weights: if w_time is not None: - nonzero_weights = (w_time != 0.0) + nonzero_weights = w_time != 0.0 Y = Y[nonzero_weights] freqs_count = freqs_count[nonzero_weights] w_time = w_time[nonzero_weights] norm_inf_Y = np.max(Y / w_time) if w_time is not None else np.max(Y) - if l1_ratio == 1.: + if l1_ratio == 1.0: # dual norm of L1 weighted is Linf with inverse weights return norm_inf_Y - elif l1_ratio == 0.: + elif l1_ratio == 0.0: # dual norm of L2 is L2 return np.sqrt(phi.norm(Y[None, :], ord=2).sum()) - if norm_inf_Y == 0.: - return 0. + if norm_inf_Y == 0.0: + return 0.0 # ignore some values of Y by lower bound on dual norm: if w_time is None: idx = Y > l1_ratio * norm_inf_Y else: - idx = Y > l1_ratio * np.max(Y / (w_space * (1. - l1_ratio) + - l1_ratio * w_time)) + idx = Y > l1_ratio * np.max( + Y / (w_space * (1.0 - l1_ratio) + l1_ratio * w_time) + ) if idx.sum() == 1: return norm_inf_Y @@ -859,18 +967,18 @@ def norm_epsilon(Y, l1_ratio, phi, w_space=1., w_time=None): K = Y.shape[0] if w_time is None: - p_sum_Y2 = np.cumsum(Y ** 2) + p_sum_Y2 = np.cumsum(Y**2) p_sum_w2 = np.arange(1, K + 1) p_sum_Yw = np.cumsum(Y) - upper = p_sum_Y2 / Y ** 2 - 2. * p_sum_Yw / Y + p_sum_w2 + upper = p_sum_Y2 / Y**2 - 2.0 * p_sum_Yw / Y + p_sum_w2 else: - p_sum_Y2 = np.cumsum(Y ** 2) - p_sum_w2 = np.cumsum(w_time ** 2) + p_sum_Y2 = np.cumsum(Y**2) + p_sum_w2 = np.cumsum(w_time**2) p_sum_Yw = np.cumsum(Y * w_time) - upper = (p_sum_Y2 / (Y / w_time) ** 2 - - 2. * p_sum_Yw / (Y / w_time) + p_sum_w2) - upper_greater = np.where(upper > w_space ** 2 * (1. - l1_ratio) ** 2 / - l1_ratio ** 2)[0] + upper = p_sum_Y2 / (Y / w_time) ** 2 - 2.0 * p_sum_Yw / (Y / w_time) + p_sum_w2 + upper_greater = np.where( + upper > w_space**2 * (1.0 - l1_ratio) ** 2 / l1_ratio**2 + )[0] i0 = upper_greater[0] - 1 if upper_greater.size else K - 1 @@ -878,9 +986,9 @@ def norm_epsilon(Y, l1_ratio, phi, w_space=1., w_time=None): p_sum_w2 = p_sum_w2[i0] p_sum_Yw = p_sum_Yw[i0] - denom = l1_ratio ** 2 * p_sum_w2 - w_space ** 2 * (1. - l1_ratio) ** 2 + denom = l1_ratio**2 * p_sum_w2 - w_space**2 * (1.0 - l1_ratio) ** 2 if np.abs(denom) < 1e-10: - return p_sum_Y2 / (2. * l1_ratio * p_sum_Yw) + return p_sum_Y2 / (2.0 * l1_ratio * p_sum_Yw) else: delta = (l1_ratio * p_sum_Yw) ** 2 - p_sum_Y2 * denom return (l1_ratio * p_sum_Yw - np.sqrt(delta)) / denom @@ -918,24 +1026,35 @@ def norm_epsilon_inf(G, R, phi, l1_ratio, n_orient, w_space=None, w_time=None): n_positions = G.shape[1] // n_orient GTRPhi = np.abs(phi(np.dot(G.T, R))) # norm over orientations: - GTRPhi = GTRPhi.reshape((n_orient, -1), order='F') + GTRPhi = GTRPhi.reshape((n_orient, -1), order="F") GTRPhi = np.linalg.norm(GTRPhi, axis=0) - GTRPhi = GTRPhi.reshape((n_positions, -1), order='F') - nu = 0. + GTRPhi = GTRPhi.reshape((n_positions, -1), order="F") + nu = 0.0 for idx in range(n_positions): GTRPhi_ = GTRPhi[idx] w_t = w_time[idx] if w_time is not None else None - w_s = w_space[idx] if w_space is not None else 1. - norm_eps = norm_epsilon(GTRPhi_, l1_ratio, phi, w_space=w_s, - w_time=w_t) + w_s = w_space[idx] if w_space is not None else 1.0 + norm_eps = norm_epsilon(GTRPhi_, l1_ratio, phi, w_space=w_s, w_time=w_t) if norm_eps > nu: nu = norm_eps return nu -def dgap_l21l1(M, G, Z, active_set, alpha_space, alpha_time, phi, phiT, - n_orient, highest_d_obj, w_space=None, w_time=None): +def dgap_l21l1( + M, + G, + Z, + active_set, + alpha_space, + alpha_time, + phi, + phiT, + n_orient, + highest_d_obj, + w_space=None, + w_time=None, +): """Duality gap for the time-frequency mixed norm inverse problem. See :footcite:`GramfortEtAl2012,NdiayeEtAl2016` @@ -1003,29 +1122,45 @@ def dgap_l21l1(M, G, Z, active_set, alpha_space, alpha_time, phi, phiT, p_obj = 0.5 * nR2 + alpha_space * penaltyl21 + alpha_time * penaltyl1 l1_ratio = alpha_time / (alpha_space + alpha_time) - dual_norm = norm_epsilon_inf(G, R, phi, l1_ratio, n_orient, - w_space=w_space, w_time=w_time) - scaling = min(1., (alpha_space + alpha_time) / dual_norm) + dual_norm = norm_epsilon_inf( + G, R, phi, l1_ratio, n_orient, w_space=w_space, w_time=w_time + ) + scaling = min(1.0, (alpha_space + alpha_time) / dual_norm) - d_obj = (scaling - 0.5 * (scaling ** 2)) * nR2 + scaling * np.sum(R * GX) + d_obj = (scaling - 0.5 * (scaling**2)) * nR2 + scaling * np.sum(R * GX) d_obj = max(d_obj, highest_d_obj) gap = p_obj - d_obj return gap, p_obj, d_obj, R -def _tf_mixed_norm_solver_bcd_(M, G, Z, active_set, candidates, alpha_space, - alpha_time, lipschitz_constant, phi, phiT, - w_space=None, w_time=None, n_orient=1, - maxit=200, tol=1e-8, dgap_freq=10, perc=None, - timeit=True, verbose=None): +def _tf_mixed_norm_solver_bcd_( + M, + G, + Z, + active_set, + candidates, + alpha_space, + alpha_time, + lipschitz_constant, + phi, + phiT, + w_space=None, + w_time=None, + n_orient=1, + maxit=200, + tol=1e-8, + dgap_freq=10, + perc=None, + timeit=True, + verbose=None, +): n_sources = G.shape[1] n_positions = n_sources // n_orient # First make G fortran for faster access to blocks of columns Gd = np.asfortranarray(G) - G = np.ascontiguousarray( - Gd.T.reshape(n_positions, n_orient, -1).transpose(0, 2, 1)) + G = np.ascontiguousarray(Gd.T.reshape(n_positions, n_orient, -1).transpose(0, 2, 1)) R = M.copy() # residual active = np.where(active_set[::n_orient])[0] @@ -1044,7 +1179,7 @@ def _tf_mixed_norm_solver_bcd_(M, G, Z, active_set, candidates, alpha_space, alpha_space_lc = alpha_space * w_space / lipschitz_constant converged = False - d_obj = - np.inf + d_obj = -np.inf for i in range(maxit): for jj in candidates: @@ -1066,7 +1201,7 @@ def _tf_mixed_norm_solver_bcd_(M, G, Z, active_set, candidates, alpha_space, R += np.dot(G_j, X_j) X_j_new += X_j - rows_norm = np.linalg.norm(X_j_new, 'fro') + rows_norm = np.linalg.norm(X_j_new, "fro") if rows_norm <= alpha_space_lc[jj]: if was_active: Z[jj] = 0.0 @@ -1084,8 +1219,11 @@ def _tf_mixed_norm_solver_bcd_(M, G, Z, active_set, candidates, alpha_space, active_set_j[:] = False else: # l1 - shrink = np.maximum(1.0 - alpha_time_lc[jj] / np.maximum( - col_norm, alpha_time_lc[jj]), 0.0) + shrink = np.maximum( + 1.0 + - alpha_time_lc[jj] / np.maximum(col_norm, alpha_time_lc[jj]), + 0.0, + ) if w_time is not None: shrink[w_time[jj] == 0.0] = 0.0 Z_j_new *= shrink[np.newaxis, :] @@ -1098,8 +1236,11 @@ def _tf_mixed_norm_solver_bcd_(M, G, Z, active_set, candidates, alpha_space, active_set_j[:] = False else: shrink = np.maximum( - 1.0 - alpha_space_lc[jj] / - np.maximum(row_norm, alpha_space_lc[jj]), 0.0) + 1.0 + - alpha_space_lc[jj] + / np.maximum(row_norm, alpha_space_lc[jj]), + 0.0, + ) Z_j_new *= shrink Z[jj] = Z_j_new.reshape(-1, *shape_init[1:]).copy() active_set_j[:] = True @@ -1107,17 +1248,28 @@ def _tf_mixed_norm_solver_bcd_(M, G, Z, active_set, candidates, alpha_space, R -= np.dot(G_j, Z_j_phi_T) if (i + 1) % dgap_freq == 0: - Zd = np.vstack([Z[pos] for pos in range(n_positions) - if np.any(Z[pos])]) + Zd = np.vstack([Z[pos] for pos in range(n_positions) if np.any(Z[pos])]) gap, p_obj, d_obj, _ = dgap_l21l1( - M, Gd, Zd, active_set, alpha_space, alpha_time, phi, phiT, - n_orient, d_obj, w_space=w_space, w_time=w_time) - converged = (gap < tol) + M, + Gd, + Zd, + active_set, + alpha_space, + alpha_time, + phi, + phiT, + n_orient, + d_obj, + w_space=w_space, + w_time=w_time, + ) + converged = gap < tol E.append(p_obj) - logger.info("\n Iteration %d :: n_active %d" % ( - i + 1, np.sum(active_set) / n_orient)) - logger.info(" dgap %.2e :: p_obj %f :: d_obj %f" % ( - gap, p_obj, d_obj)) + logger.info( + "\n Iteration %d :: n_active %d" + % (i + 1, np.sum(active_set) / n_orient) + ) + logger.info(" dgap %.2e :: p_obj %f :: d_obj %f" % (gap, p_obj, d_obj)) if converged: break @@ -1130,13 +1282,23 @@ def _tf_mixed_norm_solver_bcd_(M, G, Z, active_set, candidates, alpha_space, @verbose -def _tf_mixed_norm_solver_bcd_active_set(M, G, alpha_space, alpha_time, - lipschitz_constant, phi, phiT, - Z_init=None, w_space=None, - w_time=None, n_orient=1, maxit=200, - tol=1e-8, dgap_freq=10, - verbose=None): - +def _tf_mixed_norm_solver_bcd_active_set( + M, + G, + alpha_space, + alpha_time, + lipschitz_constant, + phi, + phiT, + Z_init=None, + w_space=None, + w_time=None, + n_orient=1, + maxit=200, + tol=1e-8, + dgap_freq=10, + verbose=None, +): n_sensors, n_times = M.shape n_sources = G.shape[1] n_positions = n_sources // n_orient @@ -1146,15 +1308,15 @@ def _tf_mixed_norm_solver_bcd_active_set(M, G, alpha_space, alpha_time, active = [] if Z_init is not None: if Z_init.shape != (n_sources, phi.n_coefs.sum()): - raise Exception('Z_init must be None or an array with shape ' - '(n_sources, n_coefs).') + raise Exception( + "Z_init must be None or an array with shape " "(n_sources, n_coefs)." + ) for ii in range(n_positions): - if np.any(Z_init[ii * n_orient:(ii + 1) * n_orient]): - active_set[ii * n_orient:(ii + 1) * n_orient] = True + if np.any(Z_init[ii * n_orient : (ii + 1) * n_orient]): + active_set[ii * n_orient : (ii + 1) * n_orient] = True active.append(ii) if len(active): - Z.update(dict(zip(active, - np.vsplit(Z_init[active_set], len(active))))) + Z.update(dict(zip(active, np.vsplit(Z_init[active_set], len(active))))) E = [] candidates = range(n_positions) @@ -1165,9 +1327,24 @@ def _tf_mixed_norm_solver_bcd_active_set(M, G, alpha_space, alpha_time, Z_init = dict.fromkeys(np.arange(n_positions), 0.0) Z_init.update(dict(zip(active, Z.values()))) Z, active_set, E_tmp, _ = _tf_mixed_norm_solver_bcd_( - M, G, Z_init, active_set, candidates, alpha_space, alpha_time, - lipschitz_constant, phi, phiT, w_space=w_space, w_time=w_time, - n_orient=n_orient, maxit=1, tol=tol, perc=None, verbose=verbose) + M, + G, + Z_init, + active_set, + candidates, + alpha_space, + alpha_time, + lipschitz_constant, + phi, + phiT, + w_space=w_space, + w_time=w_time, + n_orient=n_orient, + maxit=1, + tol=tol, + perc=None, + verbose=verbose, + ) E += E_tmp @@ -1185,14 +1362,25 @@ def _tf_mixed_norm_solver_bcd_active_set(M, G, alpha_space, alpha_time, w_time_as = None Z, as_, E_tmp, converged = _tf_mixed_norm_solver_bcd_( - M, G[:, active_set], Z_init, + M, + G[:, active_set], + Z_init, np.ones(len(active) * n_orient, dtype=bool), - candidates_, alpha_space, alpha_time, - lipschitz_constant[active_set[::n_orient]], phi, phiT, - w_space=w_space_as, w_time=w_time_as, - n_orient=n_orient, maxit=maxit, tol=tol, - dgap_freq=dgap_freq, perc=0.5, - verbose=verbose) + candidates_, + alpha_space, + alpha_time, + lipschitz_constant[active_set[::n_orient]], + phi, + phiT, + w_space=w_space_as, + w_time=w_time_as, + n_orient=n_orient, + maxit=maxit, + tol=tol, + dgap_freq=dgap_freq, + perc=0.5, + verbose=verbose, + ) active = np.where(active_set[::n_orient])[0] active_set[active_set] = as_.copy() E += E_tmp @@ -1201,10 +1389,23 @@ def _tf_mixed_norm_solver_bcd_active_set(M, G, alpha_space, alpha_time, if converged: Zd = np.vstack([Z[pos] for pos in range(len(Z)) if np.any(Z[pos])]) gap, p_obj, d_obj, _ = dgap_l21l1( - M, G, Zd, active_set, alpha_space, alpha_time, - phi, phiT, n_orient, d_obj, w_space, w_time) - logger.info("\ndgap %.2e :: p_obj %f :: d_obj %f :: n_active %d" - % (gap, p_obj, d_obj, np.sum(active_set) / n_orient)) + M, + G, + Zd, + active_set, + alpha_space, + alpha_time, + phi, + phiT, + n_orient, + d_obj, + w_space, + w_time, + ) + logger.info( + "\ndgap %.2e :: p_obj %f :: d_obj %f :: n_active %d" + % (gap, p_obj, d_obj, np.sum(active_set) / n_orient) + ) if gap < tol: logger.info("\nConvergence reached!\n") break @@ -1220,10 +1421,22 @@ def _tf_mixed_norm_solver_bcd_active_set(M, G, alpha_space, alpha_time, @verbose -def tf_mixed_norm_solver(M, G, alpha_space, alpha_time, wsize=64, tstep=4, - n_orient=1, maxit=200, tol=1e-8, - active_set_size=None, debias=True, return_gap=False, - dgap_freq=10, verbose=None): +def tf_mixed_norm_solver( + M, + G, + alpha_space, + alpha_time, + wsize=64, + tstep=4, + n_orient=1, + maxit=200, + tol=1e-8, + active_set_size=None, + debias=True, + return_gap=False, + dgap_freq=10, + verbose=None, +): """Solve TF L21+L1 inverse solver with BCD and active set approach. See :footcite:`GramfortEtAl2013b,GramfortEtAl2011,BekhtiEtAl2016`. @@ -1288,9 +1501,10 @@ def tf_mixed_norm_solver(M, G, alpha_space, alpha_time, wsize=64, tstep=4, tstep = np.atleast_1d(tstep) wsize = np.atleast_1d(wsize) if len(tstep) != len(wsize): - raise ValueError('The same number of window sizes and steps must be ' - 'passed. Got tstep = %s and wsize = %s' % - (tstep, wsize)) + raise ValueError( + "The same number of window sizes and steps must be " + "passed. Got tstep = %s and wsize = %s" % (tstep, wsize) + ) n_steps = np.ceil(M.shape[1] / tstep.astype(float)).astype(int) n_freqs = wsize // 2 + 1 @@ -1303,14 +1517,25 @@ def tf_mixed_norm_solver(M, G, alpha_space, alpha_time, wsize=64, tstep=4, else: lc = np.empty(n_positions) for j in range(n_positions): - G_tmp = G[:, (j * n_orient):((j + 1) * n_orient)] + G_tmp = G[:, (j * n_orient) : ((j + 1) * n_orient)] lc[j] = np.linalg.norm(np.dot(G_tmp.T, G_tmp), ord=2) logger.info("Using block coordinate descent with active set approach") X, Z, active_set, E, gap = _tf_mixed_norm_solver_bcd_active_set( - M, G, alpha_space, alpha_time, lc, phi, phiT, - Z_init=None, n_orient=n_orient, maxit=maxit, tol=tol, - dgap_freq=dgap_freq, verbose=None) + M, + G, + alpha_space, + alpha_time, + lc, + phi, + phiT, + Z_init=None, + n_orient=n_orient, + maxit=maxit, + tol=tol, + dgap_freq=dgap_freq, + verbose=None, + ) if np.any(active_set) and debias: bias = compute_bias(M, G[:, active_set], X, n_orient=n_orient) @@ -1323,10 +1548,21 @@ def tf_mixed_norm_solver(M, G, alpha_space, alpha_time, wsize=64, tstep=4, @verbose -def iterative_tf_mixed_norm_solver(M, G, alpha_space, alpha_time, - n_tfmxne_iter, wsize=64, tstep=4, - maxit=3000, tol=1e-8, debias=True, - n_orient=1, dgap_freq=10, verbose=None): +def iterative_tf_mixed_norm_solver( + M, + G, + alpha_space, + alpha_time, + n_tfmxne_iter, + wsize=64, + tstep=4, + maxit=3000, + tol=1e-8, + debias=True, + n_orient=1, + dgap_freq=10, + verbose=None, +): """Solve TF L0.5/L1 + L0.5 inverse problem with BCD + active set approach. Parameters @@ -1385,9 +1621,10 @@ def iterative_tf_mixed_norm_solver(M, G, alpha_space, alpha_time, tstep = np.atleast_1d(tstep) wsize = np.atleast_1d(wsize) if len(tstep) != len(wsize): - raise ValueError('The same number of window sizes and steps must be ' - 'passed. Got tstep = %s and wsize = %s' % - (tstep, wsize)) + raise ValueError( + "The same number of window sizes and steps must be " + "passed. Got tstep = %s and wsize = %s" % (tstep, wsize) + ) n_steps = np.ceil(n_times / tstep.astype(float)).astype(int) n_freqs = wsize // 2 + 1 @@ -1400,24 +1637,25 @@ def iterative_tf_mixed_norm_solver(M, G, alpha_space, alpha_time, else: lc = np.empty(n_positions) for j in range(n_positions): - G_tmp = G[:, (j * n_orient):((j + 1) * n_orient)] + G_tmp = G[:, (j * n_orient) : ((j + 1) * n_orient)] lc[j] = np.linalg.norm(np.dot(G_tmp.T, G_tmp), ord=2) # space and time penalties, and inverse of their derivatives: def g_space(Z): - return np.sqrt(np.sqrt(phi.norm(Z, ord=2).reshape( - -1, n_orient).sum(axis=1))) + return np.sqrt(np.sqrt(phi.norm(Z, ord=2).reshape(-1, n_orient).sum(axis=1))) def g_space_prime_inv(Z): - return 2. * g_space(Z) + return 2.0 * g_space(Z) def g_time(Z): - return np.sqrt(np.sqrt(np.sum((np.abs(Z) ** 2.).reshape( - (n_orient, -1), order='F'), axis=0)).reshape( - (-1, Z.shape[1]), order='F')) + return np.sqrt( + np.sqrt( + np.sum((np.abs(Z) ** 2.0).reshape((n_orient, -1), order="F"), axis=0) + ).reshape((-1, Z.shape[1]), order="F") + ) def g_time_prime_inv(Z): - return 2. * g_time(Z) + return 2.0 * g_time(Z) E = list() @@ -1432,17 +1670,29 @@ def g_time_prime_inv(Z): w_space = None w_time = None else: - w_space = 1. / g_space_prime_inv(Z) + w_space = 1.0 / g_space_prime_inv(Z) w_time = g_time_prime_inv(Z) - w_time[w_time == 0.0] = -1. - w_time = 1. / w_time + w_time[w_time == 0.0] = -1.0 + w_time = 1.0 / w_time w_time[w_time < 0.0] = 0.0 X, Z, active_set_, E_, _ = _tf_mixed_norm_solver_bcd_active_set( - M, G[:, active_set], alpha_space, alpha_time, - lc[active_set[::n_orient]], phi, phiT, - Z_init=Z, w_space=w_space, w_time=w_time, n_orient=n_orient, - maxit=maxit, tol=tol, dgap_freq=dgap_freq, verbose=None) + M, + G[:, active_set], + alpha_space, + alpha_time, + lc[active_set[::n_orient]], + phi, + phiT, + Z_init=Z, + w_space=w_space, + w_time=w_time, + n_orient=n_orient, + maxit=maxit, + tol=tol, + dgap_freq=dgap_freq, + verbose=None, + ) active_set[active_set] = active_set_ @@ -1450,25 +1700,31 @@ def g_time_prime_inv(Z): l21_penalty = np.sum(g_space(Z.copy())) l1_penalty = phi.norm(g_time(Z.copy()), ord=1).sum() - p_obj = (0.5 * np.linalg.norm(M - np.dot(G[:, active_set], X), - 'fro') ** 2. + alpha_space * l21_penalty + - alpha_time * l1_penalty) + p_obj = ( + 0.5 * np.linalg.norm(M - np.dot(G[:, active_set], X), "fro") ** 2.0 + + alpha_space * l21_penalty + + alpha_time * l1_penalty + ) E.append(p_obj) - logger.info('Iteration %d: active set size=%d, E=%f' % ( - k + 1, active_set.sum() / n_orient, p_obj)) + logger.info( + "Iteration %d: active set size=%d, E=%f" + % (k + 1, active_set.sum() / n_orient, p_obj) + ) # Check convergence if np.array_equal(active_set, active_set_0): max_diff = np.amax(np.abs(Z - Z0)) - if (max_diff < tol): - print('Convergence reached after %d reweightings!' % k) + if max_diff < tol: + print("Convergence reached after %d reweightings!" % k) break else: - p_obj = 0.5 * np.linalg.norm(M) ** 2. + p_obj = 0.5 * np.linalg.norm(M) ** 2.0 E.append(p_obj) - logger.info('Iteration %d: as_size=%d, E=%f' % ( - k + 1, active_set.sum() / n_orient, p_obj)) + logger.info( + "Iteration %d: as_size=%d, E=%f" + % (k + 1, active_set.sum() / n_orient, p_obj) + ) break if debias: diff --git a/mne/inverse_sparse/tests/test_gamma_map.py b/mne/inverse_sparse/tests/test_gamma_map.py index 6d8aecaf4ad..c6b94c7d9eb 100644 --- a/mne/inverse_sparse/tests/test_gamma_map.py +++ b/mne/inverse_sparse/tests/test_gamma_map.py @@ -8,13 +8,17 @@ import mne from mne.datasets import testing -from mne import (read_cov, read_forward_solution, read_evokeds, - convert_forward_solution, VectorSourceEstimate) +from mne import ( + read_cov, + read_forward_solution, + read_evokeds, + convert_forward_solution, + VectorSourceEstimate, +) from mne.cov import regularize from mne.inverse_sparse import gamma_map from mne.inverse_sparse.mxne_inverse import make_stc_from_dipoles -from mne.minimum_norm.tests.test_inverse import (assert_stc_res, - assert_var_exp_log) +from mne.minimum_norm.tests.test_inverse import assert_stc_res, assert_var_exp_log from mne import pick_types_forward from mne.utils import assert_stcs_equal, catch_logging from mne.dipole import Dipole @@ -22,29 +26,30 @@ data_path = testing.data_path(download=False) fname_evoked = data_path / "MEG" / "sample" / "sample_audvis-ave.fif" fname_cov = data_path / "MEG" / "sample" / "sample_audvis-cov.fif" -fname_fwd = ( - data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-6-fwd.fif" -) +fname_fwd = data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-6-fwd.fif" subjects_dir = data_path / "subjects" -def _check_stc(stc, evoked, idx, hemi, fwd, dist_limit=0., ratio=50., - res=None, atol=1e-20): +def _check_stc( + stc, evoked, idx, hemi, fwd, dist_limit=0.0, ratio=50.0, res=None, atol=1e-20 +): """Check correctness.""" assert_array_almost_equal(stc.times, evoked.times, 5) stc_orig = stc if isinstance(stc, VectorSourceEstimate): assert stc.data.any(1).any(1).all() # all dipoles should have some stc = stc.magnitude() - amps = np.sum(stc.data ** 2, axis=1) + amps = np.sum(stc.data**2, axis=1) order = np.argsort(amps)[::-1] amps = amps[order] verts = np.concatenate(stc.vertices)[order] hemi_idx = int(order[0] >= len(stc.vertices[1])) - hemis = ['lh', 'rh'] + hemis = ["lh", "rh"] assert hemis[hemi_idx] == hemi - dist = np.linalg.norm(np.diff(fwd['src'][hemi_idx]['rr'][[idx, verts[0]]], - axis=0)[0]) * 1000. + dist = ( + np.linalg.norm(np.diff(fwd["src"][hemi_idx]["rr"][[idx, verts[0]]], axis=0)[0]) + * 1000.0 + ) assert dist <= dist_limit assert amps[0] > ratio * amps[1] if res is not None: @@ -59,8 +64,7 @@ def test_gamma_map_standard(): forward = convert_forward_solution(forward, surf_ori=True) forward = pick_types_forward(forward, meg=False, eeg=True) - evoked = read_evokeds(fname_evoked, condition=0, baseline=(None, 0), - proj=False) + evoked = read_evokeds(fname_evoked, condition=0, baseline=(None, 0), proj=False) evoked.resample(50, npad=100) evoked.crop(tmin=0.1, tmax=0.14) # crop to window around peak @@ -69,50 +73,90 @@ def test_gamma_map_standard(): alpha = 0.5 with catch_logging() as log: - stc = gamma_map(evoked, forward, cov, alpha, tol=1e-4, - xyz_same_gamma=True, update_mode=1, verbose=True) - _check_stc(stc, evoked, 68477, 'lh', fwd=forward) + stc = gamma_map( + evoked, + forward, + cov, + alpha, + tol=1e-4, + xyz_same_gamma=True, + update_mode=1, + verbose=True, + ) + _check_stc(stc, evoked, 68477, "lh", fwd=forward) assert_var_exp_log(log.getvalue(), 20, 22) with catch_logging() as log: stc_vec, res = gamma_map( - evoked, forward, cov, alpha, tol=1e-4, xyz_same_gamma=True, - update_mode=1, pick_ori='vector', return_residual=True, - verbose=True) + evoked, + forward, + cov, + alpha, + tol=1e-4, + xyz_same_gamma=True, + update_mode=1, + pick_ori="vector", + return_residual=True, + verbose=True, + ) assert_var_exp_log(log.getvalue(), 20, 22) assert_stcs_equal(stc_vec.magnitude(), stc) - _check_stc(stc_vec, evoked, 68477, 'lh', fwd=forward, res=res) + _check_stc(stc_vec, evoked, 68477, "lh", fwd=forward, res=res) stc, res = gamma_map( - evoked, forward, cov, alpha, tol=1e-4, xyz_same_gamma=False, - update_mode=1, pick_ori='vector', return_residual=True) - _check_stc(stc, evoked, 82010, 'lh', fwd=forward, dist_limit=6., ratio=2., - res=res) + evoked, + forward, + cov, + alpha, + tol=1e-4, + xyz_same_gamma=False, + update_mode=1, + pick_ori="vector", + return_residual=True, + ) + _check_stc( + stc, evoked, 82010, "lh", fwd=forward, dist_limit=6.0, ratio=2.0, res=res + ) with catch_logging() as log: - dips = gamma_map(evoked, forward, cov, alpha, tol=1e-4, - xyz_same_gamma=False, update_mode=1, - return_as_dipoles=True, verbose=True) + dips = gamma_map( + evoked, + forward, + cov, + alpha, + tol=1e-4, + xyz_same_gamma=False, + update_mode=1, + return_as_dipoles=True, + verbose=True, + ) exp_var = assert_var_exp_log(log.getvalue(), 58, 60) dip_exp_var = np.mean(sum(dip.gof for dip in dips)) assert_allclose(exp_var, dip_exp_var, atol=10) # not really equiv, close - assert (isinstance(dips[0], Dipole)) - stc_dip = make_stc_from_dipoles(dips, forward['src']) + assert isinstance(dips[0], Dipole) + stc_dip = make_stc_from_dipoles(dips, forward["src"]) assert_stcs_equal(stc.magnitude(), stc_dip) # force fixed orientation - stc, res = gamma_map(evoked, forward, cov, alpha, tol=1e-4, - xyz_same_gamma=False, update_mode=2, - loose=0, return_residual=True) - _check_stc(stc, evoked, 85739, 'lh', fwd=forward, ratio=20., res=res) + stc, res = gamma_map( + evoked, + forward, + cov, + alpha, + tol=1e-4, + xyz_same_gamma=False, + update_mode=2, + loose=0, + return_residual=True, + ) + _check_stc(stc, evoked, 85739, "lh", fwd=forward, ratio=20.0, res=res) @pytest.mark.slowtest @testing.requires_testing_data def test_gamma_map_vol_sphere(): """Gamma MAP with a sphere forward and volumic source space.""" - evoked = read_evokeds(fname_evoked, condition=0, baseline=(None, 0), - proj=False) + evoked = read_evokeds(fname_evoked, condition=0, baseline=(None, 0), proj=False) evoked.resample(50, npad=100) evoked.crop(tmin=0.1, tmax=0.16) # crop to window around peak @@ -120,18 +164,32 @@ def test_gamma_map_vol_sphere(): cov = regularize(cov, evoked.info, rank=None) info = evoked.info - sphere = mne.make_sphere_model(r0=(0., 0., 0.), head_radius=0.080) - src = mne.setup_volume_source_space(subject=None, pos=30., mri=None, - sphere=(0.0, 0.0, 0.0, 0.08), - bem=None, mindist=5.0, - exclude=2.0, sphere_units='m') - fwd = mne.make_forward_solution(info, trans=None, src=src, bem=sphere, - eeg=False, meg=True) + sphere = mne.make_sphere_model(r0=(0.0, 0.0, 0.0), head_radius=0.080) + src = mne.setup_volume_source_space( + subject=None, + pos=30.0, + mri=None, + sphere=(0.0, 0.0, 0.0, 0.08), + bem=None, + mindist=5.0, + exclude=2.0, + sphere_units="m", + ) + fwd = mne.make_forward_solution( + info, trans=None, src=src, bem=sphere, eeg=False, meg=True + ) alpha = 0.5 - stc = gamma_map(evoked, fwd, cov, alpha, tol=1e-4, - xyz_same_gamma=False, update_mode=2, - return_residual=False) + stc = gamma_map( + evoked, + fwd, + cov, + alpha, + tol=1e-4, + xyz_same_gamma=False, + update_mode=2, + return_residual=False, + ) assert_array_almost_equal(stc.times, evoked.times, 5) # Computing inverse with restricted orientations should also work, since @@ -141,18 +199,21 @@ def test_gamma_map_vol_sphere(): # Compare orientation obtained using fit_dipole and gamma_map # for a simulated evoked containing a single dipole - stc = mne.VolSourceEstimate(50e-9 * np.random.RandomState(42).randn(1, 4), - vertices=[stc.vertices[0][:1]], - tmin=stc.tmin, - tstep=stc.tstep) - evoked_dip = mne.simulation.simulate_evoked(fwd, stc, info, cov, nave=1e9, - use_cps=True) + stc = mne.VolSourceEstimate( + 50e-9 * np.random.RandomState(42).randn(1, 4), + vertices=[stc.vertices[0][:1]], + tmin=stc.tmin, + tstep=stc.tstep, + ) + evoked_dip = mne.simulation.simulate_evoked( + fwd, stc, info, cov, nave=1e9, use_cps=True + ) dip_gmap = gamma_map(evoked_dip, fwd, cov, 0.1, return_as_dipoles=True) amp_max = [np.max(d.amplitude) for d in dip_gmap] dip_gmap = dip_gmap[np.argmax(amp_max)] - assert (dip_gmap[0].pos[0] in src[0]['rr'][stc.vertices[0]]) + assert dip_gmap[0].pos[0] in src[0]["rr"][stc.vertices[0]] dip_fit = mne.fit_dipole(evoked_dip, cov, sphere)[0] - assert (np.abs(np.dot(dip_fit.ori[0], dip_gmap.ori[0])) > 0.99) + assert np.abs(np.dot(dip_fit.ori[0], dip_gmap.ori[0])) > 0.99 diff --git a/mne/inverse_sparse/tests/test_mxne_inverse.py b/mne/inverse_sparse/tests/test_mxne_inverse.py index c5a8064b608..19b26bc7483 100644 --- a/mne/inverse_sparse/tests/test_mxne_inverse.py +++ b/mne/inverse_sparse/tests/test_mxne_inverse.py @@ -4,22 +4,24 @@ # License: Simplified BSD import numpy as np -from numpy.testing import (assert_array_almost_equal, assert_allclose, - assert_array_less, assert_array_equal) +from numpy.testing import ( + assert_array_almost_equal, + assert_allclose, + assert_array_less, + assert_array_equal, +) import pytest import mne from mne.datasets import testing from mne.label import read_label -from mne import (read_cov, read_forward_solution, read_evokeds, - convert_forward_solution) +from mne import read_cov, read_forward_solution, read_evokeds, convert_forward_solution from mne.inverse_sparse import mixed_norm, tf_mixed_norm from mne.inverse_sparse.mxne_inverse import make_stc_from_dipoles, _split_gof from mne.inverse_sparse.mxne_inverse import _compute_mxne_sure from mne.inverse_sparse.mxne_optim import norm_l2inf from mne.minimum_norm import apply_inverse, make_inverse_operator -from mne.minimum_norm.tests.test_inverse import \ - assert_var_exp_log, assert_stc_res +from mne.minimum_norm.tests.test_inverse import assert_var_exp_log, assert_stc_res from mne.utils import assert_stcs_equal, catch_logging, _record_warnings from mne.dipole import Dipole from mne.source_estimate import VolSourceEstimate @@ -31,14 +33,12 @@ fname_data = data_path / "MEG" / "sample" / "sample_audvis-ave.fif" fname_cov = data_path / "MEG" / "sample" / "sample_audvis-cov.fif" fname_raw = data_path / "MEG" / "sample" / "sample_audvis_trunc_raw.fif" -fname_fwd = ( - data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-6-fwd.fif" -) +fname_fwd = data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-6-fwd.fif" label = "Aud-rh" fname_label = data_path / "MEG" / "sample" / "labels" / ("%s.label" % label) -@pytest.fixture(scope='module', params=[testing._pytest_param]) +@pytest.fixture(scope="module", params=[testing._pytest_param]) def forward(): """Get a forward solution.""" # module scope it for speed (but don't overwrite in use!) @@ -63,32 +63,59 @@ def test_mxne_inverse_standard(forward): evoked_l21 = evoked.copy() evoked_l21.crop(tmin=0.081, tmax=0.1) label = read_label(fname_label) - assert label.hemi == 'rh' + assert label.hemi == "rh" forward = convert_forward_solution(forward, surf_ori=True) # Reduce source space to make test computation faster - inverse_operator = make_inverse_operator(evoked_l21.info, forward, cov, - loose=loose, depth=depth, - fixed=True, use_cps=True) - stc_dspm = apply_inverse(evoked_l21, inverse_operator, lambda2=1. / 9., - method='dSPM') + inverse_operator = make_inverse_operator( + evoked_l21.info, + forward, + cov, + loose=loose, + depth=depth, + fixed=True, + use_cps=True, + ) + stc_dspm = apply_inverse( + evoked_l21, inverse_operator, lambda2=1.0 / 9.0, method="dSPM" + ) stc_dspm.data[np.abs(stc_dspm.data) < 12] = 0.0 - stc_dspm.data[np.abs(stc_dspm.data) >= 12] = 1. + stc_dspm.data[np.abs(stc_dspm.data) >= 12] = 1.0 weights_min = 0.5 # MxNE tests alpha = 70 # spatial regularization parameter with _record_warnings(): # CD - stc_cd = mixed_norm(evoked_l21, forward, cov, alpha, loose=loose, - depth=depth, maxit=300, tol=1e-8, - active_set_size=10, weights=stc_dspm, - weights_min=weights_min, solver='cd') - stc_bcd = mixed_norm(evoked_l21, forward, cov, alpha, loose=loose, - depth=depth, maxit=300, tol=1e-8, active_set_size=10, - weights=stc_dspm, weights_min=weights_min, - solver='bcd') + stc_cd = mixed_norm( + evoked_l21, + forward, + cov, + alpha, + loose=loose, + depth=depth, + maxit=300, + tol=1e-8, + active_set_size=10, + weights=stc_dspm, + weights_min=weights_min, + solver="cd", + ) + stc_bcd = mixed_norm( + evoked_l21, + forward, + cov, + alpha, + loose=loose, + depth=depth, + maxit=300, + tol=1e-8, + active_set_size=10, + weights=stc_dspm, + weights_min=weights_min, + solver="bcd", + ) assert_array_almost_equal(stc_cd.times, evoked_l21.times, 5) assert_array_almost_equal(stc_bcd.times, evoked_l21.times, 5) assert_allclose(stc_cd.data, stc_bcd.data, rtol=1e-3, atol=0.0) @@ -99,20 +126,31 @@ def test_mxne_inverse_standard(forward): with _record_warnings(): # no convergence stc = mixed_norm(evoked_l21, forward, cov, alpha, loose=1, maxit=2) with _record_warnings(): # no convergence - stc_vec = mixed_norm(evoked_l21, forward, cov, alpha, loose=1, maxit=2, - pick_ori='vector') + stc_vec = mixed_norm( + evoked_l21, forward, cov, alpha, loose=1, maxit=2, pick_ori="vector" + ) assert_stcs_equal(stc_vec.magnitude(), stc) - with _record_warnings(), \ - pytest.raises(ValueError, match='pick_ori='): - mixed_norm(evoked_l21, forward, cov, alpha, loose=0, maxit=2, - pick_ori='vector') + with _record_warnings(), pytest.raises(ValueError, match="pick_ori="): + mixed_norm(evoked_l21, forward, cov, alpha, loose=0, maxit=2, pick_ori="vector") with _record_warnings(), catch_logging() as log: # CD - dips = mixed_norm(evoked_l21, forward, cov, alpha, loose=loose, - depth=depth, maxit=300, tol=1e-8, active_set_size=10, - weights=stc_dspm, weights_min=weights_min, - solver='cd', return_as_dipoles=True, verbose=True) - stc_dip = make_stc_from_dipoles(dips, forward['src']) + dips = mixed_norm( + evoked_l21, + forward, + cov, + alpha, + loose=loose, + depth=depth, + maxit=300, + tol=1e-8, + active_set_size=10, + weights=stc_dspm, + weights_min=weights_min, + solver="cd", + return_as_dipoles=True, + verbose=True, + ) + stc_dip = make_stc_from_dipoles(dips, forward["src"]) assert isinstance(dips[0], Dipole) assert stc_dip.subject == "sample" assert_stcs_equal(stc_cd, stc_dip) @@ -120,21 +158,42 @@ def test_mxne_inverse_standard(forward): # Single time point things should match with _record_warnings(), catch_logging() as log: - dips = mixed_norm(evoked_l21.copy().crop(0.081, 0.081), - forward, cov, alpha, loose=loose, - depth=depth, maxit=300, tol=1e-8, active_set_size=10, - weights=stc_dspm, weights_min=weights_min, - solver='cd', return_as_dipoles=True, verbose=True) + dips = mixed_norm( + evoked_l21.copy().crop(0.081, 0.081), + forward, + cov, + alpha, + loose=loose, + depth=depth, + maxit=300, + tol=1e-8, + active_set_size=10, + weights=stc_dspm, + weights_min=weights_min, + solver="cd", + return_as_dipoles=True, + verbose=True, + ) assert_var_exp_log(log.getvalue(), 37.8, 38.0) # 37.9 gof = sum(dip.gof[0] for dip in dips) # these are now partial exp vars assert_allclose(gof, 37.9, atol=0.1) with _record_warnings(), catch_logging() as log: - stc, res = mixed_norm(evoked_l21, forward, cov, alpha, loose=loose, - depth=depth, maxit=300, tol=1e-8, - weights=stc_dspm, # gh-6382 - active_set_size=10, return_residual=True, - solver='cd', verbose=True) + stc, res = mixed_norm( + evoked_l21, + forward, + cov, + alpha, + loose=loose, + depth=depth, + maxit=300, + tol=1e-8, + weights=stc_dspm, # gh-6382 + active_set_size=10, + return_residual=True, + solver="cd", + verbose=True, + ) assert_array_almost_equal(stc.times, evoked_l21.times, 5) assert stc.vertices[1][0] in label.vertices assert_var_exp_log(log.getvalue(), 51, 53) # 51.8 @@ -144,9 +203,21 @@ def test_mxne_inverse_standard(forward): # irMxNE tests with _record_warnings(), catch_logging() as log: # CD stc, residual = mixed_norm( - evoked_l21, forward, cov, alpha, n_mxne_iter=5, loose=0.0001, - depth=depth, maxit=300, tol=1e-8, active_set_size=10, - solver='cd', return_residual=True, pick_ori='vector', verbose=True) + evoked_l21, + forward, + cov, + alpha, + n_mxne_iter=5, + loose=0.0001, + depth=depth, + maxit=300, + tol=1e-8, + active_set_size=10, + solver="cd", + return_residual=True, + pick_ori="vector", + verbose=True, + ) assert_array_almost_equal(stc.times, evoked_l21.times, 5) assert stc.vertices[1][0] in label.vertices assert stc.vertices == [[63152], [79017]] @@ -154,33 +225,72 @@ def test_mxne_inverse_standard(forward): assert_stc_res(evoked_l21, stc, forward, residual) # Do with TF-MxNE for test memory savings - alpha = 60. # overall regularization parameter + alpha = 60.0 # overall regularization parameter l1_ratio = 0.01 # temporal regularization proportion - stc, _ = tf_mixed_norm(evoked, forward, cov, - loose=loose, depth=depth, maxit=100, tol=1e-4, - tstep=4, wsize=16, window=0.1, weights=stc_dspm, - weights_min=weights_min, return_residual=True, - alpha=alpha, l1_ratio=l1_ratio) + stc, _ = tf_mixed_norm( + evoked, + forward, + cov, + loose=loose, + depth=depth, + maxit=100, + tol=1e-4, + tstep=4, + wsize=16, + window=0.1, + weights=stc_dspm, + weights_min=weights_min, + return_residual=True, + alpha=alpha, + l1_ratio=l1_ratio, + ) assert_array_almost_equal(stc.times, evoked.times, 5) assert stc.vertices[1][0] in label.vertices # vector stc_nrm = tf_mixed_norm( - evoked, forward, cov, loose=1, depth=depth, maxit=2, tol=1e-4, - tstep=4, wsize=16, window=0.1, weights=stc_dspm, - weights_min=weights_min, alpha=alpha, l1_ratio=l1_ratio) + evoked, + forward, + cov, + loose=1, + depth=depth, + maxit=2, + tol=1e-4, + tstep=4, + wsize=16, + window=0.1, + weights=stc_dspm, + weights_min=weights_min, + alpha=alpha, + l1_ratio=l1_ratio, + ) stc_vec, residual = tf_mixed_norm( - evoked, forward, cov, loose=1, depth=depth, maxit=2, tol=1e-4, - tstep=4, wsize=16, window=0.1, weights=stc_dspm, - weights_min=weights_min, alpha=alpha, l1_ratio=l1_ratio, - pick_ori='vector', return_residual=True) + evoked, + forward, + cov, + loose=1, + depth=depth, + maxit=2, + tol=1e-4, + tstep=4, + wsize=16, + window=0.1, + weights=stc_dspm, + weights_min=weights_min, + alpha=alpha, + l1_ratio=l1_ratio, + pick_ori="vector", + return_residual=True, + ) assert_stcs_equal(stc_vec.magnitude(), stc_nrm) - pytest.raises(ValueError, tf_mixed_norm, evoked, forward, cov, - alpha=101, l1_ratio=0.03) - pytest.raises(ValueError, tf_mixed_norm, evoked, forward, cov, - alpha=50., l1_ratio=1.01) + pytest.raises( + ValueError, tf_mixed_norm, evoked, forward, cov, alpha=101, l1_ratio=0.03 + ) + pytest.raises( + ValueError, tf_mixed_norm, evoked, forward, cov, alpha=50.0, l1_ratio=1.01 + ) @pytest.mark.slowtest @@ -195,123 +305,169 @@ def test_mxne_vol_sphere(): evoked_l21.crop(tmin=0.081, tmax=0.1) info = evoked.info - sphere = mne.make_sphere_model(r0=(0., 0., 0.), head_radius=0.080) - src = mne.setup_volume_source_space(subject=None, pos=15., mri=None, - sphere=(0.0, 0.0, 0.0, 0.08), - bem=None, mindist=5.0, - exclude=2.0, sphere_units='m') - fwd = mne.make_forward_solution(info, trans=None, src=src, - bem=sphere, eeg=False, meg=True) + sphere = mne.make_sphere_model(r0=(0.0, 0.0, 0.0), head_radius=0.080) + src = mne.setup_volume_source_space( + subject=None, + pos=15.0, + mri=None, + sphere=(0.0, 0.0, 0.0, 0.08), + bem=None, + mindist=5.0, + exclude=2.0, + sphere_units="m", + ) + fwd = mne.make_forward_solution( + info, trans=None, src=src, bem=sphere, eeg=False, meg=True + ) - alpha = 80. + alpha = 80.0 # Computing inverse with restricted orientations should also work, since # we have a discrete source space. - stc = mixed_norm(evoked_l21, fwd, cov, alpha, loose=0.2, - return_residual=False, maxit=3, tol=1e-8, - active_set_size=10) + stc = mixed_norm( + evoked_l21, + fwd, + cov, + alpha, + loose=0.2, + return_residual=False, + maxit=3, + tol=1e-8, + active_set_size=10, + ) assert_array_almost_equal(stc.times, evoked_l21.times, 5) # irMxNE tests with catch_logging() as log: - stc = mixed_norm(evoked_l21, fwd, cov, alpha, - n_mxne_iter=1, maxit=30, tol=1e-8, - active_set_size=10, verbose=True) + stc = mixed_norm( + evoked_l21, + fwd, + cov, + alpha, + n_mxne_iter=1, + maxit=30, + tol=1e-8, + active_set_size=10, + verbose=True, + ) assert isinstance(stc, VolSourceEstimate) assert_array_almost_equal(stc.times, evoked_l21.times, 5) assert_var_exp_log(log.getvalue(), 9, 11) # 10.2 # Compare orientation obtained using fit_dipole and gamma_map # for a simulated evoked containing a single dipole - stc = mne.VolSourceEstimate(50e-9 * np.random.RandomState(42).randn(1, 4), - vertices=[stc.vertices[0][:1]], - tmin=stc.tmin, - tstep=stc.tstep) - evoked_dip = mne.simulation.simulate_evoked(fwd, stc, info, cov, nave=1e9, - use_cps=True) + stc = mne.VolSourceEstimate( + 50e-9 * np.random.RandomState(42).randn(1, 4), + vertices=[stc.vertices[0][:1]], + tmin=stc.tmin, + tstep=stc.tstep, + ) + evoked_dip = mne.simulation.simulate_evoked( + fwd, stc, info, cov, nave=1e9, use_cps=True + ) - dip_mxne = mixed_norm(evoked_dip, fwd, cov, alpha=80, - n_mxne_iter=1, maxit=30, tol=1e-8, - active_set_size=10, return_as_dipoles=True) + dip_mxne = mixed_norm( + evoked_dip, + fwd, + cov, + alpha=80, + n_mxne_iter=1, + maxit=30, + tol=1e-8, + active_set_size=10, + return_as_dipoles=True, + ) amp_max = [np.max(d.amplitude) for d in dip_mxne] dip_mxne = dip_mxne[np.argmax(amp_max)] - assert dip_mxne.pos[0] in src[0]['rr'][stc.vertices[0]] + assert dip_mxne.pos[0] in src[0]["rr"][stc.vertices[0]] dip_fit = mne.fit_dipole(evoked_dip, cov, sphere)[0] assert np.abs(np.dot(dip_fit.ori[0], dip_mxne.ori[0])) > 0.99 dist = 1000 * np.linalg.norm(dip_fit.pos[0] - dip_mxne.pos[0]) - assert dist < 4. # within 4 mm + assert dist < 4.0 # within 4 mm # Do with TF-MxNE for test memory savings - alpha = 60. # overall regularization parameter + alpha = 60.0 # overall regularization parameter l1_ratio = 0.01 # temporal regularization proportion - stc, _ = tf_mixed_norm(evoked, fwd, cov, maxit=3, tol=1e-4, - tstep=16, wsize=32, window=0.1, alpha=alpha, - l1_ratio=l1_ratio, return_residual=True) + stc, _ = tf_mixed_norm( + evoked, + fwd, + cov, + maxit=3, + tol=1e-4, + tstep=16, + wsize=32, + window=0.1, + alpha=alpha, + l1_ratio=l1_ratio, + return_residual=True, + ) assert isinstance(stc, VolSourceEstimate) assert_array_almost_equal(stc.times, evoked.times, 5) -@pytest.mark.parametrize('mod', ( - None, 'mult', 'augment', 'sign', 'zero', 'less')) +@pytest.mark.parametrize("mod", (None, "mult", "augment", "sign", "zero", "less")) def test_split_gof_basic(mod): """Test splitting the goodness of fit.""" # first a trivial case - gain = np.array([[0., 1., 1.], [1., 1., 0.]]).T + gain = np.array([[0.0, 1.0, 1.0], [1.0, 1.0, 0.0]]).T M = np.ones((3, 1)) X = np.ones((2, 1)) M_est = gain @ X - assert_allclose(M_est, np.array([[1., 2., 1.]]).T) # a reasonable estimate - if mod == 'mult': - gain *= [1., -0.5] + assert_allclose(M_est, np.array([[1.0, 2.0, 1.0]]).T) # a reasonable estimate + if mod == "mult": + gain *= [1.0, -0.5] X[1] *= -2 - elif mod == 'augment': + elif mod == "augment": gain = np.concatenate((gain, np.zeros((3, 1))), axis=1) - X = np.concatenate((X, [[1.]])) - elif mod == 'sign': + X = np.concatenate((X, [[1.0]])) + elif mod == "sign": gain[1] *= -1 M[1] *= -1 M_est[1] *= -1 - elif mod in ('zero', 'less'): - gain = np.array([[1, 1., 1.], [1., 1., 1.]]).T - if mod == 'zero': - X[:, 0] = [1., 0.] + elif mod in ("zero", "less"): + gain = np.array([[1, 1.0, 1.0], [1.0, 1.0, 1.0]]).T + if mod == "zero": + X[:, 0] = [1.0, 0.0] else: - X[:, 0] = [1., 0.5] + X[:, 0] = [1.0, 0.5] M_est = gain @ X else: assert mod is None res = M - M_est - gof = 100 * (1. - (res * res).sum() / (M * M).sum()) + gof = 100 * (1.0 - (res * res).sum() / (M * M).sum()) gof_split = _split_gof(M, X, gain) assert_allclose(gof_split.sum(), gof) want = gof_split[[0, 0]] - if mod == 'augment': + if mod == "augment": want = np.concatenate((want, [[0]])) - if mod in ('mult', 'less'): + if mod in ("mult", "less"): assert_array_less(gof_split[1], gof_split[0]) - elif mod == 'zero': + elif mod == "zero": assert_allclose(gof_split[0], gof_split.sum(0)) - assert_allclose(gof_split[1], 0., atol=1e-6) + assert_allclose(gof_split[1], 0.0, atol=1e-6) else: assert_allclose(gof_split, want, atol=1e-12) @testing.requires_testing_data -@pytest.mark.parametrize('idx, weights', [ - # empirically determined approximately orthogonal columns: 0, 15157, 19448 - ([0], [1]), - ([0, 15157], [1, 1]), - ([0, 15157], [1, 3]), - ([0, 15157], [5, -1]), - ([0, 15157, 19448], [1, 1, 1]), - ([0, 15157, 19448], [1e-2, 1, 5]), -]) +@pytest.mark.parametrize( + "idx, weights", + [ + # empirically determined approximately orthogonal columns: 0, 15157, 19448 + ([0], [1]), + ([0, 15157], [1, 1]), + ([0, 15157], [1, 3]), + ([0, 15157], [5, -1]), + ([0, 15157, 19448], [1, 1, 1]), + ([0, 15157, 19448], [1e-2, 1, 5]), + ], +) def test_split_gof_meg(forward, idx, weights): """Test GOF splitting on MEG data.""" - gain = forward['sol']['data'][:, idx] + gain = forward["sol"]["data"][:, idx] # close to orthogonal norms = np.linalg.norm(gain, axis=0) triu = np.triu_indices(len(idx), 1) @@ -320,7 +476,7 @@ def test_split_gof_meg(forward, idx, weights): # first, split across time (one dipole per time point) M = gain * weights gof_split = _split_gof(M, np.diag(weights), gain) - assert_allclose(gof_split.sum(0), 100., atol=1e-5) # all sum to 100 + assert_allclose(gof_split.sum(0), 100.0, atol=1e-5) # all sum to 100 assert_allclose(gof_split, 100 * np.eye(len(weights)), atol=1) # loc # next, summed to a single time point (all dipoles active at one time pt) weights = np.array(weights)[:, np.newaxis] @@ -333,33 +489,36 @@ def test_split_gof_meg(forward, idx, weights): assert_allclose(gof_split.sum(), 100, rtol=1e-5) -@pytest.mark.parametrize('n_sensors, n_dipoles, n_times', [ - (10, 15, 7), - (20, 60, 20), -]) -@pytest.mark.parametrize('nnz', [2, 4]) -@pytest.mark.parametrize('corr', [0.75]) -@pytest.mark.parametrize('n_orient', [1, 3]) -def test_mxne_inverse_sure_synthetic(n_sensors, n_dipoles, n_times, nnz, corr, - n_orient, snr=4): +@pytest.mark.parametrize( + "n_sensors, n_dipoles, n_times", + [ + (10, 15, 7), + (20, 60, 20), + ], +) +@pytest.mark.parametrize("nnz", [2, 4]) +@pytest.mark.parametrize("corr", [0.75]) +@pytest.mark.parametrize("n_orient", [1, 3]) +def test_mxne_inverse_sure_synthetic( + n_sensors, n_dipoles, n_times, nnz, corr, n_orient, snr=4 +): """Tests SURE criterion for automatic alpha selection on synthetic data.""" rng = np.random.RandomState(0) - sigma = np.sqrt(1 - corr ** 2) + sigma = np.sqrt(1 - corr**2) U = rng.randn(n_sensors) # generate gain matrix - G = np.empty([n_sensors, n_dipoles], order='F') + G = np.empty([n_sensors, n_dipoles], order="F") G[:, :n_orient] = np.expand_dims(U, axis=-1) n_dip_per_pos = n_dipoles // n_orient for j in range(1, n_dip_per_pos): U *= corr U += sigma * rng.randn(n_sensors) - G[:, j * n_orient:(j + 1) * n_orient] = np.expand_dims(U, axis=-1) + G[:, j * n_orient : (j + 1) * n_orient] = np.expand_dims(U, axis=-1) # generate coefficient matrix support = rng.choice(n_dip_per_pos, nnz, replace=False) X = np.zeros((n_dipoles, n_times)) for k in support: - X[k * n_orient:(k + 1) * n_orient, :] = rng.normal( - size=(n_orient, n_times)) + X[k * n_orient : (k + 1) * n_orient, :] = rng.normal(size=(n_orient, n_times)) # generate measurement matrix M = G @ X noise = rng.randn(n_sensors, n_times) @@ -368,12 +527,22 @@ def test_mxne_inverse_sure_synthetic(n_sensors, n_dipoles, n_times, nnz, corr, # inverse modeling with sure alpha_max = norm_l2inf(np.dot(G.T, M), n_orient, copy=False) alpha_grid = np.geomspace(alpha_max, alpha_max / 10, num=15) - _, active_set, _ = _compute_mxne_sure(M, G, alpha_grid, sigma=sigma, - n_mxne_iter=5, maxit=3000, tol=1e-4, - n_orient=n_orient, - active_set_size=10, debias=True, - solver="auto", dgap_freq=10, - random_state=0, verbose=False) + _, active_set, _ = _compute_mxne_sure( + M, + G, + alpha_grid, + sigma=sigma, + n_mxne_iter=5, + maxit=3000, + tol=1e-4, + n_orient=n_orient, + active_set_size=10, + debias=True, + solver="auto", + dgap_freq=10, + random_state=0, + verbose=False, + ) assert np.count_nonzero(active_set, axis=-1) == n_orient * nnz @@ -381,38 +550,45 @@ def test_mxne_inverse_sure_synthetic(n_sensors, n_dipoles, n_times, nnz, corr, @testing.requires_testing_data def test_mxne_inverse_sure(): """Tests SURE criterion for automatic alpha selection on MEG data.""" + def data_fun(times): data = np.zeros(times.shape) data[times >= 0] = 50e-9 return data + n_dipoles = 2 raw = mne.io.read_raw_fif(fname_raw) info = mne.io.read_info(fname_data) with info._unlock(): - info['projs'] = [] + info["projs"] = [] noise_cov = mne.make_ad_hoc_cov(info) - label_names = ['Aud-lh', 'Aud-rh'] + label_names = ["Aud-lh", "Aud-rh"] labels = [ - mne.read_label(data_path / 'MEG' / 'sample' / 'labels' / f'{ln}.label') - for ln in label_names] + mne.read_label(data_path / "MEG" / "sample" / "labels" / f"{ln}.label") + for ln in label_names + ] fname_fwd = ( - data_path - / "MEG" - / "sample" - / "sample_audvis_trunc-meg-eeg-oct-4-fwd.fif" + data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-4-fwd.fif" ) forward = mne.read_forward_solution(fname_fwd) - forward = mne.pick_types_forward(forward, meg="grad", eeg=False, - exclude=raw.info['bads']) - times = np.arange(100, dtype=np.float64) / raw.info['sfreq'] - 0.1 - stc = simulate_sparse_stc(forward['src'], n_dipoles=n_dipoles, times=times, - random_state=1, labels=labels, data_fun=data_fun) + forward = mne.pick_types_forward( + forward, meg="grad", eeg=False, exclude=raw.info["bads"] + ) + times = np.arange(100, dtype=np.float64) / raw.info["sfreq"] - 0.1 + stc = simulate_sparse_stc( + forward["src"], + n_dipoles=n_dipoles, + times=times, + random_state=1, + labels=labels, + data_fun=data_fun, + ) nave = 30 - evoked = simulate_evoked(forward, stc, info, noise_cov, nave=nave, - use_cps=False, iir_filter=None) + evoked = simulate_evoked( + forward, stc, info, noise_cov, nave=nave, use_cps=False, iir_filter=None + ) evoked = evoked.crop(tmin=0, tmax=10e-3) - stc_ = mixed_norm(evoked, forward, noise_cov, loose=0.9, n_mxne_iter=5, - depth=0.9) + stc_ = mixed_norm(evoked, forward, noise_cov, loose=0.9, n_mxne_iter=5, depth=0.9) assert_array_equal(stc_.vertices, stc.vertices) @@ -423,19 +599,17 @@ def test_mxne_inverse_empty(): evoked = read_evokeds(fname_data, condition=0, baseline=(None, 0)) evoked.pick("grad", exclude="bads") fname_fwd = ( - data_path - / "MEG" - / "sample" - / "sample_audvis_trunc-meg-eeg-oct-4-fwd.fif" + data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-4-fwd.fif" ) forward = mne.read_forward_solution(fname_fwd) - forward = mne.pick_types_forward(forward, meg="grad", eeg=False, - exclude=evoked.info['bads']) + forward = mne.pick_types_forward( + forward, meg="grad", eeg=False, exclude=evoked.info["bads"] + ) cov = read_cov(fname_cov) - with pytest.warns(RuntimeWarning, match='too big'): + with pytest.warns(RuntimeWarning, match="too big"): stc, residual = mixed_norm( - evoked, forward, cov, n_mxne_iter=3, alpha=99, - return_residual=True) + evoked, forward, cov, n_mxne_iter=3, alpha=99, return_residual=True + ) assert stc.data.size == 0 assert stc.vertices[0].size == 0 assert stc.vertices[1].size == 0 diff --git a/mne/inverse_sparse/tests/test_mxne_optim.py b/mne/inverse_sparse/tests/test_mxne_optim.py index c3288528400..b0779c01e7c 100644 --- a/mne/inverse_sparse/tests/test_mxne_optim.py +++ b/mne/inverse_sparse/tests/test_mxne_optim.py @@ -5,15 +5,24 @@ import pytest import numpy as np -from numpy.testing import (assert_array_equal, assert_array_almost_equal, - assert_allclose, assert_array_less) - -from mne.inverse_sparse.mxne_optim import (mixed_norm_solver, - tf_mixed_norm_solver, - iterative_mixed_norm_solver, - iterative_tf_mixed_norm_solver, - norm_epsilon_inf, norm_epsilon, - _Phi, _PhiT, dgap_l21l1) +from numpy.testing import ( + assert_array_equal, + assert_array_almost_equal, + assert_allclose, + assert_array_less, +) + +from mne.inverse_sparse.mxne_optim import ( + mixed_norm_solver, + tf_mixed_norm_solver, + iterative_mixed_norm_solver, + iterative_tf_mixed_norm_solver, + norm_epsilon_inf, + norm_epsilon, + _Phi, + _PhiT, + dgap_l21l1, +) from mne.time_frequency._stft import stft_norm2 from mne.utils import catch_logging, _record_warnings @@ -37,7 +46,7 @@ def _generate_tf_data(): def test_l21_mxne(): """Test convergence of MxNE solver.""" - n, p, t, alpha = 30, 40, 20, 1. + n, p, t, alpha = 30, 40, 20, 1.0 rng = np.random.RandomState(0) G = rng.randn(n, p) G /= np.std(G, axis=0)[None, :] @@ -49,47 +58,61 @@ def test_l21_mxne(): args = (M, G, alpha, 1000, 1e-8) with _record_warnings(): # CD X_hat_cd, active_set, _, gap_cd = mixed_norm_solver( - *args, active_set_size=None, - debias=True, solver='cd', return_gap=True) + *args, active_set_size=None, debias=True, solver="cd", return_gap=True + ) assert_array_less(gap_cd, 1e-8) assert_array_equal(np.where(active_set)[0], [0, 4]) with _record_warnings(): # CD X_hat_bcd, active_set, E, gap_bcd = mixed_norm_solver( - M, G, alpha, maxit=1000, tol=1e-8, active_set_size=None, - debias=True, solver='bcd', return_gap=True) + M, + G, + alpha, + maxit=1000, + tol=1e-8, + active_set_size=None, + debias=True, + solver="bcd", + return_gap=True, + ) assert_array_less(gap_bcd, 9.6e-9) assert_array_equal(np.where(active_set)[0], [0, 4]) assert_allclose(X_hat_bcd, X_hat_cd, rtol=1e-2) with _record_warnings(): # CD X_hat_cd, active_set, _ = mixed_norm_solver( - *args, active_set_size=2, debias=True, solver='cd') + *args, active_set_size=2, debias=True, solver="cd" + ) assert_array_equal(np.where(active_set)[0], [0, 4]) with _record_warnings(): # CD X_hat_bcd, active_set, _ = mixed_norm_solver( - *args, active_set_size=2, debias=True, solver='bcd') + *args, active_set_size=2, debias=True, solver="bcd" + ) assert_array_equal(np.where(active_set)[0], [0, 4]) assert_allclose(X_hat_bcd, X_hat_cd, rtol=1e-2) with _record_warnings(): # CD X_hat_bcd, active_set, _ = mixed_norm_solver( - *args, active_set_size=2, debias=True, n_orient=2, solver='bcd') + *args, active_set_size=2, debias=True, n_orient=2, solver="bcd" + ) assert_array_equal(np.where(active_set)[0], [0, 1, 4, 5]) # suppress a coordinate-descent warning here - with pytest.warns(RuntimeWarning, match='descent'): + with pytest.warns(RuntimeWarning, match="descent"): X_hat_cd, active_set, _ = mixed_norm_solver( - *args, active_set_size=2, debias=True, n_orient=2, solver='cd') + *args, active_set_size=2, debias=True, n_orient=2, solver="cd" + ) assert_array_equal(np.where(active_set)[0], [0, 1, 4, 5]) assert_allclose(X_hat_bcd, X_hat_cd, rtol=1e-2) with _record_warnings(): # CD X_hat_bcd, active_set, _ = mixed_norm_solver( - *args, active_set_size=2, debias=True, n_orient=5, solver='bcd') + *args, active_set_size=2, debias=True, n_orient=5, solver="bcd" + ) assert_array_equal(np.where(active_set)[0], [0, 1, 2, 3, 4]) - with pytest.warns(RuntimeWarning, match='descent'): + with pytest.warns(RuntimeWarning, match="descent"): X_hat_cd, active_set, _ = mixed_norm_solver( - *args, active_set_size=2, debias=True, n_orient=5, solver='cd') + *args, active_set_size=2, debias=True, n_orient=5, solver="cd" + ) assert_array_equal(np.where(active_set)[0], [0, 1, 2, 3, 4]) assert_allclose(X_hat_bcd, X_hat_cd) @@ -98,7 +121,7 @@ def test_l21_mxne(): @pytest.mark.slowtest def test_non_convergence(): """Test non-convergence of MxNE solver to catch unexpected bugs.""" - n, p, t, alpha = 30, 40, 20, 1. + n, p, t, alpha = 30, 40, 20, 1.0 rng = np.random.RandomState(0) G = rng.randn(n, p) G /= np.std(G, axis=0)[None, :] @@ -111,23 +134,34 @@ def test_non_convergence(): # In case of non-convegence, we test that no error is returned. args = (M, G, alpha, 1, 1e-12) with catch_logging() as log: - mixed_norm_solver(*args, active_set_size=None, debias=True, - solver='bcd', verbose=True) + mixed_norm_solver( + *args, active_set_size=None, debias=True, solver="bcd", verbose=True + ) log = log.getvalue() - assert 'Convergence reached' not in log + assert "Convergence reached" not in log def test_tf_mxne(): """Test convergence of TF-MxNE solver.""" - alpha_space = 10. - alpha_time = 5. + alpha_space = 10.0 + alpha_time = 5.0 M, G, active_set = _generate_tf_data() with _record_warnings(): # CD X_hat_tf, active_set_hat_tf, E, gap_tfmxne = tf_mixed_norm_solver( - M, G, alpha_space, alpha_time, maxit=200, tol=1e-8, verbose=True, - n_orient=1, tstep=4, wsize=32, return_gap=True) + M, + G, + alpha_space, + alpha_time, + maxit=200, + tol=1e-8, + verbose=True, + n_orient=1, + tstep=4, + wsize=32, + return_gap=True, + ) assert_array_less(gap_tfmxne, 1e-8) assert_array_equal(np.where(active_set_hat_tf)[0], active_set) @@ -144,35 +178,40 @@ def test_norm_epsilon(): Y = np.zeros((n_steps * n_freqs).item()) l1_ratio = 0.03 - assert_allclose(norm_epsilon(Y, l1_ratio, phi), 0.) + assert_allclose(norm_epsilon(Y, l1_ratio, phi), 0.0) - Y[0] = 2. + Y[0] = 2.0 assert_allclose(norm_epsilon(Y, l1_ratio, phi), np.max(Y)) - l1_ratio = 1. + l1_ratio = 1.0 assert_allclose(norm_epsilon(Y, l1_ratio, phi), np.max(Y)) # dummy value without random: Y = np.arange((n_steps * n_freqs).item()) l1_ratio = 0.0 - assert_allclose(norm_epsilon(Y, l1_ratio, phi) ** 2, - stft_norm2(Y.reshape(-1, n_freqs[0], n_steps[0]))) + assert_allclose( + norm_epsilon(Y, l1_ratio, phi) ** 2, + stft_norm2(Y.reshape(-1, n_freqs[0], n_steps[0])), + ) l1_ratio = 0.03 # test that vanilla epsilon norm = weights equal to 1 w_time = np.ones(n_coefs[0]) Y = np.abs(np.random.randn(n_coefs[0])) - assert_allclose(norm_epsilon(Y, l1_ratio, phi), - norm_epsilon(Y, l1_ratio, phi, w_time=w_time)) + assert_allclose( + norm_epsilon(Y, l1_ratio, phi), norm_epsilon(Y, l1_ratio, phi, w_time=w_time) + ) # scaling w_time and w_space by the same amount should divide # epsilon norm by the same amount Y = np.arange(n_coefs.item()) + 1 - mult = 2. + mult = 2.0 assert_allclose( - norm_epsilon(Y, l1_ratio, phi, w_space=1, - w_time=np.ones(n_coefs.item())) / mult, - norm_epsilon(Y, l1_ratio, phi, w_space=mult, - w_time=mult * np.ones(n_coefs.item()))) + norm_epsilon(Y, l1_ratio, phi, w_space=1, w_time=np.ones(n_coefs.item())) + / mult, + norm_epsilon( + Y, l1_ratio, phi, w_space=mult, w_time=mult * np.ones(n_coefs.item()) + ), + ) @pytest.mark.slowtest # slow-ish on Travis OSX @@ -192,30 +231,59 @@ def test_dgapl21l1(): for l1_ratio in [0.05, 0.1]: alpha_max = norm_epsilon_inf(G, M, phi, l1_ratio, n_orient) - alpha_space = (1. - l1_ratio) * alpha_max + alpha_space = (1.0 - l1_ratio) * alpha_max alpha_time = l1_ratio * alpha_max Z = np.zeros([n_sources, phi.n_coefs.sum()]) # for alpha = alpha_max, Z = 0 is the solution so the dgap is 0 - gap = dgap_l21l1(M, G, Z, np.ones(n_sources, dtype=bool), - alpha_space, alpha_time, phi, phiT, - n_orient, -np.inf)[0] - - assert_allclose(0., gap) + gap = dgap_l21l1( + M, + G, + Z, + np.ones(n_sources, dtype=bool), + alpha_space, + alpha_time, + phi, + phiT, + n_orient, + -np.inf, + )[0] + + assert_allclose(0.0, gap) # check that solution for alpha smaller than alpha_max is non 0: X_hat_tf, active_set_hat_tf, E, gap = tf_mixed_norm_solver( - M, G, alpha_space / 1.01, alpha_time / 1.01, maxit=200, tol=1e-8, - verbose=True, debias=False, n_orient=n_orient, tstep=tstep, - wsize=wsize, return_gap=True) + M, + G, + alpha_space / 1.01, + alpha_time / 1.01, + maxit=200, + tol=1e-8, + verbose=True, + debias=False, + n_orient=n_orient, + tstep=tstep, + wsize=wsize, + return_gap=True, + ) # allow possible small numerical errors (negative gap) assert_array_less(-1e-10, gap) assert_array_less(gap, 1e-8) assert_array_less(1, len(active_set_hat_tf)) X_hat_tf, active_set_hat_tf, E, gap = tf_mixed_norm_solver( - M, G, alpha_space / 5., alpha_time / 5., maxit=200, tol=1e-8, - verbose=True, debias=False, n_orient=n_orient, tstep=tstep, - wsize=wsize, return_gap=True) + M, + G, + alpha_space / 5.0, + alpha_time / 5.0, + maxit=200, + tol=1e-8, + verbose=True, + debias=False, + n_orient=n_orient, + tstep=tstep, + wsize=wsize, + return_gap=True, + ) assert_array_less(-1e-10, gap) assert_array_less(gap, 1e-8) assert_array_less(1, len(active_set_hat_tf)) @@ -223,19 +291,37 @@ def test_dgapl21l1(): def test_tf_mxne_vs_mxne(): """Test equivalence of TF-MxNE (with alpha_time=0) and MxNE.""" - alpha_space = 60. - alpha_time = 0. + alpha_space = 60.0 + alpha_time = 0.0 M, G, active_set = _generate_tf_data() X_hat_tf, active_set_hat_tf, E = tf_mixed_norm_solver( - M, G, alpha_space, alpha_time, maxit=200, tol=1e-8, - verbose=True, debias=False, n_orient=1, tstep=4, wsize=32) + M, + G, + alpha_space, + alpha_time, + maxit=200, + tol=1e-8, + verbose=True, + debias=False, + n_orient=1, + tstep=4, + wsize=32, + ) # Also run L21 and check that we get the same X_hat_l21, _, _ = mixed_norm_solver( - M, G, alpha_space, maxit=200, tol=1e-8, verbose=False, n_orient=1, - active_set_size=None, debias=False) + M, + G, + alpha_space, + maxit=200, + tol=1e-8, + verbose=False, + n_orient=1, + active_set_size=None, + debias=False, + ) assert_allclose(X_hat_tf, X_hat_l21, rtol=1e-1) @@ -254,47 +340,107 @@ def test_iterative_reweighted_mxne(): with _record_warnings(): # CD X_hat_l21, _, _ = mixed_norm_solver( - M, G, alpha, maxit=1000, tol=1e-8, verbose=False, n_orient=1, - active_set_size=None, debias=False, solver='bcd') + M, + G, + alpha, + maxit=1000, + tol=1e-8, + verbose=False, + n_orient=1, + active_set_size=None, + debias=False, + solver="bcd", + ) with _record_warnings(): # CD X_hat_bcd, active_set, _ = iterative_mixed_norm_solver( - M, G, alpha, 1, maxit=1000, tol=1e-8, active_set_size=None, - debias=False, solver='bcd') + M, + G, + alpha, + 1, + maxit=1000, + tol=1e-8, + active_set_size=None, + debias=False, + solver="bcd", + ) assert_allclose(X_hat_bcd, X_hat_l21, rtol=1e-3) with _record_warnings(): # CD X_hat_bcd, active_set, _ = iterative_mixed_norm_solver( - M, G, alpha, 5, maxit=1000, tol=1e-8, active_set_size=2, - debias=True, solver='bcd') + M, + G, + alpha, + 5, + maxit=1000, + tol=1e-8, + active_set_size=2, + debias=True, + solver="bcd", + ) assert_array_equal(np.where(active_set)[0], [0, 4]) with _record_warnings(): # CD X_hat_cd, active_set, _ = iterative_mixed_norm_solver( - M, G, alpha, 5, maxit=1000, tol=1e-8, active_set_size=None, - debias=True, solver='cd') + M, + G, + alpha, + 5, + maxit=1000, + tol=1e-8, + active_set_size=None, + debias=True, + solver="cd", + ) assert_array_equal(np.where(active_set)[0], [0, 4]) assert_array_almost_equal(X_hat_bcd, X_hat_cd, 5) with _record_warnings(): # CD X_hat_bcd, active_set, _ = iterative_mixed_norm_solver( - M, G, alpha, 5, maxit=1000, tol=1e-8, active_set_size=2, - debias=True, n_orient=2, solver='bcd') + M, + G, + alpha, + 5, + maxit=1000, + tol=1e-8, + active_set_size=2, + debias=True, + n_orient=2, + solver="bcd", + ) assert_array_equal(np.where(active_set)[0], [0, 1, 4, 5]) # suppress a coordinate-descent warning here - with pytest.warns(RuntimeWarning, match='descent'): + with pytest.warns(RuntimeWarning, match="descent"): X_hat_cd, active_set, _ = iterative_mixed_norm_solver( - M, G, alpha, 5, maxit=1000, tol=1e-8, active_set_size=2, - debias=True, n_orient=2, solver='cd') + M, + G, + alpha, + 5, + maxit=1000, + tol=1e-8, + active_set_size=2, + debias=True, + n_orient=2, + solver="cd", + ) assert_array_equal(np.where(active_set)[0], [0, 1, 4, 5]) assert_allclose(X_hat_bcd, X_hat_cd) X_hat_bcd, active_set, _ = iterative_mixed_norm_solver( - M, G, alpha, 5, maxit=1000, tol=1e-8, active_set_size=2, debias=True, - n_orient=5) + M, G, alpha, 5, maxit=1000, tol=1e-8, active_set_size=2, debias=True, n_orient=5 + ) assert_array_equal(np.where(active_set)[0], [0, 1, 2, 3, 4]) - with pytest.warns(RuntimeWarning, match='descent'): + with pytest.warns(RuntimeWarning, match="descent"): X_hat_cd, active_set, _ = iterative_mixed_norm_solver( - M, G, alpha, 5, maxit=1000, tol=1e-8, active_set_size=2, - debias=True, n_orient=5, solver='cd') + M, + G, + alpha, + 5, + maxit=1000, + tol=1e-8, + active_set_size=2, + debias=True, + n_orient=5, + solver="cd", + ) assert_array_equal(np.where(active_set)[0], [0, 1, 2, 3, 4]) assert_allclose(X_hat_bcd, X_hat_cd) @@ -303,27 +449,69 @@ def test_iterative_reweighted_mxne(): def test_iterative_reweighted_tfmxne(): """Test convergence of irTF-MxNE solver.""" M, G, true_active_set = _generate_tf_data() - alpha_space = 38. + alpha_space = 38.0 alpha_time = 0.5 tstep, wsize = [4, 2], [64, 16] X_hat_tf, _, _ = tf_mixed_norm_solver( - M, G, alpha_space, alpha_time, maxit=1000, tol=1e-4, wsize=wsize, - tstep=tstep, verbose=False, n_orient=1, debias=False) + M, + G, + alpha_space, + alpha_time, + maxit=1000, + tol=1e-4, + wsize=wsize, + tstep=tstep, + verbose=False, + n_orient=1, + debias=False, + ) X_hat_bcd, active_set, _ = iterative_tf_mixed_norm_solver( - M, G, alpha_space, alpha_time, 1, wsize=wsize, tstep=tstep, - maxit=1000, tol=1e-4, debias=False, verbose=False) + M, + G, + alpha_space, + alpha_time, + 1, + wsize=wsize, + tstep=tstep, + maxit=1000, + tol=1e-4, + debias=False, + verbose=False, + ) assert_allclose(X_hat_tf, X_hat_bcd, rtol=1e-3) assert_array_equal(np.where(active_set)[0], true_active_set) - alpha_space = 50. + alpha_space = 50.0 X_hat_bcd, active_set, _ = iterative_tf_mixed_norm_solver( - M, G, alpha_space, alpha_time, 3, wsize=wsize, tstep=tstep, - n_orient=5, maxit=1000, tol=1e-4, debias=False, verbose=False) + M, + G, + alpha_space, + alpha_time, + 3, + wsize=wsize, + tstep=tstep, + n_orient=5, + maxit=1000, + tol=1e-4, + debias=False, + verbose=False, + ) assert_array_equal(np.where(active_set)[0], [0, 1, 2, 3, 4]) - alpha_space = 40. + alpha_space = 40.0 X_hat_bcd, active_set, _ = iterative_tf_mixed_norm_solver( - M, G, alpha_space, alpha_time, 2, wsize=wsize, tstep=tstep, - n_orient=2, maxit=1000, tol=1e-4, debias=False, verbose=False) + M, + G, + alpha_space, + alpha_time, + 2, + wsize=wsize, + tstep=tstep, + n_orient=2, + maxit=1000, + tol=1e-4, + debias=False, + verbose=False, + ) assert_array_equal(np.where(active_set)[0], [0, 1, 4, 5]) diff --git a/mne/io/__init__.py b/mne/io/__init__.py index 8af9495499d..e51df7c9183 100644 --- a/mne/io/__init__.py +++ b/mne/io/__init__.py @@ -6,9 +6,18 @@ # License: BSD-3-Clause from .open import fiff_open, show_fiff, _fiff_get_fid -from .meas_info import (read_fiducials, write_fiducials, read_info, write_info, - _empty_info, _merge_info, _force_update_info, Info, - anonymize_info, _writing_info_hdf5) +from .meas_info import ( + read_fiducials, + write_fiducials, + read_info, + write_info, + _empty_info, + _merge_info, + _force_update_info, + Info, + anonymize_info, + _writing_info_hdf5, +) from .proj import make_eeg_average_ref_proj, Projection from .tag import _loc_to_coil_trans, _coil_trans_to_loc, _loc_to_eeg_loc @@ -56,8 +65,7 @@ from .boxy import read_raw_boxy from .snirf import read_raw_snirf from .persyst import read_raw_persyst -from .fieldtrip import (read_raw_fieldtrip, read_epochs_fieldtrip, - read_evoked_fieldtrip) +from .fieldtrip import read_raw_fieldtrip, read_epochs_fieldtrip, read_evoked_fieldtrip from .nihon import read_raw_nihon from ._read_raw import read_raw from .eyelink import read_raw_eyelink @@ -67,5 +75,4 @@ from .fiff import Raw from .fiff import Raw as RawFIF from .base import concatenate_raws, match_channel_orders -from .reference import (set_eeg_reference, set_bipolar_reference, - add_reference_channels) +from .reference import set_eeg_reference, set_bipolar_reference, add_reference_channels diff --git a/mne/io/_digitization.py b/mne/io/_digitization.py index c6baf9f507b..020200b634b 100644 --- a/mne/io/_digitization.py +++ b/mne/io/_digitization.py @@ -19,28 +19,36 @@ from .constants import FIFF, _coord_frame_named from .tree import dir_tree_find from .tag import read_tag -from .write import (start_and_end_file, write_dig_points) - -from ..transforms import (apply_trans, Transform, - get_ras_to_neuromag_trans, combine_transforms, - invert_transform, _to_const, _str_to_frame, - _coord_frame_name) +from .write import start_and_end_file, write_dig_points + +from ..transforms import ( + apply_trans, + Transform, + get_ras_to_neuromag_trans, + combine_transforms, + invert_transform, + _to_const, + _str_to_frame, + _coord_frame_name, +) from .. import __version__ _dig_kind_dict = { - 'cardinal': FIFF.FIFFV_POINT_CARDINAL, - 'hpi': FIFF.FIFFV_POINT_HPI, - 'eeg': FIFF.FIFFV_POINT_EEG, - 'extra': FIFF.FIFFV_POINT_EXTRA, + "cardinal": FIFF.FIFFV_POINT_CARDINAL, + "hpi": FIFF.FIFFV_POINT_HPI, + "eeg": FIFF.FIFFV_POINT_EEG, + "extra": FIFF.FIFFV_POINT_EXTRA, } _dig_kind_ints = tuple(sorted(_dig_kind_dict.values())) -_dig_kind_proper = {'cardinal': 'Cardinal', - 'hpi': 'HPI', - 'eeg': 'EEG', - 'extra': 'Extra', - 'unknown': 'Unknown'} +_dig_kind_proper = { + "cardinal": "Cardinal", + "hpi": "HPI", + "eeg": "EEG", + "extra": "Extra", + "unknown": "Unknown", +} _dig_kind_rev = {val: key for key, val in _dig_kind_dict.items()} -_cardinal_kind_rev = {1: 'LPA', 2: 'Nasion', 3: 'RPA', 4: 'Inion'} +_cardinal_kind_rev = {1: "LPA", 2: "Nasion", 3: "RPA", 4: "Inion"} def _format_dig_points(dig, enforce_order=False): @@ -56,8 +64,8 @@ def _format_dig_points(dig, enforce_order=False): # use a heap to enforce order on FIDS, EEG, Extra for idx, digpoint in enumerate(dig): - ident = digpoint['ident'] - kind = digpoint['kind'] + ident = digpoint["ident"] + kind = digpoint["kind"] # push onto heap based on 'ident' (for the order) for # each of the possible DigPoint 'kind's @@ -80,9 +88,13 @@ def _format_dig_points(dig, enforce_order=False): eeg_digpoints.sort() extra_digpoints.sort(), head_digpoints.sort() new_dig = [] - for idx, d in enumerate(fids_digpoints + hpi_digpoints + - extra_digpoints + eeg_digpoints + - head_digpoints): + for idx, d in enumerate( + fids_digpoints + + hpi_digpoints + + extra_digpoints + + eeg_digpoints + + head_digpoints + ): new_dig.append(d[-1]) dig = new_dig @@ -90,12 +102,12 @@ def _format_dig_points(dig, enforce_order=False): def _get_dig_eeg(dig): - return [d for d in dig if d['kind'] == FIFF.FIFFV_POINT_EEG] + return [d for d in dig if d["kind"] == FIFF.FIFFV_POINT_EEG] def _count_points_by_type(dig): """Get the number of points of each type.""" - occurrences = Counter([d['kind'] for d in dig]) + occurrences = Counter([d["kind"] for d in dig]) return dict( fid=occurrences[FIFF.FIFFV_POINT_CARDINAL], hpi=occurrences[FIFF.FIFFV_POINT_HPI], @@ -104,7 +116,7 @@ def _count_points_by_type(dig): ) -_dig_keys = {'kind', 'ident', 'r', 'coord_frame'} +_dig_keys = {"kind", "ident", "r", "coord_frame"} class DigPoint(dict): @@ -129,27 +141,28 @@ class DigPoint(dict): """ def __repr__(self): # noqa: D105 - if self['kind'] == FIFF.FIFFV_POINT_CARDINAL: - id_ = _cardinal_kind_rev.get(self['ident'], 'Unknown cardinal') + if self["kind"] == FIFF.FIFFV_POINT_CARDINAL: + id_ = _cardinal_kind_rev.get(self["ident"], "Unknown cardinal") else: - id_ = _dig_kind_proper[ - _dig_kind_rev.get(self['kind'], 'unknown')] - id_ = ('%s #%s' % (id_, self['ident'])) + id_ = _dig_kind_proper[_dig_kind_rev.get(self["kind"], "unknown")] + id_ = "%s #%s" % (id_, self["ident"]) id_ = id_.rjust(10) - cf = _coord_frame_name(self['coord_frame']) - if 'voxel' in cf: - pos = ('(%0.1f, %0.1f, %0.1f)' % tuple(self['r'])).ljust(25) + cf = _coord_frame_name(self["coord_frame"]) + if "voxel" in cf: + pos = ("(%0.1f, %0.1f, %0.1f)" % tuple(self["r"])).ljust(25) else: - pos = ('(%0.1f, %0.1f, %0.1f) mm' % - tuple(1000 * self['r'])).ljust(25) - return ('' % (id_, pos, cf)) + pos = ("(%0.1f, %0.1f, %0.1f) mm" % tuple(1000 * self["r"])).ljust(25) + return "" % (id_, pos, cf) # speed up info copy by only deep copying the mutable item def __deepcopy__(self, memodict): """Make a deepcopy.""" return DigPoint( - kind=self['kind'], r=self['r'].copy(), - ident=self['ident'], coord_frame=self['coord_frame']) + kind=self["kind"], + r=self["r"].copy(), + ident=self["ident"], + coord_frame=self["coord_frame"], + ) def __eq__(self, other): # noqa: D105 """Compare two DigPoints. @@ -157,13 +170,13 @@ def __eq__(self, other): # noqa: D105 Two digpoints are equal if they are the same kind, share the same coordinate frame and position. """ - my_keys = ['kind', 'ident', 'coord_frame'] + my_keys = ["kind", "ident", "coord_frame"] if set(self.keys()) != set(other.keys()): return False elif any(self[_] != other[_] for _ in my_keys): return False else: - return np.allclose(self['r'], other['r']) + return np.allclose(self["r"], other["r"]) def _read_dig_fif(fid, meas_info): @@ -171,16 +184,16 @@ def _read_dig_fif(fid, meas_info): isotrak = dir_tree_find(meas_info, FIFF.FIFFB_ISOTRAK) dig = None if len(isotrak) == 0: - logger.info('Isotrak not found') + logger.info("Isotrak not found") elif len(isotrak) > 1: - warn('Multiple Isotrak found') + warn("Multiple Isotrak found") else: isotrak = isotrak[0] coord_frame = FIFF.FIFFV_COORD_HEAD dig = [] - for k in range(isotrak['nent']): - kind = isotrak['directory'][k].kind - pos = isotrak['directory'][k].pos + for k in range(isotrak["nent"]): + kind = isotrak["directory"][k].kind + pos = isotrak["directory"][k].pos if kind == FIFF.FIFF_DIG_POINT: tag = read_tag(fid, pos) dig.append(tag.data) @@ -188,7 +201,7 @@ def _read_dig_fif(fid, meas_info): tag = read_tag(fid, pos) coord_frame = _coord_frame_named.get(int(tag.data.item())) for d in dig: - d['coord_frame'] = coord_frame + d["coord_frame"] = coord_frame return _format_dig_points(dig) @@ -217,21 +230,22 @@ def write_dig(fname, pts, coord_frame=None, *, overwrite=False, verbose=None): fname = _check_fname(fname, overwrite=overwrite) if coord_frame is not None: coord_frame = _to_const(coord_frame) - pts_frames = {pt.get('coord_frame', coord_frame) for pt in pts} + pts_frames = {pt.get("coord_frame", coord_frame) for pt in pts} bad_frames = pts_frames - {coord_frame} if len(bad_frames) > 0: raise ValueError( - 'Points have coord_frame entries that are incompatible with ' - 'coord_frame=%i: %s.' % (coord_frame, str(tuple(bad_frames)))) + "Points have coord_frame entries that are incompatible with " + "coord_frame=%i: %s." % (coord_frame, str(tuple(bad_frames))) + ) with start_and_end_file(fname) as fid: write_dig_points(fid, pts, block=True, coord_frame=coord_frame) _cardinal_ident_mapping = { - FIFF.FIFFV_POINT_NASION: 'nasion', - FIFF.FIFFV_POINT_LPA: 'lpa', - FIFF.FIFFV_POINT_RPA: 'rpa', + FIFF.FIFFV_POINT_NASION: "nasion", + FIFF.FIFFV_POINT_LPA: "lpa", + FIFF.FIFFV_POINT_RPA: "rpa", } @@ -239,8 +253,8 @@ def _ensure_fiducials_head(dig): # Ensure that there are all three fiducials in the head coord frame fids = dict() for d in dig: - if d['kind'] == FIFF.FIFFV_POINT_CARDINAL: - name = _cardinal_ident_mapping.get(d['ident'], None) + if d["kind"] == FIFF.FIFFV_POINT_CARDINAL: + name = _cardinal_ident_mapping.get(d["ident"], None) if name is not None: fids[name] = d radius = None @@ -253,17 +267,22 @@ def _ensure_fiducials_head(dig): if name not in fids: if radius is None: radius = [ - np.linalg.norm(d['r']) for d in dig - if d['coord_frame'] == FIFF.FIFFV_COORD_HEAD - and not np.isnan(d['r']).any()] + np.linalg.norm(d["r"]) + for d in dig + if d["coord_frame"] == FIFF.FIFFV_COORD_HEAD + and not np.isnan(d["r"]).any() + ] if not radius: return # can't complete, no head points radius = np.mean(radius) - dig.append(DigPoint( - kind=FIFF.FIFFV_POINT_CARDINAL, ident=ident, - r=np.array(mults[name], float) * radius, - coord_frame=FIFF.FIFFV_COORD_HEAD, - )) + dig.append( + DigPoint( + kind=FIFF.FIFFV_POINT_CARDINAL, + ident=ident, + r=np.array(mults[name], float) * radius, + coord_frame=FIFF.FIFFV_COORD_HEAD, + ) + ) # XXXX: @@ -288,27 +307,29 @@ def _get_data_as_dict_from_dig(dig, exclude_ref_channel=True): fids, dig_ch_pos_location = dict(), list() for d in dig: - if d['kind'] == FIFF.FIFFV_POINT_CARDINAL: - fids[_cardinal_ident_mapping[d['ident']]] = d['r'] - elif d['kind'] == FIFF.FIFFV_POINT_HPI: - hpi.append(d['r']) - elp.append(d['r']) - elif d['kind'] == FIFF.FIFFV_POINT_EXTRA: - hsp.append(d['r']) - elif d['kind'] == FIFF.FIFFV_POINT_EEG: - if d['ident'] != 0 or not exclude_ref_channel: - dig_ch_pos_location.append(d['r']) - - dig_coord_frames = set([d['coord_frame'] for d in dig]) + if d["kind"] == FIFF.FIFFV_POINT_CARDINAL: + fids[_cardinal_ident_mapping[d["ident"]]] = d["r"] + elif d["kind"] == FIFF.FIFFV_POINT_HPI: + hpi.append(d["r"]) + elp.append(d["r"]) + elif d["kind"] == FIFF.FIFFV_POINT_EXTRA: + hsp.append(d["r"]) + elif d["kind"] == FIFF.FIFFV_POINT_EEG: + if d["ident"] != 0 or not exclude_ref_channel: + dig_ch_pos_location.append(d["r"]) + + dig_coord_frames = set([d["coord_frame"] for d in dig]) if len(dig_coord_frames) != 1: - raise RuntimeError('Only single coordinate frame in dig is supported, ' - f'got {dig_coord_frames}') + raise RuntimeError( + "Only single coordinate frame in dig is supported, " + f"got {dig_coord_frames}" + ) dig_ch_pos_location = np.array(dig_ch_pos_location) dig_ch_pos_location.shape = (-1, 3) # empty will be (0, 3) return Bunch( - nasion=fids.get('nasion', None), - lpa=fids.get('lpa', None), - rpa=fids.get('rpa', None), + nasion=fids.get("nasion", None), + lpa=fids.get("lpa", None), + rpa=fids.get("rpa", None), hsp=np.array(hsp) if len(hsp) else None, hpi=np.array(hpi) if len(hpi) else None, elp=np.array(elp) if len(elp) else None, @@ -322,20 +343,21 @@ def _get_fid_coords(dig, raise_error=True): fid_coord_frames = dict() for d in dig: - if d['kind'] == FIFF.FIFFV_POINT_CARDINAL: - key = _cardinal_ident_mapping[d['ident']] - fid_coords[key] = d['r'] - fid_coord_frames[key] = d['coord_frame'] + if d["kind"] == FIFF.FIFFV_POINT_CARDINAL: + key = _cardinal_ident_mapping[d["ident"]] + fid_coords[key] = d["r"] + fid_coord_frames[key] = d["coord_frame"] if len(fid_coord_frames) > 0 and raise_error: - if set(fid_coord_frames.keys()) != set(['nasion', 'lpa', 'rpa']): - raise ValueError("Some fiducial points are missing (got %s)." % - fid_coord_frames.keys()) + if set(fid_coord_frames.keys()) != set(["nasion", "lpa", "rpa"]): + raise ValueError( + "Some fiducial points are missing (got %s)." % fid_coord_frames.keys() + ) if len(set(fid_coord_frames.values())) > 1: raise ValueError( - 'All fiducial points must be in the same coordinate system ' - '(got %s)' % len(fid_coord_frames) + "All fiducial points must be in the same coordinate system " + "(got %s)" % len(fid_coord_frames) ) coord_frame = fid_coord_frames.popitem()[1] if fid_coord_frames else None @@ -357,18 +379,19 @@ def _write_dig_points(fname, dig_points): _, ext = op.splitext(fname) dig_points = np.asarray(dig_points) if (dig_points.ndim != 2) or (dig_points.shape[1] != 3): - err = ("Points must be of shape (n_points, 3), " - "not %s" % (dig_points.shape,)) + err = "Points must be of shape (n_points, 3), " "not %s" % (dig_points.shape,) raise ValueError(err) - if ext == '.txt': - with open(fname, 'wb') as fid: + if ext == ".txt": + with open(fname, "wb") as fid: version = __version__ now = datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y") - fid.write(b'%% Ascii 3D points file created by mne-python version' - b' %s at %s\n' % (version.encode(), now.encode())) - fid.write(b'%% %d 3D points, x y z per line\n' % len(dig_points)) - np.savetxt(fid, dig_points, delimiter='\t', newline='\n') + fid.write( + b"%% Ascii 3D points file created by mne-python version" + b" %s at %s\n" % (version.encode(), now.encode()) + ) + fid.write(b"%% %d 3D points, x y z per line\n" % len(dig_points)) + np.savetxt(fid, dig_points, delimiter="\t", newline="\n") else: msg = "Unrecognized extension: %r. Need '.txt'." % ext raise ValueError(msg) @@ -376,14 +399,24 @@ def _write_dig_points(fname, dig_points): def _coord_frame_const(coord_frame): if not isinstance(coord_frame, str) or coord_frame not in _str_to_frame: - raise ValueError('coord_frame must be one of %s, got %s' - % (sorted(_str_to_frame.keys()), coord_frame)) + raise ValueError( + "coord_frame must be one of %s, got %s" + % (sorted(_str_to_frame.keys()), coord_frame) + ) return _str_to_frame[coord_frame] -def _make_dig_points(nasion=None, lpa=None, rpa=None, hpi=None, - extra_points=None, dig_ch_pos=None, *, - coord_frame='head', add_missing_fiducials=False): +def _make_dig_points( + nasion=None, + lpa=None, + rpa=None, + hpi=None, + extra_points=None, + dig_ch_pos=None, + *, + coord_frame="head", + add_missing_fiducials=False, +): """Construct digitizer info for the info. Parameters @@ -419,67 +452,105 @@ def _make_dig_points(nasion=None, lpa=None, rpa=None, hpi=None, if lpa is not None: lpa = np.asarray(lpa) if lpa.shape != (3,): - raise ValueError('LPA should have the shape (3,) instead of %s' - % (lpa.shape,)) - dig.append({'r': lpa, 'ident': FIFF.FIFFV_POINT_LPA, - 'kind': FIFF.FIFFV_POINT_CARDINAL, - 'coord_frame': coord_frame}) + raise ValueError( + "LPA should have the shape (3,) instead of %s" % (lpa.shape,) + ) + dig.append( + { + "r": lpa, + "ident": FIFF.FIFFV_POINT_LPA, + "kind": FIFF.FIFFV_POINT_CARDINAL, + "coord_frame": coord_frame, + } + ) if nasion is not None: nasion = np.asarray(nasion) if nasion.shape != (3,): - raise ValueError('Nasion should have the shape (3,) instead of %s' - % (nasion.shape,)) - dig.append({'r': nasion, 'ident': FIFF.FIFFV_POINT_NASION, - 'kind': FIFF.FIFFV_POINT_CARDINAL, - 'coord_frame': coord_frame}) + raise ValueError( + "Nasion should have the shape (3,) instead of %s" % (nasion.shape,) + ) + dig.append( + { + "r": nasion, + "ident": FIFF.FIFFV_POINT_NASION, + "kind": FIFF.FIFFV_POINT_CARDINAL, + "coord_frame": coord_frame, + } + ) if rpa is not None: rpa = np.asarray(rpa) if rpa.shape != (3,): - raise ValueError('RPA should have the shape (3,) instead of %s' - % (rpa.shape,)) - dig.append({'r': rpa, 'ident': FIFF.FIFFV_POINT_RPA, - 'kind': FIFF.FIFFV_POINT_CARDINAL, - 'coord_frame': coord_frame}) + raise ValueError( + "RPA should have the shape (3,) instead of %s" % (rpa.shape,) + ) + dig.append( + { + "r": rpa, + "ident": FIFF.FIFFV_POINT_RPA, + "kind": FIFF.FIFFV_POINT_CARDINAL, + "coord_frame": coord_frame, + } + ) if hpi is not None: hpi = np.asarray(hpi) if hpi.ndim != 2 or hpi.shape[1] != 3: - raise ValueError('HPI should have the shape (n_points, 3) instead ' - 'of %s' % (hpi.shape,)) + raise ValueError( + "HPI should have the shape (n_points, 3) instead " + "of %s" % (hpi.shape,) + ) for idx, point in enumerate(hpi): - dig.append({'r': point, 'ident': idx + 1, - 'kind': FIFF.FIFFV_POINT_HPI, - 'coord_frame': coord_frame}) + dig.append( + { + "r": point, + "ident": idx + 1, + "kind": FIFF.FIFFV_POINT_HPI, + "coord_frame": coord_frame, + } + ) if extra_points is not None: extra_points = np.asarray(extra_points) if len(extra_points) and extra_points.shape[1] != 3: - raise ValueError('Points should have the shape (n_points, 3) ' - 'instead of %s' % (extra_points.shape,)) + raise ValueError( + "Points should have the shape (n_points, 3) " + "instead of %s" % (extra_points.shape,) + ) for idx, point in enumerate(extra_points): - dig.append({'r': point, 'ident': idx + 1, - 'kind': FIFF.FIFFV_POINT_EXTRA, - 'coord_frame': coord_frame}) + dig.append( + { + "r": point, + "ident": idx + 1, + "kind": FIFF.FIFFV_POINT_EXTRA, + "coord_frame": coord_frame, + } + ) if dig_ch_pos is not None: idents = [] use_arange = False for key, value in dig_ch_pos.items(): - _validate_type(key, str, 'dig_ch_pos') + _validate_type(key, str, "dig_ch_pos") try: idents.append(int(key[-3:])) except ValueError: use_arange = True - _validate_type(value, (np.ndarray, list, tuple), 'dig_ch_pos') + _validate_type(value, (np.ndarray, list, tuple), "dig_ch_pos") value = np.array(value, dtype=float) dig_ch_pos[key] = value - if value.shape != (3, ): + if value.shape != (3,): raise RuntimeError( "The position should be a 1D array of 3 floats. " - f"Provided shape {value.shape}.") + f"Provided shape {value.shape}." + ) if use_arange: idents = np.arange(1, len(dig_ch_pos) + 1) for key, ident in zip(dig_ch_pos, idents): - dig.append({'r': dig_ch_pos[key], 'ident': int(ident), - 'kind': FIFF.FIFFV_POINT_EEG, - 'coord_frame': coord_frame}) + dig.append( + { + "r": dig_ch_pos[key], + "ident": int(ident), + "kind": FIFF.FIFFV_POINT_EEG, + "coord_frame": coord_frame, + } + ) if add_missing_fiducials: assert coord_frame == FIFF.FIFFV_COORD_HEAD # These being none is really an assumption that if you have one you @@ -506,13 +577,11 @@ def _call_make_dig_points(nasion, lpa, rpa, hpi, extra, convert=True): else: neuromag_trans = None - ctf_head_t = Transform(fro='ctf_head', to='head', trans=neuromag_trans) + ctf_head_t = Transform(fro="ctf_head", to="head", trans=neuromag_trans) - info_dig = _make_dig_points(nasion=nasion, - lpa=lpa, - rpa=rpa, - hpi=hpi, - extra_points=extra) + info_dig = _make_dig_points( + nasion=nasion, lpa=lpa, rpa=rpa, hpi=hpi, extra_points=extra + ) return info_dig, ctf_head_t @@ -527,19 +596,26 @@ def _artemis123_read_pos(nas, lpa, rpa, hpi, extra): ############################################################################## # From bti -def _make_bti_dig_points(nasion, lpa, rpa, hpi, extra, - convert=False, use_hpi=False, - bti_dev_t=False, dev_ctf_t=False): - +def _make_bti_dig_points( + nasion, + lpa, + rpa, + hpi, + extra, + convert=False, + use_hpi=False, + bti_dev_t=False, + dev_ctf_t=False, +): _hpi = hpi if use_hpi else None - info_dig, ctf_head_t = _call_make_dig_points(nasion, lpa, rpa, _hpi, extra, - convert) + info_dig, ctf_head_t = _call_make_dig_points(nasion, lpa, rpa, _hpi, extra, convert) if convert: - t = combine_transforms(invert_transform(bti_dev_t), dev_ctf_t, - 'meg', 'ctf_head') - dev_head_t = combine_transforms(t, ctf_head_t, 'meg', 'head') + t = combine_transforms( + invert_transform(bti_dev_t), dev_ctf_t, "meg", "ctf_head" + ) + dev_head_t = combine_transforms(t, ctf_head_t, "meg", "head") else: - dev_head_t = Transform('meg', 'head', trans=None) + dev_head_t = Transform("meg", "head", trans=None) return info_dig, dev_head_t, ctf_head_t # ctf_head_t should not be needed diff --git a/mne/io/_read_raw.py b/mne/io/_read_raw.py index 5227b33ae86..d3c5b1eea07 100644 --- a/mne/io/_read_raw.py +++ b/mne/io/_read_raw.py @@ -8,12 +8,27 @@ from pathlib import Path from functools import partial -from . import (read_raw_edf, read_raw_bdf, read_raw_gdf, read_raw_brainvision, - read_raw_fif, read_raw_eeglab, read_raw_cnt, read_raw_egi, - read_raw_eximia, read_raw_nirx, read_raw_fieldtrip, - read_raw_artemis123, read_raw_nicolet, read_raw_kit, - read_raw_ctf, read_raw_boxy, read_raw_snirf, read_raw_fil, - read_raw_nihon) +from . import ( + read_raw_edf, + read_raw_bdf, + read_raw_gdf, + read_raw_brainvision, + read_raw_fif, + read_raw_eeglab, + read_raw_cnt, + read_raw_egi, + read_raw_eximia, + read_raw_nirx, + read_raw_fieldtrip, + read_raw_artemis123, + read_raw_nicolet, + read_raw_kit, + read_raw_ctf, + read_raw_boxy, + read_raw_snirf, + read_raw_fil, + read_raw_nihon, +) from ..utils import fill_doc @@ -45,8 +60,8 @@ def _read_unsupported(fname, **kwargs): ".snirf": dict(SNIRF=read_raw_snirf), ".mat": dict(fieldtrip=read_raw_fieldtrip), ".bin": { - 'ARTEMIS': read_raw_artemis123, - 'UCL FIL OPM': read_raw_fil, + "ARTEMIS": read_raw_artemis123, + "UCL FIL OPM": read_raw_fil, }, ".data": dict(Nicolet=read_raw_nicolet), ".sqd": dict(KIT=read_raw_kit), @@ -72,7 +87,7 @@ def split_name_ext(fname): for si in range(-maxsuffixes, 0): ext = "".join(suffixes[si:]).lower() if ext in readers: - return Path(fname).name[:-len(ext)], ext + return Path(fname).name[: -len(ext)], ext return fname, None # unknown file extension @@ -112,8 +127,8 @@ def read_raw(fname, *, preload=False, verbose=None, **kwargs): Raw object. """ _, ext = split_name_ext(fname) - kwargs['verbose'] = verbose - kwargs['preload'] = preload + kwargs["verbose"] = verbose + kwargs["preload"] = preload if ext not in readers: _read_unsupported(fname) these_readers = list(readers[ext].values()) @@ -124,10 +139,12 @@ def read_raw(fname, *, preload=False, verbose=None, **kwargs): if len(these_readers) == 1: raise else: - choices = '\n'.join( - f'mne.io.{func.__name__.ljust(20)} ({kind})' - for kind, func in readers[ext].items()) + choices = "\n".join( + f"mne.io.{func.__name__.ljust(20)} ({kind})" + for kind, func in readers[ext].items() + ) raise RuntimeError( - 'Could not read file using any of the possible readers for ' - f'extension {ext}. Consider trying to read the file directly with ' - f'one of:\n{choices}') + "Could not read file using any of the possible readers for " + f"extension {ext}. Consider trying to read the file directly with " + f"one of:\n{choices}" + ) diff --git a/mne/io/array/array.py b/mne/io/array/array.py index b4cf25d9f65..7e7ffded42a 100644 --- a/mne/io/array/array.py +++ b/mne/io/array/array.py @@ -51,37 +51,50 @@ class RawArray(BaseRaw): """ @verbose - def __init__(self, data, info, first_samp=0, copy='auto', - verbose=None): # noqa: D102 - _validate_type(info, 'info', 'info') - _check_option('copy', copy, ('data', 'info', 'both', 'auto', None)) + def __init__( + self, data, info, first_samp=0, copy="auto", verbose=None + ): # noqa: D102 + _validate_type(info, "info", "info") + _check_option("copy", copy, ("data", "info", "both", "auto", None)) dtype = np.complex128 if np.any(np.iscomplex(data)) else np.float64 orig_data = data data = np.asanyarray(orig_data, dtype=dtype) if data.ndim != 2: - raise ValueError('Data must be a 2D array of shape (n_channels, ' - 'n_samples), got shape %s' % (data.shape,)) - if len(data) != len(info['ch_names']): - raise ValueError('len(data) (%s) does not match ' - 'len(info["ch_names"]) (%s)' - % (len(data), len(info['ch_names']))) - assert len(info['ch_names']) == info['nchan'] - if copy in ('auto', 'info', 'both'): + raise ValueError( + "Data must be a 2D array of shape (n_channels, " + "n_samples), got shape %s" % (data.shape,) + ) + if len(data) != len(info["ch_names"]): + raise ValueError( + "len(data) (%s) does not match " + 'len(info["ch_names"]) (%s)' % (len(data), len(info["ch_names"])) + ) + assert len(info["ch_names"]) == info["nchan"] + if copy in ("auto", "info", "both"): info = info.copy() - if copy in ('data', 'both'): + if copy in ("data", "both"): if data is orig_data: data = data.copy() - elif copy != 'auto' and data is not orig_data: - raise ValueError('data copying was not requested by copy=%r but ' - 'it was required to get to double floating point ' - 'precision' % (copy,)) - logger.info('Creating RawArray with %s data, n_channels=%s, n_times=%s' - % (dtype.__name__, data.shape[0], data.shape[1])) - super(RawArray, self).__init__(info, data, - first_samps=(int(first_samp),), - dtype=dtype, verbose=verbose) - logger.info(' Range : %d ... %d = %9.3f ... %9.3f secs' % ( - self.first_samp, self.last_samp, - float(self.first_samp) / info['sfreq'], - float(self.last_samp) / info['sfreq'])) - logger.info('Ready.') + elif copy != "auto" and data is not orig_data: + raise ValueError( + "data copying was not requested by copy=%r but " + "it was required to get to double floating point " + "precision" % (copy,) + ) + logger.info( + "Creating RawArray with %s data, n_channels=%s, n_times=%s" + % (dtype.__name__, data.shape[0], data.shape[1]) + ) + super(RawArray, self).__init__( + info, data, first_samps=(int(first_samp),), dtype=dtype, verbose=verbose + ) + logger.info( + " Range : %d ... %d = %9.3f ... %9.3f secs" + % ( + self.first_samp, + self.last_samp, + float(self.first_samp) / info["sfreq"], + float(self.last_samp) / info["sfreq"], + ) + ) + logger.info("Ready.") diff --git a/mne/io/array/tests/test_array.py b/mne/io/array/tests/test_array.py index 4ab3587b8f6..1a96b9e4488 100644 --- a/mne/io/array/tests/test_array.py +++ b/mne/io/array/tests/test_array.py @@ -5,8 +5,7 @@ from pathlib import Path import numpy as np -from numpy.testing import (assert_array_almost_equal, assert_allclose, - assert_equal) +from numpy.testing import assert_array_almost_equal, assert_allclose, assert_equal import pytest import matplotlib.pyplot as plt @@ -24,23 +23,24 @@ def test_long_names(): """Test long name support.""" - info = create_info(['a' * 15 + 'b', 'a' * 16], 1000., verbose='error') + info = create_info(["a" * 15 + "b", "a" * 16], 1000.0, verbose="error") data = np.zeros((2, 1000)) raw = RawArray(data, info) - assert raw.ch_names == ['a' * 15 + 'b', 'a' * 16] + assert raw.ch_names == ["a" * 15 + "b", "a" * 16] # and a way to get the old behavior - raw.rename_channels({k: k[:13] for k in raw.ch_names}, - allow_duplicates=True, verbose='error') - assert raw.ch_names == ['a' * 13 + '-0', 'a' * 13 + '-1'] - info = create_info(['a' * 16] * 11, 1000., verbose='error') + raw.rename_channels( + {k: k[:13] for k in raw.ch_names}, allow_duplicates=True, verbose="error" + ) + assert raw.ch_names == ["a" * 13 + "-0", "a" * 13 + "-1"] + info = create_info(["a" * 16] * 11, 1000.0, verbose="error") data = np.zeros((11, 1000)) raw = RawArray(data, info) - assert raw.ch_names == ['a' * 16 + '-%s' % ii for ii in range(11)] + assert raw.ch_names == ["a" * 16 + "-%s" % ii for ii in range(11)] def test_array_copy(): """Test copying during construction.""" - info = create_info(1, 1000.) + info = create_info(1, 1000.0) data = np.zeros((1, 1000)) # 'auto' (default) raw = RawArray(data, info) @@ -50,27 +50,27 @@ def test_array_copy(): assert raw._data is not data assert raw.info is not info # 'info' (more restrictive) - raw = RawArray(data, info, copy='info') + raw = RawArray(data, info, copy="info") assert raw._data is data assert raw.info is not info with pytest.raises(ValueError, match="data copying was not .* copy='info"): - RawArray(data.astype(np.float32), info, copy='info') + RawArray(data.astype(np.float32), info, copy="info") # 'data' - raw = RawArray(data, info, copy='data') + raw = RawArray(data, info, copy="data") assert raw._data is not data assert raw.info is info # 'both' - raw = RawArray(data, info, copy='both') + raw = RawArray(data, info, copy="both") assert raw._data is not data assert raw.info is not info - raw = RawArray(data.astype(np.float32), info, copy='both') + raw = RawArray(data.astype(np.float32), info, copy="both") assert raw._data is not data assert raw.info is not info # None raw = RawArray(data, info, copy=None) assert raw._data is data assert raw.info is info - with pytest.raises(ValueError, match='data copying was not .* copy=None'): + with pytest.raises(ValueError, match="data copying was not .* copy=None"): RawArray(data.astype(np.float32), info, copy=None) @@ -80,19 +80,24 @@ def test_array_raw(): # creating raw = read_raw_fif(fif_fname).crop(2, 5) data, times = raw[:, :] - sfreq = raw.info['sfreq'] - ch_names = [(ch[4:] if 'STI' not in ch else ch) - for ch in raw.info['ch_names']] # change them, why not + sfreq = raw.info["sfreq"] + ch_names = [ + (ch[4:] if "STI" not in ch else ch) for ch in raw.info["ch_names"] + ] # change them, why not types = list() for ci in range(101): - types.extend(('grad', 'grad', 'mag')) - types.extend(['ecog', 'seeg', 'hbo']) # really 4 meg channels - types.extend(['stim'] * 9) - types.extend(['dbs']) # really eeg channel - types.extend(['eeg'] * 60) - picks = np.concatenate([pick_types(raw.info, meg=True)[::20], - pick_types(raw.info, meg=False, stim=True), - pick_types(raw.info, meg=False, eeg=True)[::20]]) + types.extend(("grad", "grad", "mag")) + types.extend(["ecog", "seeg", "hbo"]) # really 4 meg channels + types.extend(["stim"] * 9) + types.extend(["dbs"]) # really eeg channel + types.extend(["eeg"] * 60) + picks = np.concatenate( + [ + pick_types(raw.info, meg=True)[::20], + pick_types(raw.info, meg=False, stim=True), + pick_types(raw.info, meg=False, eeg=True)[::20], + ] + ) del raw data = data[picks] ch_names = np.array(ch_names)[picks].tolist() @@ -101,37 +106,39 @@ def test_array_raw(): # wrong length pytest.raises(ValueError, create_info, ch_names, sfreq, types) # bad entry - types.append('foo') + types.append("foo") pytest.raises(KeyError, create_info, ch_names, sfreq, types) - types[-1] = 'eog' + types[-1] = "eog" # default type info = create_info(ch_names, sfreq) - assert_equal(info['chs'][0]['kind'], - get_channel_type_constants()['misc']['kind']) + assert_equal(info["chs"][0]["kind"], get_channel_type_constants()["misc"]["kind"]) # use real types info = create_info(ch_names, sfreq, types) - raw2 = _test_raw_reader(RawArray, test_preloading=False, - data=data, info=info, first_samp=2 * data.shape[1]) + raw2 = _test_raw_reader( + RawArray, + test_preloading=False, + data=data, + info=info, + first_samp=2 * data.shape[1], + ) data2, times2 = raw2[:, :] assert_allclose(data, data2) assert_allclose(times, times2) - assert ('RawArray' in repr(raw2)) + assert "RawArray" in repr(raw2) pytest.raises(TypeError, RawArray, info, data) # filtering - picks = pick_types(raw2.info, meg=True, misc=True, exclude='bads')[:4] + picks = pick_types(raw2.info, meg=True, misc=True, exclude="bads")[:4] assert_equal(len(picks), 4) raw_lp = raw2.copy() - kwargs = dict(fir_design='firwin', picks=picks) - raw_lp.filter(None, 4.0, h_trans_bandwidth=4., **kwargs) + kwargs = dict(fir_design="firwin", picks=picks) + raw_lp.filter(None, 4.0, h_trans_bandwidth=4.0, **kwargs) raw_hp = raw2.copy() - raw_hp.filter(16.0, None, l_trans_bandwidth=4., **kwargs) + raw_hp.filter(16.0, None, l_trans_bandwidth=4.0, **kwargs) raw_bp = raw2.copy() - raw_bp.filter(8.0, 12.0, l_trans_bandwidth=4., h_trans_bandwidth=4., - **kwargs) + raw_bp.filter(8.0, 12.0, l_trans_bandwidth=4.0, h_trans_bandwidth=4.0, **kwargs) raw_bs = raw2.copy() - raw_bs.filter(16.0, 4.0, l_trans_bandwidth=4., h_trans_bandwidth=4., - **kwargs) + raw_bs.filter(16.0, 4.0, l_trans_bandwidth=4.0, h_trans_bandwidth=4.0, **kwargs) data, _ = raw2[picks, :] lp_data, _ = raw_lp[picks, :] hp_data, _ = raw_hp[picks, :] @@ -143,12 +150,11 @@ def test_array_raw(): # plotting raw2.plot() - (raw2.compute_psd(tmax=2., n_fft=1024) - .plot(average=True, spatial_colors=False)) - plt.close('all') + (raw2.compute_psd(tmax=2.0, n_fft=1024).plot(average=True, spatial_colors=False)) + plt.close("all") # epoching - events = find_events(raw2, stim_channel='STI 014') + events = find_events(raw2, stim_channel="STI 014") events[:, 2] = 1 assert len(events) > 2 epochs = Epochs(raw2, events, 1, -0.2, 0.4, preload=True) @@ -158,22 +164,21 @@ def test_array_raw(): # complex data rng = np.random.RandomState(0) data = rng.randn(1, 100) + 1j * rng.randn(1, 100) - raw = RawArray(data, create_info(1, 1000., 'eeg')) + raw = RawArray(data, create_info(1, 1000.0, "eeg")) assert_allclose(raw._data, data) # Using digital montage to give MNI electrode coordinates n_elec = 10 ts_size = 10000 - Fs = 512. + Fs = 512.0 ch_names = [str(i) for i in range(n_elec)] ch_pos_loc = np.random.randint(60, size=(n_elec, 3)).tolist() data = np.random.rand(n_elec, ts_size) montage = make_dig_montage( - ch_pos=dict(zip(ch_names, ch_pos_loc)), - coord_frame='head' + ch_pos=dict(zip(ch_names, ch_pos_loc)), coord_frame="head" ) - info = create_info(ch_names, Fs, 'ecog') + info = create_info(ch_names, Fs, "ecog") raw = RawArray(data, info) raw.set_montage(montage) diff --git a/mne/io/artemis123/artemis123.py b/mne/io/artemis123/artemis123.py index e959158d9c4..c92494dbb26 100644 --- a/mne/io/artemis123/artemis123.py +++ b/mne/io/artemis123/artemis123.py @@ -18,8 +18,9 @@ @verbose -def read_raw_artemis123(input_fname, preload=False, verbose=None, - pos_fname=None, add_head_trans=True): +def read_raw_artemis123( + input_fname, preload=False, verbose=None, pos_fname=None, add_head_trans=True +): """Read Artemis123 data as raw object. Parameters @@ -47,27 +48,39 @@ def read_raw_artemis123(input_fname, preload=False, verbose=None, -------- mne.io.Raw : Documentation of attributes and methods. """ - return RawArtemis123(input_fname, preload=preload, verbose=verbose, - pos_fname=pos_fname, add_head_trans=add_head_trans) + return RawArtemis123( + input_fname, + preload=preload, + verbose=verbose, + pos_fname=pos_fname, + add_head_trans=add_head_trans, + ) def _get_artemis123_info(fname, pos_fname=None): """Generate info struct from artemis123 header file.""" fname = op.splitext(fname)[0] - header = fname + '.txt' + header = fname + ".txt" - logger.info('Reading header...') + logger.info("Reading header...") # key names for artemis channel info... - chan_keys = ['name', 'scaling', 'FLL_Gain', 'FLL_Mode', 'FLL_HighPass', - 'FLL_AutoReset', 'FLL_ResetLock'] + chan_keys = [ + "name", + "scaling", + "FLL_Gain", + "FLL_Mode", + "FLL_HighPass", + "FLL_AutoReset", + "FLL_ResetLock", + ] header_info = dict() - header_info['filter_hist'] = [] - header_info['comments'] = '' - header_info['channels'] = [] + header_info["filter_hist"] = [] + header_info["comments"] = "" + header_info["channels"] = [] - with open(header, 'r') as fid: + with open(header, "r") as fid: # section flag # 0 - None # 1 - main header @@ -78,12 +91,11 @@ def _get_artemis123_info(fname, pos_fname=None): sectionFlag = 0 for line in fid: # skip emptylines or header line for channel info - if ((not line.strip()) or - (sectionFlag == 2 and line.startswith('DAQ Map'))): + if (not line.strip()) or (sectionFlag == 2 and line.startswith("DAQ Map")): continue # set sectionFlag - if line.startswith('"): sectionFlag = 1 @@ -99,149 +111,165 @@ def _get_artemis123_info(fname, pos_fname=None): # parse header info lines # part of main header - lines are name value pairs if sectionFlag == 1: - values = line.strip().split('\t') + values = line.strip().split("\t") if len(values) == 1: - values.append('') + values.append("") header_info[values[0]] = values[1] # part of channel header - lines are Channel Info elif sectionFlag == 2: - values = line.strip().split('\t') + values = line.strip().split("\t") if len(values) != 7: - raise OSError('Error parsing line \n\t:%s\n' % line + - 'from file %s' % header) + raise OSError( + "Error parsing line \n\t:%s\n" % line + + "from file %s" % header + ) tmp = dict() for k, v in zip(chan_keys, values): tmp[k] = v - header_info['channels'].append(tmp) + header_info["channels"].append(tmp) elif sectionFlag == 3: - header_info['comments'] = '%s%s' \ - % (header_info['comments'], line.strip()) + header_info["comments"] = "%s%s" % ( + header_info["comments"], + line.strip(), + ) elif sectionFlag == 4: - header_info['num_samples'] = int(line.strip()) + header_info["num_samples"] = int(line.strip()) elif sectionFlag == 5: - header_info['filter_hist'].append(line.strip()) - - for k in ['Temporal Filter Active?', 'Decimation Active?', - 'Spatial Filter Active?']: - if header_info[k] != 'FALSE': - warn('%s - set to but is not supported' % k) - if header_info['filter_hist']: - warn('Non-Empty Filter history found, BUT is not supported' % k) + header_info["filter_hist"].append(line.strip()) + + for k in [ + "Temporal Filter Active?", + "Decimation Active?", + "Spatial Filter Active?", + ]: + if header_info[k] != "FALSE": + warn("%s - set to but is not supported" % k) + if header_info["filter_hist"]: + warn("Non-Empty Filter history found, BUT is not supported" % k) # build mne info struct - info = _empty_info(float(header_info['DAQ Sample Rate'])) + info = _empty_info(float(header_info["DAQ Sample Rate"])) # Attempt to get time/date from fname # Artemis123 files saved from the scanner observe the following # naming convention 'Artemis_Data_YYYY-MM-DD-HHh-MMm_[chosen by user].bin' try: date = datetime.datetime.strptime( - op.basename(fname).split('_')[2], '%Y-%m-%d-%Hh-%Mm') + op.basename(fname).split("_")[2], "%Y-%m-%d-%Hh-%Mm" + ) meas_date = (calendar.timegm(date.utctimetuple()), 0) except Exception: meas_date = None # build subject info must be an integer (as per FIFF) try: - subject_info = {'id': int(header_info['Subject ID'])} + subject_info = {"id": int(header_info["Subject ID"])} except ValueError: - subject_info = {'id': 0} + subject_info = {"id": 0} # build description - desc = '' - for k in ['Purpose', 'Notes']: - desc += '{} : {}\n'.format(k, header_info[k]) - desc += 'Comments : {}'.format(header_info['comments']) - - info.update({'meas_date': meas_date, - 'description': desc, - 'subject_info': subject_info, - 'proj_name': header_info['Project Name']}) + desc = "" + for k in ["Purpose", "Notes"]: + desc += "{} : {}\n".format(k, header_info[k]) + desc += "Comments : {}".format(header_info["comments"]) + + info.update( + { + "meas_date": meas_date, + "description": desc, + "subject_info": subject_info, + "proj_name": header_info["Project Name"], + } + ) # Channel Names by type - ref_mag_names = ['REF_001', 'REF_002', 'REF_003', - 'REF_004', 'REF_005', 'REF_006'] + ref_mag_names = ["REF_001", "REF_002", "REF_003", "REF_004", "REF_005", "REF_006"] - ref_grad_names = ['REF_007', 'REF_008', 'REF_009', - 'REF_010', 'REF_011', 'REF_012'] + ref_grad_names = ["REF_007", "REF_008", "REF_009", "REF_010", "REF_011", "REF_012"] # load mne loc dictionary loc_dict = _load_mne_locs() - info['chs'] = [] - info['bads'] = [] + info["chs"] = [] + info["bads"] = [] - for i, chan in enumerate(header_info['channels']): + for i, chan in enumerate(header_info["channels"]): # build chs struct - t = {'cal': float(chan['scaling']), 'ch_name': chan['name'], - 'logno': i + 1, 'scanno': i + 1, 'range': 1.0, - 'unit_mul': FIFF.FIFF_UNITM_NONE, - 'coord_frame': FIFF.FIFFV_COORD_DEVICE} + t = { + "cal": float(chan["scaling"]), + "ch_name": chan["name"], + "logno": i + 1, + "scanno": i + 1, + "range": 1.0, + "unit_mul": FIFF.FIFF_UNITM_NONE, + "coord_frame": FIFF.FIFFV_COORD_DEVICE, + } # REF_018 has a zero cal which can cause problems. Let's set it to # a value of another ref channel to make writers/readers happy. - if t['cal'] == 0: - t['cal'] = 4.716e-10 - info['bads'].append(t['ch_name']) - t['loc'] = loc_dict.get(chan['name'], np.zeros(12)) - - if (chan['name'].startswith('MEG')): - t['coil_type'] = FIFF.FIFFV_COIL_ARTEMIS123_GRAD - t['kind'] = FIFF.FIFFV_MEG_CH + if t["cal"] == 0: + t["cal"] = 4.716e-10 + info["bads"].append(t["ch_name"]) + t["loc"] = loc_dict.get(chan["name"], np.zeros(12)) + + if chan["name"].startswith("MEG"): + t["coil_type"] = FIFF.FIFFV_COIL_ARTEMIS123_GRAD + t["kind"] = FIFF.FIFFV_MEG_CH # While gradiometer units are T/m, the meg sensors referred to as # gradiometers report the field difference between 2 pick-up coils. # Therefore the units of the measurements should be T # *AND* the baseline (difference between pickup coils) # should not be used in leadfield / forwardfield computations. - t['unit'] = FIFF.FIFF_UNIT_T - t['unit_mul'] = FIFF.FIFF_UNITM_F + t["unit"] = FIFF.FIFF_UNIT_T + t["unit_mul"] = FIFF.FIFF_UNITM_F # 3 axis reference magnetometers - elif (chan['name'] in ref_mag_names): - t['coil_type'] = FIFF.FIFFV_COIL_ARTEMIS123_REF_MAG - t['kind'] = FIFF.FIFFV_REF_MEG_CH - t['unit'] = FIFF.FIFF_UNIT_T - t['unit_mul'] = FIFF.FIFF_UNITM_F + elif chan["name"] in ref_mag_names: + t["coil_type"] = FIFF.FIFFV_COIL_ARTEMIS123_REF_MAG + t["kind"] = FIFF.FIFFV_REF_MEG_CH + t["unit"] = FIFF.FIFF_UNIT_T + t["unit_mul"] = FIFF.FIFF_UNITM_F # reference gradiometers - elif (chan['name'] in ref_grad_names): - t['coil_type'] = FIFF.FIFFV_COIL_ARTEMIS123_REF_GRAD - t['kind'] = FIFF.FIFFV_REF_MEG_CH + elif chan["name"] in ref_grad_names: + t["coil_type"] = FIFF.FIFFV_COIL_ARTEMIS123_REF_GRAD + t["kind"] = FIFF.FIFFV_REF_MEG_CH # While gradiometer units are T/m, the meg sensors referred to as # gradiometers report the field difference between 2 pick-up coils. # Therefore the units of the measurements should be T # *AND* the baseline (difference between pickup coils) # should not be used in leadfield / forwardfield computations. - t['unit'] = FIFF.FIFF_UNIT_T - t['unit_mul'] = FIFF.FIFF_UNITM_F + t["unit"] = FIFF.FIFF_UNIT_T + t["unit_mul"] = FIFF.FIFF_UNITM_F # other reference channels are unplugged and should be ignored. - elif (chan['name'].startswith('REF')): - t['coil_type'] = FIFF.FIFFV_COIL_NONE - t['kind'] = FIFF.FIFFV_MISC_CH - t['unit'] = FIFF.FIFF_UNIT_V - info['bads'].append(t['ch_name']) - - elif (chan['name'].startswith(('AUX', 'TRG', 'MIO'))): - t['coil_type'] = FIFF.FIFFV_COIL_NONE - t['unit'] = FIFF.FIFF_UNIT_V - if (chan['name'].startswith('TRG')): - t['kind'] = FIFF.FIFFV_STIM_CH + elif chan["name"].startswith("REF"): + t["coil_type"] = FIFF.FIFFV_COIL_NONE + t["kind"] = FIFF.FIFFV_MISC_CH + t["unit"] = FIFF.FIFF_UNIT_V + info["bads"].append(t["ch_name"]) + + elif chan["name"].startswith(("AUX", "TRG", "MIO")): + t["coil_type"] = FIFF.FIFFV_COIL_NONE + t["unit"] = FIFF.FIFF_UNIT_V + if chan["name"].startswith("TRG"): + t["kind"] = FIFF.FIFFV_STIM_CH else: - t['kind'] = FIFF.FIFFV_MISC_CH + t["kind"] = FIFF.FIFFV_MISC_CH else: - raise ValueError('Channel does not match expected' + - ' channel Types:"%s"' % chan['name']) + raise ValueError( + "Channel does not match expected" + ' channel Types:"%s"' % chan["name"] + ) # incorporate multiplier (unit_mul) into calibration - t['cal'] *= 10 ** t['unit_mul'] - t['unit_mul'] = FIFF.FIFF_UNITM_NONE + t["cal"] *= 10 ** t["unit_mul"] + t["unit_mul"] = FIFF.FIFF_UNITM_NONE # append this channel to the info - info['chs'].append(t) - if chan['FLL_ResetLock'] == 'TRUE': - info['bads'].append(t['ch_name']) + info["chs"].append(t) + if chan["FLL_ResetLock"] == "TRUE": + info["bads"].append(t["ch_name"]) # reduce info['bads'] to unique set - info['bads'] = list(set(info['bads'])) + info["bads"] = list(set(info["bads"])) # HPI information # print header_info.keys() @@ -249,38 +277,37 @@ def _get_artemis123_info(fname, pos_fname=None): # Don't know what event_channel is don't think we have it HPIs are either # always on or always off. # hpi_sub['event_channel'] = ??? - hpi_sub['hpi_coils'] = [dict(), dict(), dict(), dict()] + hpi_sub["hpi_coils"] = [dict(), dict(), dict(), dict()] hpi_coils = [dict(), dict(), dict(), dict()] - drive_channels = ['MIO_001', 'MIO_003', 'MIO_009', 'MIO_011'] - key_base = 'Head Tracking %s %d' + drive_channels = ["MIO_001", "MIO_003", "MIO_009", "MIO_011"] + key_base = "Head Tracking %s %d" # set default HPI frequencies - if info['sfreq'] == 1000: + if info["sfreq"] == 1000: default_freqs = [140, 150, 160, 40] else: default_freqs = [700, 750, 800, 40] for i in range(4): # build coil structure - hpi_coils[i]['number'] = i + 1 - hpi_coils[i]['drive_chan'] = drive_channels[i] - this_freq = header_info.pop(key_base % ('Frequency', i + 1), - default_freqs[i]) - hpi_coils[i]['coil_freq'] = this_freq + hpi_coils[i]["number"] = i + 1 + hpi_coils[i]["drive_chan"] = drive_channels[i] + this_freq = header_info.pop(key_base % ("Frequency", i + 1), default_freqs[i]) + hpi_coils[i]["coil_freq"] = this_freq # check if coil is on - if header_info[key_base % ('Channel', i + 1)] == 'OFF': - hpi_sub['hpi_coils'][i]['event_bits'] = [0] + if header_info[key_base % ("Channel", i + 1)] == "OFF": + hpi_sub["hpi_coils"][i]["event_bits"] = [0] else: - hpi_sub['hpi_coils'][i]['event_bits'] = [256] + hpi_sub["hpi_coils"][i]["event_bits"] = [256] - info['hpi_subsystem'] = hpi_sub - info['hpi_meas'] = [{'hpi_coils': hpi_coils}] + info["hpi_subsystem"] = hpi_sub + info["hpi_meas"] = [{"hpi_coils": hpi_coils}] # read in digitized points if supplied if pos_fname is not None: - info['dig'] = _read_pos(pos_fname) + info["dig"] = _read_pos(pos_fname) else: - info['dig'] = [] + info["dig"] = [] info._unlocked = False info._update_redundant() @@ -303,133 +330,163 @@ class RawArtemis123(BaseRaw): """ @verbose - def __init__(self, input_fname, preload=False, verbose=None, - pos_fname=None, add_head_trans=True): # noqa: D102 + def __init__( + self, + input_fname, + preload=False, + verbose=None, + pos_fname=None, + add_head_trans=True, + ): # noqa: D102 from scipy.spatial.distance import cdist - from ...chpi import (compute_chpi_amplitudes, compute_chpi_locs, - _fit_coil_order_dev_head_trans) - input_fname = str( - _check_fname(input_fname, "read", True, "input_fname") + from ...chpi import ( + compute_chpi_amplitudes, + compute_chpi_locs, + _fit_coil_order_dev_head_trans, ) + + input_fname = str(_check_fname(input_fname, "read", True, "input_fname")) fname, ext = op.splitext(input_fname) - if ext == '.txt': - input_fname = fname + '.bin' - elif ext != '.bin': - raise RuntimeError('Valid artemis123 files must end in "txt"' + - ' or ".bin".') + if ext == ".txt": + input_fname = fname + ".bin" + elif ext != ".bin": + raise RuntimeError( + 'Valid artemis123 files must end in "txt"' + ' or ".bin".' + ) if not op.exists(input_fname): - raise RuntimeError('%s - Not Found' % input_fname) + raise RuntimeError("%s - Not Found" % input_fname) - info, header_info = _get_artemis123_info(input_fname, - pos_fname=pos_fname) + info, header_info = _get_artemis123_info(input_fname, pos_fname=pos_fname) - last_samps = [header_info.get('num_samples', 1) - 1] + last_samps = [header_info.get("num_samples", 1) - 1] super(RawArtemis123, self).__init__( - info, preload, filenames=[input_fname], raw_extras=[header_info], - last_samps=last_samps, orig_format="single", - verbose=verbose) + info, + preload, + filenames=[input_fname], + raw_extras=[header_info], + last_samps=last_samps, + orig_format="single", + verbose=verbose, + ) if add_head_trans: n_hpis = 0 - for d in info['hpi_subsystem']['hpi_coils']: - if d['event_bits'] == [256]: + for d in info["hpi_subsystem"]["hpi_coils"]: + if d["event_bits"] == [256]: n_hpis += 1 if n_hpis < 3: - warn('%d HPIs active. At least 3 needed to perform' % n_hpis + - 'head localization\n *NO* head localization performed') + warn( + "%d HPIs active. At least 3 needed to perform" % n_hpis + + "head localization\n *NO* head localization performed" + ) else: # Localized HPIs using the 1st 250 milliseconds of data. with info._unlock(): - info['hpi_results'] = [dict( - dig_points=[dict( - r=np.zeros(3), - coord_frame=FIFF.FIFFV_COORD_DEVICE, - ident=ii + 1) for ii in range(n_hpis)], - coord_trans=Transform('meg', 'head'))] + info["hpi_results"] = [ + dict( + dig_points=[ + dict( + r=np.zeros(3), + coord_frame=FIFF.FIFFV_COORD_DEVICE, + ident=ii + 1, + ) + for ii in range(n_hpis) + ], + coord_trans=Transform("meg", "head"), + ) + ] coil_amplitudes = compute_chpi_amplitudes( - self, tmin=0, tmax=0.25, t_window=0.25, t_step_min=0.25) - assert len(coil_amplitudes['times']) == 1 + self, tmin=0, tmax=0.25, t_window=0.25, t_step_min=0.25 + ) + assert len(coil_amplitudes["times"]) == 1 coil_locs = compute_chpi_locs(self.info, coil_amplitudes) with info._unlock(): - info['hpi_results'] = None - hpi_g = coil_locs['gofs'][0] - hpi_dev = coil_locs['rrs'][0] + info["hpi_results"] = None + hpi_g = coil_locs["gofs"][0] + hpi_dev = coil_locs["rrs"][0] # only use HPI coils with localizaton goodness_of_fit > 0.98 bad_idx = [] for i, g in enumerate(hpi_g): - msg = 'HPI coil %d - location goodness of fit (%0.3f)' + msg = "HPI coil %d - location goodness of fit (%0.3f)" if g < 0.98: bad_idx.append(i) - msg += ' *Removed from coregistration*' + msg += " *Removed from coregistration*" logger.info(msg % (i + 1, g)) hpi_dev = np.delete(hpi_dev, bad_idx, axis=0) hpi_g = np.delete(hpi_g, bad_idx, axis=0) if pos_fname is not None: # Digitized HPI points are needed. - hpi_head = np.array([d['r'] - for d in self.info.get('dig', []) - if d['kind'] == FIFF.FIFFV_POINT_HPI]) - - if (len(hpi_head) != len(hpi_dev)): - mesg = ("number of digitized (%d) and " + - "active (%d) HPI coils are " + - "not the same.") - raise RuntimeError(mesg % (len(hpi_head), - len(hpi_dev))) + hpi_head = np.array( + [ + d["r"] + for d in self.info.get("dig", []) + if d["kind"] == FIFF.FIFFV_POINT_HPI + ] + ) + + if len(hpi_head) != len(hpi_dev): + mesg = ( + "number of digitized (%d) and " + + "active (%d) HPI coils are " + + "not the same." + ) + raise RuntimeError(mesg % (len(hpi_head), len(hpi_dev))) # compute initial head to dev transform and hpi ordering - head_to_dev_t, order, trans_g = \ - _fit_coil_order_dev_head_trans(hpi_dev, hpi_head) + head_to_dev_t, order, trans_g = _fit_coil_order_dev_head_trans( + hpi_dev, hpi_head + ) # set the device to head transform - self.info['dev_head_t'] = \ - Transform(FIFF.FIFFV_COORD_DEVICE, - FIFF.FIFFV_COORD_HEAD, head_to_dev_t) + self.info["dev_head_t"] = Transform( + FIFF.FIFFV_COORD_DEVICE, FIFF.FIFFV_COORD_HEAD, head_to_dev_t + ) # add hpi_meg_dev to dig... for idx, point in enumerate(hpi_dev): - d = {'r': point, 'ident': idx + 1, - 'kind': FIFF.FIFFV_POINT_HPI, - 'coord_frame': FIFF.FIFFV_COORD_DEVICE} - self.info['dig'].append(DigPoint(d)) + d = { + "r": point, + "ident": idx + 1, + "kind": FIFF.FIFFV_POINT_HPI, + "coord_frame": FIFF.FIFFV_COORD_DEVICE, + } + self.info["dig"].append(DigPoint(d)) dig_dists = cdist(hpi_head[order], hpi_head[order]) dev_dists = cdist(hpi_dev, hpi_dev) tmp_dists = np.abs(dig_dists - dev_dists) dist_limit = tmp_dists.max() * 1.1 - msg = 'HPI-Dig corrregsitration\n' - msg += '\tGOF : %0.3f\n' % trans_g - msg += '\tMax Coil Error : %0.3f cm\n' % (100 * - tmp_dists.max()) + msg = "HPI-Dig corrregsitration\n" + msg += "\tGOF : %0.3f\n" % trans_g + msg += "\tMax Coil Error : %0.3f cm\n" % (100 * tmp_dists.max()) logger.info(msg) else: - logger.info('Assuming Cardinal HPIs') + logger.info("Assuming Cardinal HPIs") nas = hpi_dev[0] lpa = hpi_dev[2] rpa = hpi_dev[1] t = get_ras_to_neuromag_trans(nas, lpa, rpa) with self.info._unlock(): - self.info['dev_head_t'] = \ - Transform(FIFF.FIFFV_COORD_DEVICE, - FIFF.FIFFV_COORD_HEAD, t) + self.info["dev_head_t"] = Transform( + FIFF.FIFFV_COORD_DEVICE, FIFF.FIFFV_COORD_HEAD, t + ) # transform fiducial points nas = apply_trans(t, nas) lpa = apply_trans(t, lpa) rpa = apply_trans(t, rpa) - hpi = apply_trans(self.info['dev_head_t'], hpi_dev) + hpi = apply_trans(self.info["dev_head_t"], hpi_dev) with self.info._unlock(): - self.info['dig'] = _make_dig_points(nasion=nas, - lpa=lpa, - rpa=rpa, - hpi=hpi) + self.info["dig"] = _make_dig_points( + nasion=nas, lpa=lpa, rpa=rpa, hpi=hpi + ) order = np.array([0, 1, 2]) dist_limit = 0.005 @@ -439,33 +496,39 @@ def __init__(self, input_fname, preload=False, verbose=None, # add HPI points in device coords... dig = [] for idx, point in enumerate(hpi_dev): - dig.append({'r': point, 'ident': idx + 1, - 'kind': FIFF.FIFFV_POINT_HPI, - 'coord_frame': FIFF.FIFFV_COORD_DEVICE}) - hpi_result['dig_points'] = dig + dig.append( + { + "r": point, + "ident": idx + 1, + "kind": FIFF.FIFFV_POINT_HPI, + "coord_frame": FIFF.FIFFV_COORD_DEVICE, + } + ) + hpi_result["dig_points"] = dig # attach Transform - hpi_result['coord_trans'] = self.info['dev_head_t'] + hpi_result["coord_trans"] = self.info["dev_head_t"] # 1 based indexing - hpi_result['order'] = order + 1 - hpi_result['used'] = np.arange(3) + 1 - hpi_result['dist_limit'] = dist_limit - hpi_result['good_limit'] = 0.98 + hpi_result["order"] = order + 1 + hpi_result["used"] = np.arange(3) + 1 + hpi_result["dist_limit"] = dist_limit + hpi_result["good_limit"] = 0.98 # Warn for large discrepancies between digitized and fit # cHPI locations - if hpi_result['dist_limit'] > 0.005: - warn('Large difference between digitized geometry' + - ' and HPI geometry. Max coil to coil difference' + - ' is %0.2f cm\n' % (100. * tmp_dists.max()) + - 'beware of *POOR* head localization') + if hpi_result["dist_limit"] > 0.005: + warn( + "Large difference between digitized geometry" + + " and HPI geometry. Max coil to coil difference" + + " is %0.2f cm\n" % (100.0 * tmp_dists.max()) + + "beware of *POOR* head localization" + ) # store it with self.info._unlock(): - self.info['hpi_results'] = [hpi_result] + self.info["hpi_results"] = [hpi_result] def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): """Read a chunk of raw data.""" - _read_segments_file( - self, data, idx, fi, start, stop, cals, mult, dtype='>f4') + _read_segments_file(self, data, idx, fi, start, stop, cals, mult, dtype=">f4") diff --git a/mne/io/artemis123/tests/test_artemis123.py b/mne/io/artemis123/tests/test_artemis123.py index 21b7204b775..35ba7dab668 100644 --- a/mne/io/artemis123/tests/test_artemis123.py +++ b/mne/io/artemis123/tests/test_artemis123.py @@ -16,19 +16,17 @@ artemis123_dir = testing.data_path(download=False) / "ARTEMIS123" short_HPI_dip_fname = ( - artemis123_dir - / "Artemis_Data_2017-04-04-15h-44m-22s_Motion_Translation-z.bin" + artemis123_dir / "Artemis_Data_2017-04-04-15h-44m-22s_Motion_Translation-z.bin" ) dig_fname = artemis123_dir / "Phantom_040417_dig.pos" short_hpi_1kz_fname = ( - artemis123_dir - / "Artemis_Data_2017-04-14-10h-38m-59s_Phantom_1k_HPI_1s.bin" + artemis123_dir / "Artemis_Data_2017-04-14-10h-38m-59s_Phantom_1k_HPI_1s.bin" ) # XXX this tol is way too high, but it's not clear which is correct # (old or new) -def _assert_trans(actual, desired, dist_tol=0.017, angle_tol=5.): +def _assert_trans(actual, desired, dist_tol=0.017, angle_tol=5.0): __tracebackhide__ = True trans_est = actual[0:3, 3] quat_est = rot_to_quat(actual[0:3, 0:3]) @@ -37,18 +35,23 @@ def _assert_trans(actual, desired, dist_tol=0.017, angle_tol=5.): angle = np.rad2deg(_angle_between_quats(quat_est, quat)) dist = np.linalg.norm(trans - trans_est) - assert dist <= dist_tol, \ - '%0.3f > %0.3f mm translation' % (1000 * dist, 1000 * dist_tol) - assert angle <= angle_tol, \ - '%0.3f > %0.3f° rotation' % (angle, angle_tol) + assert dist <= dist_tol, "%0.3f > %0.3f mm translation" % ( + 1000 * dist, + 1000 * dist_tol, + ) + assert angle <= angle_tol, "%0.3f > %0.3f° rotation" % (angle, angle_tol) @pytest.mark.timeout(60) # ~25 s on Travis Linux OpenBLAS @testing.requires_testing_data def test_artemis_reader(): """Test reading raw Artemis123 files.""" - _test_raw_reader(read_raw_artemis123, input_fname=short_hpi_1kz_fname, - pos_fname=dig_fname, verbose='error') + _test_raw_reader( + read_raw_artemis123, + input_fname=short_hpi_1kz_fname, + pos_fname=dig_fname, + verbose="error", + ) @pytest.mark.timeout(60) @@ -56,47 +59,60 @@ def test_artemis_reader(): def test_dev_head_t(): """Test dev_head_t computation for Artemis123.""" # test a random selected point - raw = read_raw_artemis123(short_hpi_1kz_fname, preload=True, - add_head_trans=False) + raw = read_raw_artemis123(short_hpi_1kz_fname, preload=True, add_head_trans=False) meg_picks = pick_types(raw.info, meg=True, eeg=False) # checked against matlab reader. assert_allclose(raw[meg_picks[12]][0][0][123], 1.08239606023e-11) - dev_head_t_1 = np.array([[9.713e-01, 2.340e-01, -4.164e-02, 1.302e-04], - [-2.371e-01, 9.664e-01, -9.890e-02, 1.977e-03], - [1.710e-02, 1.059e-01, 9.942e-01, -8.159e-03], - [0.0, 0.0, 0.0, 1.0]]) - - dev_head_t_2 = np.array([[9.890e-01, 1.475e-01, -8.090e-03, 4.997e-04], - [-1.476e-01, 9.846e-01, -9.389e-02, 1.962e-03], - [-5.888e-03, 9.406e-02, 9.955e-01, -1.610e-02], - [0.0, 0.0, 0.0, 1.0]]) - - expected_dev_hpi_rr = np.array([[-0.01579644, 0.06527367, 0.00152648], - [0.06666813, 0.0148956, 0.00545488], - [-0.06699212, -0.01732376, 0.0112027]]) + dev_head_t_1 = np.array( + [ + [9.713e-01, 2.340e-01, -4.164e-02, 1.302e-04], + [-2.371e-01, 9.664e-01, -9.890e-02, 1.977e-03], + [1.710e-02, 1.059e-01, 9.942e-01, -8.159e-03], + [0.0, 0.0, 0.0, 1.0], + ] + ) + + dev_head_t_2 = np.array( + [ + [9.890e-01, 1.475e-01, -8.090e-03, 4.997e-04], + [-1.476e-01, 9.846e-01, -9.389e-02, 1.962e-03], + [-5.888e-03, 9.406e-02, 9.955e-01, -1.610e-02], + [0.0, 0.0, 0.0, 1.0], + ] + ) + + expected_dev_hpi_rr = np.array( + [ + [-0.01579644, 0.06527367, 0.00152648], + [0.06666813, 0.0148956, 0.00545488], + [-0.06699212, -0.01732376, 0.0112027], + ] + ) # test with head loc no digitization raw = read_raw_artemis123(short_HPI_dip_fname, add_head_trans=True) - _assert_trans(raw.info['dev_head_t']['trans'], dev_head_t_1) - assert_equal(raw.info['sfreq'], 5000.0) + _assert_trans(raw.info["dev_head_t"]["trans"], dev_head_t_1) + assert_equal(raw.info["sfreq"], 5000.0) # test with head loc and digitization - with pytest.warns(RuntimeWarning, match='Large difference'): - raw = read_raw_artemis123(short_HPI_dip_fname, add_head_trans=True, - pos_fname=dig_fname) - _assert_trans(raw.info['dev_head_t']['trans'], dev_head_t_1) + with pytest.warns(RuntimeWarning, match="Large difference"): + raw = read_raw_artemis123( + short_HPI_dip_fname, add_head_trans=True, pos_fname=dig_fname + ) + _assert_trans(raw.info["dev_head_t"]["trans"], dev_head_t_1) # test cHPI localization.. - dev_hpi_rr = np.array([p['r'] for p in raw.info['dig'] - if p['coord_frame'] == FIFF.FIFFV_COORD_DEVICE]) + dev_hpi_rr = np.array( + [p["r"] for p in raw.info["dig"] if p["coord_frame"] == FIFF.FIFFV_COORD_DEVICE] + ) # points should be within 0.1 mm (1e-4m) and within 1% assert_allclose(dev_hpi_rr, expected_dev_hpi_rr, atol=1e-4, rtol=0.01) # test 1kz hpi head loc (different freq) raw = read_raw_artemis123(short_hpi_1kz_fname, add_head_trans=True) - _assert_trans(raw.info['dev_head_t']['trans'], dev_head_t_2) - assert_equal(raw.info['sfreq'], 1000.0) + _assert_trans(raw.info["dev_head_t"]["trans"], dev_head_t_2) + assert_equal(raw.info["sfreq"], 1000.0) def test_utils(tmp_path): diff --git a/mne/io/artemis123/utils.py b/mne/io/artemis123/utils.py index a2448b5fdfc..cbc4c4b7e31 100644 --- a/mne/io/artemis123/utils.py +++ b/mne/io/artemis123/utils.py @@ -7,19 +7,19 @@ def _load_mne_locs(fname=None): """Load MNE locs structure from file (if exists) or recreate it.""" - if (not fname): + if not fname: # find input file - resource_dir = op.join(op.dirname(op.abspath(__file__)), 'resources') - fname = op.join(resource_dir, 'Artemis123_mneLoc.csv') + resource_dir = op.join(op.dirname(op.abspath(__file__)), "resources") + fname = op.join(resource_dir, "Artemis123_mneLoc.csv") if not op.exists(fname): raise OSError('MNE locs file "%s" does not exist' % (fname)) - logger.info('Loading mne loc file {}'.format(fname)) + logger.info("Loading mne loc file {}".format(fname)) locs = dict() - with open(fname, 'r') as fid: + with open(fname, "r") as fid: for line in fid: - vals = line.strip().split(',') + vals = line.strip().split(",") locs[vals[0]] = np.array(vals[1::], np.float64) return locs @@ -27,41 +27,39 @@ def _load_mne_locs(fname=None): def _generate_mne_locs_file(output_fname): """Generate mne coil locs and save to supplied file.""" - logger.info('Converting Tristan coil file to mne loc file...') - resource_dir = op.join(op.dirname(op.abspath(__file__)), 'resources') - chan_fname = op.join(resource_dir, 'Artemis123_ChannelMap.csv') + logger.info("Converting Tristan coil file to mne loc file...") + resource_dir = op.join(op.dirname(op.abspath(__file__)), "resources") + chan_fname = op.join(resource_dir, "Artemis123_ChannelMap.csv") chans = _load_tristan_coil_locs(chan_fname) # compute a dict of loc structs locs = {n: _compute_mne_loc(cinfo) for n, cinfo in chans.items()} # write it out to output_fname - with open(output_fname, 'w') as fid: + with open(output_fname, "w") as fid: for n in sorted(locs.keys()): - fid.write('%s,' % n) - fid.write(','.join(locs[n].astype(str))) - fid.write('\n') + fid.write("%s," % n) + fid.write(",".join(locs[n].astype(str))) + fid.write("\n") def _load_tristan_coil_locs(coil_loc_path): """Load the Coil locations from Tristan CAD drawings.""" channel_info = dict() - with open(coil_loc_path, 'r') as fid: + with open(coil_loc_path, "r") as fid: # skip 2 Header lines fid.readline() fid.readline() for line in fid: line = line.strip() - vals = line.split(',') + vals = line.split(",") channel_info[vals[0]] = dict() if vals[6]: - channel_info[vals[0]]['inner_coil'] = \ - np.array(vals[2:5], np.float64) - channel_info[vals[0]]['outer_coil'] = \ - np.array(vals[5:8], np.float64) + channel_info[vals[0]]["inner_coil"] = np.array(vals[2:5], np.float64) + channel_info[vals[0]]["outer_coil"] = np.array(vals[5:8], np.float64) else: # nothing supplied - channel_info[vals[0]]['inner_coil'] = np.zeros(3) - channel_info[vals[0]]['outer_coil'] = np.zeros(3) + channel_info[vals[0]]["inner_coil"] = np.zeros(3) + channel_info[vals[0]]["outer_coil"] = np.zeros(3) return channel_info @@ -71,15 +69,16 @@ def _compute_mne_loc(coil_loc): Note input coil locations are in inches. """ loc = np.zeros((12)) - if (np.linalg.norm(coil_loc['inner_coil']) == 0) and \ - (np.linalg.norm(coil_loc['outer_coil']) == 0): + if (np.linalg.norm(coil_loc["inner_coil"]) == 0) and ( + np.linalg.norm(coil_loc["outer_coil"]) == 0 + ): return loc # channel location is inner coil location converted to meters From inches - loc[0:3] = coil_loc['inner_coil'] / 39.370078 + loc[0:3] = coil_loc["inner_coil"] / 39.370078 # figure out rotation - z_axis = coil_loc['outer_coil'] - coil_loc['inner_coil'] + z_axis = coil_loc["outer_coil"] - coil_loc["inner_coil"] R = rotation3d_align_z_axis(z_axis) loc[3:13] = R.T.reshape(9) return loc @@ -88,7 +87,7 @@ def _compute_mne_loc(coil_loc): def _read_pos(fname): """Read the .pos file and return positions as dig points.""" nas, lpa, rpa, hpi, extra = None, None, None, None, None - with open(fname, 'r') as fid: + with open(fname, "r") as fid: for line in fid: line = line.strip() if len(line) > 0: @@ -100,20 +99,19 @@ def _read_pos(fname): if len(parts) not in [4, 5]: continue - if parts[0].lower() == 'nasion': - nas = np.array([float(p) for p in parts[-3:]]) / 100. - elif parts[0].lower() == 'left': - lpa = np.array([float(p) for p in parts[-3:]]) / 100. - elif parts[0].lower() == 'right': - rpa = np.array([float(p) for p in parts[-3:]]) / 100. - elif 'hpi' in parts[0].lower(): + if parts[0].lower() == "nasion": + nas = np.array([float(p) for p in parts[-3:]]) / 100.0 + elif parts[0].lower() == "left": + lpa = np.array([float(p) for p in parts[-3:]]) / 100.0 + elif parts[0].lower() == "right": + rpa = np.array([float(p) for p in parts[-3:]]) / 100.0 + elif "hpi" in parts[0].lower(): if hpi is None: hpi = list() - hpi.append(np.array([float(p) for p in parts[-3:]]) / 100.) + hpi.append(np.array([float(p) for p in parts[-3:]]) / 100.0) else: if extra is None: extra = list() - extra.append(np.array([float(p) - for p in parts[-3:]]) / 100.) + extra.append(np.array([float(p) for p in parts[-3:]]) / 100.0) return _artemis123_read_pos(nas, lpa, rpa, hpi, extra) diff --git a/mne/io/base.py b/mne/io/base.py index e85ebe005d6..f44dce8daa5 100644 --- a/mne/io/base.py +++ b/mne/io/base.py @@ -21,33 +21,78 @@ from .constants import FIFF from .utils import _construct_bids_filename, _check_orig_units -from .pick import (pick_types, pick_channels, pick_info, _picks_to_idx, - channel_type) +from .pick import pick_types, pick_channels, pick_info, _picks_to_idx, channel_type from .meas_info import write_meas_info, _ensure_infos_match, ContainsMixin from .proj import setup_proj, activate_proj, _proj_equal, ProjMixin -from ..channels.channels import (UpdateChannelsMixin, SetChannelsMixin, - InterpolationMixin, _unit2human) +from ..channels.channels import ( + UpdateChannelsMixin, + SetChannelsMixin, + InterpolationMixin, + _unit2human, +) from .compensator import set_current_comp, make_compensator -from .write import (start_and_end_file, start_block, end_block, - write_dau_pack16, write_float, write_double, - write_complex64, write_complex128, write_int, - write_id, write_string, _get_split_size, _NEXT_FILE_BUFFER) - -from ..annotations import (Annotations, _annotations_starts_stops, - _combine_annotations, _handle_meas_date, - _sync_onset, _write_annotations) -from ..filter import (FilterMixin, notch_filter, resample, _resamp_ratio_len, - _resample_stim_channels, _check_fun) +from .write import ( + start_and_end_file, + start_block, + end_block, + write_dau_pack16, + write_float, + write_double, + write_complex64, + write_complex128, + write_int, + write_id, + write_string, + _get_split_size, + _NEXT_FILE_BUFFER, +) + +from ..annotations import ( + Annotations, + _annotations_starts_stops, + _combine_annotations, + _handle_meas_date, + _sync_onset, + _write_annotations, +) +from ..filter import ( + FilterMixin, + notch_filter, + resample, + _resamp_ratio_len, + _resample_stim_channels, + _check_fun, +) from ..parallel import parallel_func -from ..utils import (_check_fname, _check_pandas_installed, sizeof_fmt, - _check_pandas_index_arguments, fill_doc, copy_doc, - check_fname, _get_stim_channel, _stamp_to_dt, - logger, verbose, _time_mask, warn, SizeMixin, - copy_function_doc_to_method_doc, _validate_type, - _check_preload, _get_argvalues, _check_option, - _build_data_frame, _convert_times, _scale_dataframe_data, - _check_time_format, _arange_div, TimeMixin, repr_html, - _pl) +from ..utils import ( + _check_fname, + _check_pandas_installed, + sizeof_fmt, + _check_pandas_index_arguments, + fill_doc, + copy_doc, + check_fname, + _get_stim_channel, + _stamp_to_dt, + logger, + verbose, + _time_mask, + warn, + SizeMixin, + copy_function_doc_to_method_doc, + _validate_type, + _check_preload, + _get_argvalues, + _check_option, + _build_data_frame, + _convert_times, + _scale_dataframe_data, + _check_time_format, + _arange_div, + TimeMixin, + repr_html, + _pl, +) from ..defaults import _handle_default from ..viz import plot_raw, _RAW_CLIP_DEF from ..event import find_events, concatenate_events @@ -55,9 +100,17 @@ @fill_doc -class BaseRaw(ProjMixin, ContainsMixin, UpdateChannelsMixin, SetChannelsMixin, - InterpolationMixin, TimeMixin, SizeMixin, FilterMixin, - SpectrumMixin): +class BaseRaw( + ProjMixin, + ContainsMixin, + UpdateChannelsMixin, + SetChannelsMixin, + InterpolationMixin, + TimeMixin, + SizeMixin, + FilterMixin, + SpectrumMixin, +): """Base class for Raw data. Parameters @@ -118,20 +171,30 @@ class BaseRaw(ProjMixin, ContainsMixin, UpdateChannelsMixin, SetChannelsMixin, """ @verbose - def __init__(self, info, preload=False, - first_samps=(0,), last_samps=None, - filenames=(None,), raw_extras=(None,), - orig_format='double', dtype=np.float64, - buffer_size_sec=1., orig_units=None, - *, verbose=None): # noqa: D102 + def __init__( + self, + info, + preload=False, + first_samps=(0,), + last_samps=None, + filenames=(None,), + raw_extras=(None,), + orig_format="double", + dtype=np.float64, + buffer_size_sec=1.0, + orig_units=None, + *, + verbose=None, + ): # noqa: D102 # wait until the end to preload data, but triage here if isinstance(preload, np.ndarray): # some functions (e.g., filtering) only work w/64-bit data if preload.dtype not in (np.float64, np.complex128): - raise RuntimeError('datatype must be float64 or complex128, ' - 'not %s' % preload.dtype) + raise RuntimeError( + "datatype must be float64 or complex128, " "not %s" % preload.dtype + ) if preload.dtype != dtype: - raise ValueError('preload and dtype must match') + raise ValueError("preload and dtype must match") self._data = preload self.preload = True assert len(first_samps) == 1 @@ -139,8 +202,9 @@ def __init__(self, info, preload=False, load_from_disk = False else: if last_samps is None: - raise ValueError('last_samps must be given unless preload is ' - 'an ndarray') + raise ValueError( + "last_samps must be given unless preload is " "an ndarray" + ) if not preload: self.preload = False load_from_disk = False @@ -148,51 +212,50 @@ def __init__(self, info, preload=False, load_from_disk = True self._last_samps = np.array(last_samps) self._first_samps = np.array(first_samps) - orig_ch_names = info['ch_names'] + orig_ch_names = info["ch_names"] with info._unlock(check_after=True): # be permissive of old code - if isinstance(info['meas_date'], tuple): - info['meas_date'] = _stamp_to_dt(info['meas_date']) + if isinstance(info["meas_date"], tuple): + info["meas_date"] = _stamp_to_dt(info["meas_date"]) self.info = info self.buffer_size_sec = float(buffer_size_sec) - cals = np.empty(info['nchan']) - for k in range(info['nchan']): - cals[k] = info['chs'][k]['range'] * info['chs'][k]['cal'] + cals = np.empty(info["nchan"]) + for k in range(info["nchan"]): + cals[k] = info["chs"][k]["range"] * info["chs"][k]["cal"] bad = np.where(cals == 0)[0] if len(bad) > 0: - raise ValueError('Bad cals for channels %s' - % {ii: self.ch_names[ii] for ii in bad}) + raise ValueError( + "Bad cals for channels %s" % {ii: self.ch_names[ii] for ii in bad} + ) self._cals = cals self._raw_extras = list(dict() if r is None else r for r in raw_extras) for r in self._raw_extras: - r['orig_nchan'] = info['nchan'] - self._read_picks = [np.arange(info['nchan']) - for _ in range(len(raw_extras))] + r["orig_nchan"] = info["nchan"] + self._read_picks = [np.arange(info["nchan"]) for _ in range(len(raw_extras))] # deal with compensation (only relevant for CTF data, either CTF # reader or MNE-C converted CTF->FIF files) self._read_comp_grade = self.compensation_grade # read property - if self._read_comp_grade is not None and len(info['comps']): - logger.info('Current compensation grade : %d' - % self._read_comp_grade) + if self._read_comp_grade is not None and len(info["comps"]): + logger.info("Current compensation grade : %d" % self._read_comp_grade) self._comp = None self._filenames = list(filenames) _validate_type(orig_format, str, "orig_format") - _check_option( - "orig_format", orig_format, ("double", "single", "int", "short") - ) + _check_option("orig_format", orig_format, ("double", "single", "int", "short")) self.orig_format = orig_format # Sanity check and set original units, if provided by the reader: if orig_units: if not isinstance(orig_units, dict): - raise ValueError('orig_units must be of type dict, but got ' - ' {}'.format(type(orig_units))) + raise ValueError( + "orig_units must be of type dict, but got " + " {}".format(type(orig_units)) + ) # original units need to be truncated to 15 chars or renamed # to match MNE conventions (channel name unique and less than # 15 characters). orig_units = deepcopy(orig_units) - for old_ch, new_ch in zip(orig_ch_names, info['ch_names']): + for old_ch, new_ch in zip(orig_ch_names, info["ch_names"]): if old_ch in orig_units: this_unit = orig_units[old_ch] del orig_units[old_ch] @@ -200,18 +263,19 @@ def __init__(self, info, preload=False, # STI 014 channel is native only to fif ... for all other formats # this was artificially added by the IO procedure, so remove it - ch_names = list(info['ch_names']) - if ('STI 014' in ch_names) and not \ - (self.filenames[0].endswith('.fif')): - ch_names.remove('STI 014') + ch_names = list(info["ch_names"]) + if ("STI 014" in ch_names) and not (self.filenames[0].endswith(".fif")): + ch_names.remove("STI 014") # Each channel in the data must have a corresponding channel in # the original units. ch_correspond = [ch in orig_units for ch in ch_names] if not all(ch_correspond): ch_without_orig_unit = ch_names[ch_correspond.index(False)] - raise ValueError('Channel {} has no associated original ' - 'unit.'.format(ch_without_orig_unit)) + raise ValueError( + "Channel {} has no associated original " + "unit.".format(ch_without_orig_unit) + ) # Final check of orig_units, editing a unit if it is not a valid # unit @@ -254,22 +318,25 @@ def apply_gradient_compensation(self, grade, verbose=None): current_comp = self.compensation_grade if current_comp != grade: if self.proj: - raise RuntimeError('Cannot change compensation on data where ' - 'projectors have been applied') + raise RuntimeError( + "Cannot change compensation on data where " + "projectors have been applied" + ) # Figure out what operator to use (varies depending on preload) from_comp = current_comp if self.preload else self._read_comp_grade comp = make_compensator(self.info, from_comp, grade) - logger.info('Compensator constructed to change %d -> %d' - % (current_comp, grade)) + logger.info( + "Compensator constructed to change %d -> %d" % (current_comp, grade) + ) set_current_comp(self.info, grade) # We might need to apply it to our data now if self.preload: - logger.info('Applying compensator to loaded data') - lims = np.concatenate([np.arange(0, len(self.times), 10000), - [len(self.times)]]) + logger.info("Applying compensator to loaded data") + lims = np.concatenate( + [np.arange(0, len(self.times), 10000), [len(self.times)]] + ) for start, stop in zip(lims[:-1], lims[1:]): - self._data[:, start:stop] = np.dot( - comp, self._data[:, start:stop]) + self._data[:, start:stop] = np.dot(comp, self._data[:, start:stop]) else: self._comp = comp # store it for later use return self @@ -281,8 +348,9 @@ def _dtype(self): return self._dtype_ @verbose - def _read_segment(self, start=0, stop=None, sel=None, data_buffer=None, *, - verbose=None): + def _read_segment( + self, start=0, stop=None, sel=None, data_buffer=None, *, verbose=None + ): """Read a chunk of raw data. Parameters @@ -313,34 +381,35 @@ def _read_segment(self, start=0, stop=None, sel=None, data_buffer=None, *, stop = self.n_times if stop is None else min([int(stop), self.n_times]) if start >= stop: - raise ValueError('No data in this range') + raise ValueError("No data in this range") # Initialize the data and calibration vector if sel is None: - n_out = self.info['nchan'] + n_out = self.info["nchan"] idx = slice(None) else: n_out = len(sel) idx = _convert_slice(sel) del sel - assert n_out <= self.info['nchan'] + assert n_out <= self.info["nchan"] data_shape = (n_out, stop - start) dtype = self._dtype if isinstance(data_buffer, np.ndarray): if data_buffer.shape != data_shape: - raise ValueError('data_buffer has incorrect shape: %s != %s' - % (data_buffer.shape, data_shape)) + raise ValueError( + "data_buffer has incorrect shape: %s != %s" + % (data_buffer.shape, data_shape) + ) data = data_buffer else: data = _allocate_data(data_buffer, data_shape, dtype) # deal with having multiple files accessed by the raw object - cumul_lens = np.concatenate(([0], np.array(self._raw_lengths, - dtype='int'))) + cumul_lens = np.concatenate(([0], np.array(self._raw_lengths, dtype="int"))) cumul_lens = np.cumsum(cumul_lens) - files_used = np.logical_and(np.less(start, cumul_lens[1:]), - np.greater_equal(stop - 1, - cumul_lens[:-1])) + files_used = np.logical_and( + np.less(start, cumul_lens[1:]), np.greater_equal(stop - 1, cumul_lens[:-1]) + ) # set up cals and mult (cals, compensation, and projector) n_out = len(np.arange(len(self.ch_names))[idx]) @@ -366,8 +435,9 @@ def _read_segment(self, start=0, stop=None, sel=None, data_buffer=None, *, need_idx = np.where(np.any(mult, axis=0))[0] mult = mult[:, need_idx] logger.debug( - f'Reading {len(need_idx)}/{len(self.ch_names)} channels ' - f'due to projection') + f"Reading {len(need_idx)}/{len(self.ch_names)} channels " + f"due to projection" + ) assert (mult is None) ^ (cals is None) # xor # read from necessary files @@ -377,17 +447,27 @@ def _read_segment(self, start=0, stop=None, sel=None, data_buffer=None, *, # first iteration (only) could start in the middle somewhere if offset == 0: start_file += start - cumul_lens[fi] - stop_file = np.min([stop - cumul_lens[fi] + self._first_samps[fi], - self._last_samps[fi] + 1]) + stop_file = np.min( + [ + stop - cumul_lens[fi] + self._first_samps[fi], + self._last_samps[fi] + 1, + ] + ) if start_file < self._first_samps[fi] or stop_file < start_file: - raise ValueError('Bad array indexing, could be a bug') + raise ValueError("Bad array indexing, could be a bug") n_read = stop_file - start_file this_sl = slice(offset, offset + n_read) # reindex back to original file orig_idx = _convert_slice(self._read_picks[fi][need_idx]) _ReadSegmentFileProtector(self)._read_segment_file( - data[:, this_sl], orig_idx, fi, - int(start_file), int(stop_file), cals, mult) + data[:, this_sl], + orig_idx, + fi, + int(start_file), + int(stop_file), + cals, + mult, + ) offset += n_read return data @@ -424,9 +504,9 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): """ raise NotImplementedError - def _check_bad_segment(self, start, stop, picks, - reject_start, reject_stop, - reject_by_annotation=False): + def _check_bad_segment( + self, start, stop, picks, reject_start, reject_stop, reject_by_annotation=False + ): """Check if data segment is bad. If the slice is good, returns the data in desired range. @@ -459,13 +539,14 @@ def _check_bad_segment(self, start, stop, picks, return None if reject_by_annotation and len(self.annotations) > 0: annot = self.annotations - sfreq = self.info['sfreq'] + sfreq = self.info["sfreq"] onset = _sync_onset(self, annot.onset) overlaps = np.where(onset < reject_stop / sfreq) - overlaps = np.where(onset[overlaps] + annot.duration[overlaps] > - reject_start / sfreq) + overlaps = np.where( + onset[overlaps] + annot.duration[overlaps] > reject_start / sfreq + ) for descr in annot.description[overlaps]: - if descr.lower().startswith('bad'): + if descr.lower().startswith("bad"): return descr return self._getitem((picks, slice(start, stop)), return_times=False) @@ -498,17 +579,19 @@ def _preload_data(self, preload): data_buffer = preload if isinstance(preload, (bool, np.bool_)) and not preload: data_buffer = None - logger.info('Reading %d ... %d = %9.3f ... %9.3f secs...' % - (0, len(self.times) - 1, 0., self.times[-1])) + logger.info( + "Reading %d ... %d = %9.3f ... %9.3f secs..." + % (0, len(self.times) - 1, 0.0, self.times[-1]) + ) self._data = self._read_segment(data_buffer=data_buffer) - assert len(self._data) == self.info['nchan'] + assert len(self._data) == self.info["nchan"] self.preload = True self._comp = None # no longer needed self.close() @property def _first_time(self): - return self.first_samp / float(self.info['sfreq']) + return self.first_samp / float(self.info["sfreq"]) @property def first_samp(self): @@ -530,7 +613,7 @@ def last_samp(self): @property def _last_time(self): - return self.last_samp / float(self.info['sfreq']) + return self.last_samp / float(self.info["sfreq"]) def time_as_index(self, times, use_rounding=False, origin=None): """Convert time to indices. @@ -557,12 +640,15 @@ def time_as_index(self, times, use_rounding=False, origin=None): origin = _handle_meas_date(origin) if origin is None: delta = 0 - elif self.info['meas_date'] is None: - raise ValueError('origin must be None when info["meas_date"] ' - 'is None, got %s' % (origin,)) + elif self.info["meas_date"] is None: + raise ValueError( + 'origin must be None when info["meas_date"] ' + "is None, got %s" % (origin,) + ) else: - first_samp_in_abs_time = (self.info['meas_date'] + - timedelta(0, self._first_time)) + first_samp_in_abs_time = self.info["meas_date"] + timedelta( + 0, self._first_time + ) delta = (origin - first_samp_in_abs_time).total_seconds() times = np.atleast_1d(times) + delta @@ -571,8 +657,7 @@ def time_as_index(self, times, use_rounding=False, origin=None): @property def _raw_lengths(self): return [ - last - first + 1 - for first, last in zip(self._first_samps, self._last_samps) + last - first + 1 for first, last in zip(self._first_samps, self._last_samps) ] @property @@ -586,8 +671,9 @@ def filenames(self): return tuple(self._filenames) @verbose - def set_annotations(self, annotations, emit_warning=True, - on_missing='raise', *, verbose=None): + def set_annotations( + self, annotations, emit_warning=True, on_missing="raise", *, verbose=None + ): """Setter for annotations. This setter checks if they are inside the data range. @@ -607,37 +693,40 @@ def set_annotations(self, annotations, emit_warning=True, self : instance of Raw The raw object with annotations. """ - meas_date = _handle_meas_date(self.info['meas_date']) + meas_date = _handle_meas_date(self.info["meas_date"]) if annotations is None: self._annotations = Annotations([], [], [], meas_date) else: - _validate_type(annotations, Annotations, 'annotations') + _validate_type(annotations, Annotations, "annotations") if meas_date is None and annotations.orig_time is not None: - raise RuntimeError('Ambiguous operation. Setting an Annotation' - ' object with known ``orig_time`` to a raw' - ' object which has ``meas_date`` set to' - ' None is ambiguous. Please, either set a' - ' meaningful ``meas_date`` to the raw' - ' object; or set ``orig_time`` to None in' - ' which case the annotation onsets would be' - ' taken in reference to the first sample of' - ' the raw object.') - - delta = 1. / self.info['sfreq'] + raise RuntimeError( + "Ambiguous operation. Setting an Annotation" + " object with known ``orig_time`` to a raw" + " object which has ``meas_date`` set to" + " None is ambiguous. Please, either set a" + " meaningful ``meas_date`` to the raw" + " object; or set ``orig_time`` to None in" + " which case the annotation onsets would be" + " taken in reference to the first sample of" + " the raw object." + ) + + delta = 1.0 / self.info["sfreq"] new_annotations = annotations.copy() new_annotations._prune_ch_names(self.info, on_missing) if annotations.orig_time is None: - new_annotations.crop(0, self.times[-1] + delta, - emit_warning=emit_warning) + new_annotations.crop( + 0, self.times[-1] + delta, emit_warning=emit_warning + ) new_annotations.onset += self._first_time else: tmin = meas_date + timedelta(0, self._first_time) tmax = tmin + timedelta(seconds=self.times[-1] + delta) - new_annotations.crop(tmin=tmin, tmax=tmax, - emit_warning=emit_warning) + new_annotations.crop(tmin=tmin, tmax=tmax, emit_warning=emit_warning) new_annotations.onset -= ( - meas_date - new_annotations.orig_time).total_seconds() + meas_date - new_annotations.orig_time + ).total_seconds() new_annotations._orig_time = meas_date self._annotations = new_annotations @@ -646,8 +735,7 @@ def set_annotations(self, annotations, emit_warning=True, def __del__(self): # noqa: D105 # remove file for memmap - if hasattr(self, '_data') and \ - getattr(self._data, 'filename', None) is not None: + if hasattr(self, "_data") and getattr(self._data, "filename", None) is not None: # First, close the file out; happens automatically on del filename = self._data.filename del self._data @@ -675,29 +763,29 @@ def _parse_get_set_params(self, item): item = (item, slice(None, None, None)) if len(item) != 2: # should be channels and time instants - raise RuntimeError("Unable to access raw data (need both channels " - "and time)") + raise RuntimeError( + "Unable to access raw data (need both channels " "and time)" + ) sel = _picks_to_idx(self.info, item[0]) if isinstance(item[1], slice): time_slice = item[1] - start, stop, step = (time_slice.start, time_slice.stop, - time_slice.step) + start, stop, step = (time_slice.start, time_slice.stop, time_slice.step) else: item1 = item[1] # Let's do automated type conversion to integer here - if np.array(item[1]).dtype.kind == 'i': + if np.array(item[1]).dtype.kind == "i": item1 = int(item1) if isinstance(item1, (int, np.integer)): start, stop, step = item1, item1 + 1, 1 else: - raise ValueError('Must pass int or slice to __getitem__') + raise ValueError("Must pass int or slice to __getitem__") if start is None: start = 0 if step is not None and step != 1: - raise ValueError('step needs to be 1 : %d given' % step) + raise ValueError("step needs to be 1 : %d given" % step) if isinstance(sel, (int, np.integer)): sel = np.array([sel]) @@ -757,22 +845,32 @@ def _getitem(self, item, return_times=True): # times = self.times[start:stop] # stop can be None here so don't use it directly times = np.arange(start, start + data.shape[1], dtype=float) - times /= self.info['sfreq'] + times /= self.info["sfreq"] return data, times else: return data def __setitem__(self, item, value): """Set raw data content.""" - _check_preload(self, 'Modifying data of Raw') + _check_preload(self, "Modifying data of Raw") sel, start, stop = self._parse_get_set_params(item) # set the data self._data[sel, start:stop] = value @verbose - def get_data(self, picks=None, start=0, stop=None, - reject_by_annotation=None, return_times=False, units=None, - *, tmin=None, tmax=None, verbose=None): + def get_data( + self, + picks=None, + start=0, + stop=None, + reject_by_annotation=None, + return_times=False, + units=None, + *, + tmin=None, + tmax=None, + verbose=None, + ): """Get data in the given range. Parameters @@ -815,12 +913,12 @@ def get_data(self, picks=None, start=0, stop=None, .. versionadded:: 0.14.0 """ # validate types - _validate_type(start, types=('int-like'), item_name='start', - type_name='int') - _validate_type(stop, types=('int-like', None), item_name='stop', - type_name='int, None') + _validate_type(start, types=("int-like"), item_name="start", type_name="int") + _validate_type( + stop, types=("int-like", None), item_name="stop", type_name="int, None" + ) - picks = _picks_to_idx(self.info, picks, 'all', exclude=()) + picks = _picks_to_idx(self.info, picks, "all", exclude=()) # Get channel factors for conversion into specified unit # (vector of ones if no conversion needed) @@ -828,7 +926,7 @@ def get_data(self, picks=None, start=0, stop=None, ch_factors = _get_ch_factors(self, units, picks) # convert to ints - picks = np.atleast_1d(np.arange(self.info['nchan'])[picks]) + picks = np.atleast_1d(np.arange(self.info["nchan"])[picks]) # handle start/tmin stop/tmax tmin_start, tmax_stop = self._handle_tmin_tmax(tmin, tmax) @@ -844,7 +942,8 @@ def get_data(self, picks=None, start=0, stop=None, if len(self.annotations) == 0 or reject_by_annotation is None: getitem = self._getitem( - (picks, slice(start, stop)), return_times=return_times) + (picks, slice(start, stop)), return_times=return_times + ) if return_times: data, times = getitem if units is not None: @@ -853,9 +952,10 @@ def get_data(self, picks=None, start=0, stop=None, if units is not None: getitem *= ch_factors[:, np.newaxis] return getitem - _check_option('reject_by_annotation', reject_by_annotation.lower(), - ['omit', 'nan']) - onsets, ends = _annotations_starts_stops(self, ['BAD']) + _check_option( + "reject_by_annotation", reject_by_annotation.lower(), ["omit", "nan"] + ) + onsets, ends = _annotations_starts_stops(self, ["BAD"]) keep = (onsets < stop) & (ends > start) onsets = np.maximum(onsets[keep], start) ends = np.minimum(ends[keep], stop) @@ -871,19 +971,27 @@ def get_data(self, picks=None, start=0, stop=None, for onset, end in zip(onsets, ends): if onset >= end: continue - used[onset - start: end - start] = False + used[onset - start : end - start] = False used = np.concatenate([[False], used, [False]]) starts = np.where(~used[:-1] & used[1:])[0] + start stops = np.where(used[:-1] & ~used[1:])[0] + start n_kept = (stops - starts).sum() # kept samples n_rejected = n_samples - n_kept # rejected samples if n_rejected > 0: - if reject_by_annotation == 'omit': - msg = ("Omitting {} of {} ({:.2%}) samples, retaining {}" - " ({:.2%}) samples.") - logger.info(msg.format(n_rejected, n_samples, - n_rejected / n_samples, - n_kept, n_kept / n_samples)) + if reject_by_annotation == "omit": + msg = ( + "Omitting {} of {} ({:.2%}) samples, retaining {}" + " ({:.2%}) samples." + ) + logger.info( + msg.format( + n_rejected, + n_samples, + n_rejected / n_samples, + n_kept, + n_kept / n_samples, + ) + ) data = np.zeros((len(picks), n_kept)) times = np.zeros(data.shape[1]) idx = 0 @@ -894,11 +1002,19 @@ def get_data(self, picks=None, start=0, stop=None, data[:, idx:end], times[idx:end] = self[picks, start:stop] idx = end else: - msg = ("Setting {} of {} ({:.2%}) samples to NaN, retaining {}" - " ({:.2%}) samples.") - logger.info(msg.format(n_rejected, n_samples, - n_rejected / n_samples, - n_kept, n_kept / n_samples)) + msg = ( + "Setting {} of {} ({:.2%}) samples to NaN, retaining {}" + " ({:.2%}) samples." + ) + logger.info( + msg.format( + n_rejected, + n_samples, + n_rejected / n_samples, + n_kept, + n_kept / n_samples, + ) + ) data, times = self[picks, start:stop] data[:, ~used[1:-1]] = np.nan else: @@ -911,8 +1027,16 @@ def get_data(self, picks=None, start=0, stop=None, return data @verbose - def apply_function(self, fun, picks=None, dtype=None, n_jobs=None, - channel_wise=True, verbose=None, **kwargs): + def apply_function( + self, + fun, + picks=None, + dtype=None, + n_jobs=None, + channel_wise=True, + verbose=None, + **kwargs, + ): """Apply a function to a subset of channels. %(applyfun_summary_raw)s @@ -935,11 +1059,11 @@ def apply_function(self, fun, picks=None, dtype=None, n_jobs=None, self : instance of Raw The raw object with transformed data. """ - _check_preload(self, 'raw.apply_function') + _check_preload(self, "raw.apply_function") picks = _picks_to_idx(self.info, picks, exclude=(), with_ref_meg=False) if not callable(fun): - raise ValueError('fun needs to be a function') + raise ValueError("fun needs to be a function") data_in = self._data if dtype is not None and dtype != self._data.dtype: @@ -950,43 +1074,77 @@ def apply_function(self, fun, picks=None, dtype=None, n_jobs=None, if n_jobs == 1: # modify data inplace to save memory for idx in picks: - self._data[idx, :] = _check_fun(fun, data_in[idx, :], - **kwargs) + self._data[idx, :] = _check_fun(fun, data_in[idx, :], **kwargs) else: # use parallel function data_picks_new = parallel( - p_fun(fun, data_in[p], **kwargs) for p in picks) + p_fun(fun, data_in[p], **kwargs) for p in picks + ) for pp, p in enumerate(picks): self._data[p, :] = data_picks_new[pp] else: - self._data[picks, :] = _check_fun( - fun, data_in[picks, :], **kwargs) + self._data[picks, :] = _check_fun(fun, data_in[picks, :], **kwargs) return self # Need a separate method because the default pad is different for raw @copy_doc(FilterMixin.filter) - def filter(self, l_freq, h_freq, picks=None, filter_length='auto', - l_trans_bandwidth='auto', h_trans_bandwidth='auto', n_jobs=None, - method='fir', iir_params=None, phase='zero', - fir_window='hamming', fir_design='firwin', - skip_by_annotation=('edge', 'bad_acq_skip'), - pad='reflect_limited', verbose=None): # noqa: D102 + def filter( + self, + l_freq, + h_freq, + picks=None, + filter_length="auto", + l_trans_bandwidth="auto", + h_trans_bandwidth="auto", + n_jobs=None, + method="fir", + iir_params=None, + phase="zero", + fir_window="hamming", + fir_design="firwin", + skip_by_annotation=("edge", "bad_acq_skip"), + pad="reflect_limited", + verbose=None, + ): # noqa: D102 return super().filter( - l_freq, h_freq, picks, filter_length, l_trans_bandwidth, - h_trans_bandwidth, n_jobs=n_jobs, method=method, - iir_params=iir_params, phase=phase, fir_window=fir_window, - fir_design=fir_design, skip_by_annotation=skip_by_annotation, - pad=pad, verbose=verbose) + l_freq, + h_freq, + picks, + filter_length, + l_trans_bandwidth, + h_trans_bandwidth, + n_jobs=n_jobs, + method=method, + iir_params=iir_params, + phase=phase, + fir_window=fir_window, + fir_design=fir_design, + skip_by_annotation=skip_by_annotation, + pad=pad, + verbose=verbose, + ) @verbose - def notch_filter(self, freqs, picks=None, filter_length='auto', - notch_widths=None, trans_bandwidth=1.0, n_jobs=None, - method='fir', iir_params=None, mt_bandwidth=None, - p_value=0.05, phase='zero', fir_window='hamming', - fir_design='firwin', pad='reflect_limited', - skip_by_annotation=('edge', 'bad_acq_skip'), - verbose=None): + def notch_filter( + self, + freqs, + picks=None, + filter_length="auto", + notch_widths=None, + trans_bandwidth=1.0, + n_jobs=None, + method="fir", + iir_params=None, + mt_bandwidth=None, + p_value=0.05, + phase="zero", + fir_window="hamming", + fir_design="firwin", + pad="reflect_limited", + skip_by_annotation=("edge", "bad_acq_skip"), + verbose=None, + ): """Notch filter a subset of channels. Parameters @@ -1050,28 +1208,47 @@ def notch_filter(self, freqs, picks=None, filter_length='auto', For details, see :func:`mne.filter.notch_filter`. """ - fs = float(self.info['sfreq']) - picks = _picks_to_idx(self.info, picks, exclude=(), none='data_or_ica') - _check_preload(self, 'raw.notch_filter') - onsets, ends = _annotations_starts_stops( - self, skip_by_annotation, invert=True) - logger.info('Filtering raw data in %d contiguous segment%s' - % (len(onsets), _pl(onsets))) + fs = float(self.info["sfreq"]) + picks = _picks_to_idx(self.info, picks, exclude=(), none="data_or_ica") + _check_preload(self, "raw.notch_filter") + onsets, ends = _annotations_starts_stops(self, skip_by_annotation, invert=True) + logger.info( + "Filtering raw data in %d contiguous segment%s" % (len(onsets), _pl(onsets)) + ) for si, (start, stop) in enumerate(zip(onsets, ends)): notch_filter( - self._data[:, start:stop], fs, freqs, - filter_length=filter_length, notch_widths=notch_widths, - trans_bandwidth=trans_bandwidth, method=method, - iir_params=iir_params, mt_bandwidth=mt_bandwidth, - p_value=p_value, picks=picks, n_jobs=n_jobs, copy=False, - phase=phase, fir_window=fir_window, fir_design=fir_design, - pad=pad) + self._data[:, start:stop], + fs, + freqs, + filter_length=filter_length, + notch_widths=notch_widths, + trans_bandwidth=trans_bandwidth, + method=method, + iir_params=iir_params, + mt_bandwidth=mt_bandwidth, + p_value=p_value, + picks=picks, + n_jobs=n_jobs, + copy=False, + phase=phase, + fir_window=fir_window, + fir_design=fir_design, + pad=pad, + ) return self @verbose - def resample(self, sfreq, npad='auto', window='boxcar', stim_picks=None, - n_jobs=None, events=None, pad='reflect_limited', - verbose=None): + def resample( + self, + sfreq, + npad="auto", + window="boxcar", + stim_picks=None, + n_jobs=None, + events=None, + pad="reflect_limited", + verbose=None, + ): """Resample all channels. If appropriate, an anti-aliasing filter is applied before resampling. @@ -1151,31 +1328,37 @@ def resample(self, sfreq, npad='auto', window='boxcar', stim_picks=None, pass sfreq = float(sfreq) - o_sfreq = float(self.info['sfreq']) + o_sfreq = float(self.info["sfreq"]) offsets = np.concatenate(([0], np.cumsum(self._raw_lengths))) # set up stim channel processing if stim_picks is None: - stim_picks = pick_types(self.info, meg=False, ref_meg=False, - stim=True, exclude=[]) + stim_picks = pick_types( + self.info, meg=False, ref_meg=False, stim=True, exclude=[] + ) else: - stim_picks = _picks_to_idx(self.info, stim_picks, exclude=(), - with_ref_meg=False) + stim_picks = _picks_to_idx( + self.info, stim_picks, exclude=(), with_ref_meg=False + ) - kwargs = dict(up=sfreq, down=o_sfreq, npad=npad, window=window, - n_jobs=n_jobs, pad=pad) - ratio, n_news = zip(*(_resamp_ratio_len(sfreq, o_sfreq, old_len) - for old_len in self._raw_lengths)) + kwargs = dict( + up=sfreq, down=o_sfreq, npad=npad, window=window, n_jobs=n_jobs, pad=pad + ) + ratio, n_news = zip( + *( + _resamp_ratio_len(sfreq, o_sfreq, old_len) + for old_len in self._raw_lengths + ) + ) ratio, n_news = ratio[0], np.array(n_news, int) new_offsets = np.cumsum([0] + list(n_news)) if self.preload: - new_data = np.empty( - (len(self.ch_names), new_offsets[-1]), self._data.dtype) + new_data = np.empty((len(self.ch_names), new_offsets[-1]), self._data.dtype) for ri, (n_orig, n_new) in enumerate(zip(self._raw_lengths, n_news)): this_sl = slice(new_offsets[ri], new_offsets[ri + 1]) if self.preload: - data_chunk = self._data[:, offsets[ri]:offsets[ri + 1]] + data_chunk = self._data[:, offsets[ri] : offsets[ri + 1]] new_data[:, this_sl] = resample(data_chunk, **kwargs) # In empirical testing, it was faster to resample all channels # (above) and then replace the stim channels than it was to @@ -1183,34 +1366,37 @@ def resample(self, sfreq, npad='auto', window='boxcar', stim_picks=None, # np.insert() to restore the stims. if len(stim_picks) > 0: new_data[stim_picks, this_sl] = _resample_stim_channels( - data_chunk[stim_picks], n_new, data_chunk.shape[1]) + data_chunk[stim_picks], n_new, data_chunk.shape[1] + ) else: # this will not be I/O efficient, but will be mem efficient for ci in range(len(self.ch_names)): data_chunk = self.get_data( - ci, offsets[ri], offsets[ri + 1], verbose='error')[0] + ci, offsets[ri], offsets[ri + 1], verbose="error" + )[0] if ci == 0 and ri == 0: new_data = np.empty( - (len(self.ch_names), new_offsets[-1]), - data_chunk.dtype) + (len(self.ch_names), new_offsets[-1]), data_chunk.dtype + ) if ci in stim_picks: resamp = _resample_stim_channels( - data_chunk, n_new, data_chunk.shape[-1])[0] + data_chunk, n_new, data_chunk.shape[-1] + )[0] else: resamp = resample(data_chunk, **kwargs) new_data[ci, this_sl] = resamp self._cropped_samp = int(np.round(self._cropped_samp * ratio)) self._first_samps = np.round(self._first_samps * ratio).astype(int) - self._last_samps = (np.array(self._first_samps) + n_news - 1) + self._last_samps = np.array(self._first_samps) + n_news - 1 self._raw_lengths[ri] = list(n_news) assert np.array_equal(n_news, self._last_samps - self._first_samps + 1) self._data = new_data self.preload = True - lowpass = self.info.get('lowpass') + lowpass = self.info.get("lowpass") lowpass = np.inf if lowpass is None else lowpass with self.info._unlock(): - self.info['lowpass'] = min(lowpass, sfreq / 2.) - self.info['sfreq'] = sfreq + self.info["lowpass"] = min(lowpass, sfreq / 2.0) + self.info["sfreq"] = sfreq # See the comment above why we ignore all errors here. if events is None: @@ -1218,10 +1404,12 @@ def resample(self, sfreq, npad='auto', window='boxcar', stim_picks=None, # Did we loose events? resampled_events = find_events(self) if len(resampled_events) != len(original_events): - warn('Resampling of the stim channels caused event ' - 'information to become unreliable. Consider finding ' - 'events on the original data and passing the event ' - 'matrix as a parameter.') + warn( + "Resampling of the stim channels caused event " + "information to become unreliable. Consider finding " + "events on the original data and passing the event " + "matrix as a parameter." + ) except Exception: pass @@ -1232,7 +1420,7 @@ def resample(self, sfreq, npad='auto', window='boxcar', stim_picks=None, events[:, 0] = np.minimum( np.round(events[:, 0] * ratio).astype(int), - self._data.shape[1] + self.first_samp - 1 + self._data.shape[1] + self.first_samp - 1, ) return self, events @@ -1261,27 +1449,34 @@ def crop(self, tmin=0.0, tmax=None, include_tmax=True, *, verbose=None): raw : instance of Raw The cropped raw object, modified in-place. """ - max_time = (self.n_times - 1) / self.info['sfreq'] + max_time = (self.n_times - 1) / self.info["sfreq"] if tmax is None: tmax = max_time if tmin > tmax: - raise ValueError('tmin (%s) must be less than tmax (%s)' - % (tmin, tmax)) + raise ValueError("tmin (%s) must be less than tmax (%s)" % (tmin, tmax)) if tmin < 0.0: - raise ValueError('tmin (%s) must be >= 0' % (tmin,)) - elif tmax - int(not include_tmax) / self.info['sfreq'] > max_time: - raise ValueError('tmax (%s) must be less than or equal to the max ' - 'time (%0.4f s)' % (tmax, max_time)) - - smin, smax = np.where(_time_mask( - self.times, tmin, tmax, sfreq=self.info['sfreq'], - include_tmax=include_tmax))[0][[0, -1]] - cumul_lens = np.concatenate(([0], np.array(self._raw_lengths, - dtype='int'))) + raise ValueError("tmin (%s) must be >= 0" % (tmin,)) + elif tmax - int(not include_tmax) / self.info["sfreq"] > max_time: + raise ValueError( + "tmax (%s) must be less than or equal to the max " + "time (%0.4f s)" % (tmax, max_time) + ) + + smin, smax = np.where( + _time_mask( + self.times, + tmin, + tmax, + sfreq=self.info["sfreq"], + include_tmax=include_tmax, + ) + )[0][[0, -1]] + cumul_lens = np.concatenate(([0], np.array(self._raw_lengths, dtype="int"))) cumul_lens = np.cumsum(cumul_lens) - keepers = np.logical_and(np.less(smin, cumul_lens[1:]), - np.greater_equal(smax, cumul_lens[:-1])) + keepers = np.logical_and( + np.less(smin, cumul_lens[1:]), np.greater_equal(smax, cumul_lens[:-1]) + ) keepers = np.where(keepers)[0] # if we drop file(s) from the beginning, we need to keep track of # how many samples we dropped relative to that one @@ -1292,18 +1487,17 @@ def crop(self, tmin=0.0, tmax=None, include_tmax=True, *, verbose=None): self._last_samps = np.atleast_1d(self._last_samps[keepers]) self._last_samps[-1] -= cumul_lens[keepers[-1] + 1] - 1 - smax self._read_picks = [self._read_picks[ri] for ri in keepers] - assert all(len(r) == len(self._read_picks[0]) - for r in self._read_picks) + assert all(len(r) == len(self._read_picks[0]) for r in self._read_picks) self._raw_extras = [self._raw_extras[ri] for ri in keepers] self._filenames = [self._filenames[ri] for ri in keepers] if self.preload: # slice and copy to avoid the reference to large array - self._data = self._data[:, smin:smax + 1].copy() + self._data = self._data[:, smin : smax + 1].copy() annotations = self.annotations # now call setter to filter out annotations outside of interval if annotations.orig_time is None: - assert self.info['meas_date'] is None + assert self.info["meas_date"] is None # When self.info['meas_date'] is None (which is guaranteed if # self.annotations.orig_time is None), when we do the # self.set_annotations, it's assumed that the annotations onset @@ -1338,7 +1532,7 @@ def crop_by_annotations(self, annotations=None, *, verbose=None): onset = annot["onset"] - self.first_time # be careful about near-zero errors (crop is very picky about this, # e.g., -1e-8 is an error) - if -self.info['sfreq'] / 2 < onset < 0: + if -self.info["sfreq"] / 2 < onset < 0: onset = 0 raw_crop = self.copy().crop(onset, onset + annot["duration"]) raws.append(raw_crop) @@ -1346,10 +1540,21 @@ def crop_by_annotations(self, annotations=None, *, verbose=None): return raws @verbose - def save(self, fname, picks=None, tmin=0, tmax=None, buffer_size_sec=None, - drop_small_buffer=False, proj=False, fmt='single', - overwrite=False, split_size='2GB', split_naming='neuromag', - verbose=None): + def save( + self, + fname, + picks=None, + tmin=0, + tmax=None, + buffer_size_sec=None, + drop_small_buffer=False, + proj=False, + fmt="single", + overwrite=False, + split_size="2GB", + split_naming="neuromag", + verbose=None, + ): """Save raw data to file. Parameters @@ -1415,48 +1620,59 @@ def save(self, fname, picks=None, tmin=0, tmax=None, buffer_size_sec=None, Samples annotated ``BAD_ACQ_SKIP`` are not stored in order to optimize memory. Whatever values, they will be loaded as 0s when reading file. """ - endings = ('raw.fif', 'raw_sss.fif', 'raw_tsss.fif', - '_meg.fif', '_eeg.fif', '_ieeg.fif') - endings += tuple([f'{e}.gz' for e in endings]) - endings_err = ('.fif', '.fif.gz') + endings = ( + "raw.fif", + "raw_sss.fif", + "raw_tsss.fif", + "_meg.fif", + "_eeg.fif", + "_ieeg.fif", + ) + endings += tuple([f"{e}.gz" for e in endings]) + endings_err = (".fif", ".fif.gz") # convert to str, check for overwrite a few lines later fname = str(_check_fname(fname, overwrite=True, verbose="error")) - check_fname(fname, 'raw', endings, endings_err=endings_err) + check_fname(fname, "raw", endings, endings_err=endings_err) split_size = _get_split_size(split_size) if not self.preload and fname in self._filenames: - raise ValueError('You cannot save data to the same file.' - ' Please use a different filename.') + raise ValueError( + "You cannot save data to the same file." + " Please use a different filename." + ) if self.preload: if np.iscomplexobj(self._data): - warn('Saving raw file with complex data. Loading with ' - 'command-line MNE tools will not work.') - - type_dict = dict(short=FIFF.FIFFT_DAU_PACK16, - int=FIFF.FIFFT_INT, - single=FIFF.FIFFT_FLOAT, - double=FIFF.FIFFT_DOUBLE) - _check_option('fmt', fmt, type_dict.keys()) + warn( + "Saving raw file with complex data. Loading with " + "command-line MNE tools will not work." + ) + + type_dict = dict( + short=FIFF.FIFFT_DAU_PACK16, + int=FIFF.FIFFT_INT, + single=FIFF.FIFFT_FLOAT, + double=FIFF.FIFFT_DOUBLE, + ) + _check_option("fmt", fmt, type_dict.keys()) reset_dict = dict(short=False, int=False, single=True, double=True) reset_range = reset_dict[fmt] data_type = type_dict[fmt] data_test = self[0, 0][0] - if fmt == 'short' and np.iscomplexobj(data_test): - raise ValueError('Complex data must be saved as "single" or ' - '"double", not "short"') + if fmt == "short" and np.iscomplexobj(data_test): + raise ValueError( + 'Complex data must be saved as "single" or ' '"double", not "short"' + ) # check for file existence and expand `~` if present - fname = str( - _check_fname(fname=fname, overwrite=overwrite, verbose="error") - ) + fname = str(_check_fname(fname=fname, overwrite=overwrite, verbose="error")) if proj: info = deepcopy(self.info) projector, info = setup_proj(info) - activate_proj(info['projs'], copy=False) + activate_proj(info["projs"], copy=False) else: info = self.info projector = None @@ -1470,15 +1686,39 @@ def save(self, fname, picks=None, tmin=0, tmax=None, buffer_size_sec=None, buffer_size = self._get_buffer_size(buffer_size_sec) # write the raw file - _validate_type(split_naming, str, 'split_naming') - _check_option('split_naming', split_naming, ('neuromag', 'bids')) - _write_raw(fname, self, info, picks, fmt, data_type, reset_range, - start, stop, buffer_size, projector, drop_small_buffer, - split_size, split_naming, 0, None, overwrite) + _validate_type(split_naming, str, "split_naming") + _check_option("split_naming", split_naming, ("neuromag", "bids")) + _write_raw( + fname, + self, + info, + picks, + fmt, + data_type, + reset_range, + start, + stop, + buffer_size, + projector, + drop_small_buffer, + split_size, + split_naming, + 0, + None, + overwrite, + ) @verbose - def export(self, fname, fmt='auto', physical_range='auto', - add_ch_type=False, *, overwrite=False, verbose=None): + def export( + self, + fname, + fmt="auto", + physical_range="auto", + add_ch_type=False, + *, + overwrite=False, + verbose=None, + ): """Export Raw to external formats. %(export_fmt_support_raw)s @@ -1505,12 +1745,19 @@ def export(self, fname, fmt='auto', physical_range='auto', %(export_edf_note)s """ from ..export import export_raw - export_raw(fname, self, fmt, physical_range=physical_range, - add_ch_type=add_ch_type, overwrite=overwrite, - verbose=verbose) + + export_raw( + fname, + self, + fmt, + physical_range=physical_range, + add_ch_type=add_ch_type, + overwrite=overwrite, + verbose=verbose, + ) def _tmin_tmax_to_start_stop(self, tmin, tmax): - start = int(np.floor(tmin * self.info['sfreq'])) + start = int(np.floor(tmin * self.info["sfreq"])) # "stop" is the first sample *not* to save, so we need +1's here if tmax is None: @@ -1519,42 +1766,98 @@ def _tmin_tmax_to_start_stop(self, tmin, tmax): stop = self.time_as_index(float(tmax), use_rounding=True)[0] + 1 stop = min(stop, self.last_samp - self.first_samp + 1) if stop <= start or stop <= 0: - raise ValueError('tmin (%s) and tmax (%s) yielded no samples' - % (tmin, tmax)) + raise ValueError( + "tmin (%s) and tmax (%s) yielded no samples" % (tmin, tmax) + ) return start, stop @copy_function_doc_to_method_doc(plot_raw) - def plot(self, events=None, duration=10.0, start=0.0, n_channels=20, - bgcolor='w', color=None, bad_color='lightgray', - event_color='cyan', scalings=None, remove_dc=True, order=None, - show_options=False, title=None, show=True, block=False, - highpass=None, lowpass=None, filtorder=4, clipping=_RAW_CLIP_DEF, - show_first_samp=False, proj=True, group_by='type', - butterfly=False, decim='auto', noise_cov=None, event_id=None, - show_scrollbars=True, show_scalebars=True, time_format='float', - precompute=None, use_opengl=None, *, theme=None, - overview_mode=None, verbose=None): - return plot_raw(self, events, duration, start, n_channels, bgcolor, - color, bad_color, event_color, scalings, remove_dc, - order, show_options, title, show, block, highpass, - lowpass, filtorder, clipping, show_first_samp, - proj, group_by, butterfly, decim, noise_cov=noise_cov, - event_id=event_id, show_scrollbars=show_scrollbars, - show_scalebars=show_scalebars, time_format=time_format, - precompute=precompute, use_opengl=use_opengl, - theme=theme, overview_mode=overview_mode, - verbose=verbose) + def plot( + self, + events=None, + duration=10.0, + start=0.0, + n_channels=20, + bgcolor="w", + color=None, + bad_color="lightgray", + event_color="cyan", + scalings=None, + remove_dc=True, + order=None, + show_options=False, + title=None, + show=True, + block=False, + highpass=None, + lowpass=None, + filtorder=4, + clipping=_RAW_CLIP_DEF, + show_first_samp=False, + proj=True, + group_by="type", + butterfly=False, + decim="auto", + noise_cov=None, + event_id=None, + show_scrollbars=True, + show_scalebars=True, + time_format="float", + precompute=None, + use_opengl=None, + *, + theme=None, + overview_mode=None, + verbose=None, + ): + return plot_raw( + self, + events, + duration, + start, + n_channels, + bgcolor, + color, + bad_color, + event_color, + scalings, + remove_dc, + order, + show_options, + title, + show, + block, + highpass, + lowpass, + filtorder, + clipping, + show_first_samp, + proj, + group_by, + butterfly, + decim, + noise_cov=noise_cov, + event_id=event_id, + show_scrollbars=show_scrollbars, + show_scalebars=show_scalebars, + time_format=time_format, + precompute=precompute, + use_opengl=use_opengl, + theme=theme, + overview_mode=overview_mode, + verbose=verbose, + ) @property def ch_names(self): """Channel names.""" - return self.info['ch_names'] + return self.info["ch_names"] @property def times(self): """Time points.""" - out = _arange_div(self.n_times, float(self.info['sfreq'])) - out.flags['WRITEABLE'] = False + out = _arange_div(self.n_times, float(self.info["sfreq"])) + out.flags["WRITEABLE"] = False return out @property @@ -1599,29 +1902,31 @@ def load_bad_channels(self, bad_file=None, force=False, verbose=None): raising an error. Defaults to ``False``. %(verbose)s """ - prev_bads = self.info['bads'] + prev_bads = self.info["bads"] new_bads = [] if bad_file is not None: # Check to make sure bad channels are there - names = frozenset(self.info['ch_names']) + names = frozenset(self.info["ch_names"]) with open(bad_file) as fid: bad_names = [line for line in fid.read().splitlines() if line] new_bads = [ci for ci in bad_names if ci in names] count_diff = len(bad_names) - len(new_bads) if count_diff > 0: - msg = (f'{count_diff} bad channel(s) from:' - f'\n{bad_file}\nnot found in:\n{self.filenames[0]}') + msg = ( + f"{count_diff} bad channel(s) from:" + f"\n{bad_file}\nnot found in:\n{self.filenames[0]}" + ) if not force: raise ValueError(msg) else: warn(msg) if prev_bads != new_bads: - logger.info(f'Updating bad channels: {prev_bads} -> {new_bads}') - self.info['bads'] = new_bads + logger.info(f"Updating bad channels: {prev_bads} -> {new_bads}") + self.info["bads"] = new_bads else: - logger.info(f'No channels updated. Bads are: {prev_bads}') + logger.info(f"No channels updated. Bads are: {prev_bads}") @fill_doc def append(self, raws, preload=None): @@ -1661,7 +1966,7 @@ def append(self, raws, preload=None): self.preload = False else: # do the concatenation ourselves since preload might be a string - nchan = self.info['nchan'] + nchan = self.info["nchan"] c_ns = np.cumsum([rr.n_times for rr in ([self] + raws)]) nsamp = c_ns[-1] @@ -1672,51 +1977,58 @@ def append(self, raws, preload=None): # allocate the buffer _data = _allocate_data(preload, (nchan, nsamp), this_data.dtype) - _data[:, 0:c_ns[0]] = this_data + _data[:, 0 : c_ns[0]] = this_data for ri in range(len(raws)): if not raws[ri].preload: # read the data directly into the buffer - data_buffer = _data[:, c_ns[ri]:c_ns[ri + 1]] + data_buffer = _data[:, c_ns[ri] : c_ns[ri + 1]] raws[ri]._read_segment(data_buffer=data_buffer) else: - _data[:, c_ns[ri]:c_ns[ri + 1]] = raws[ri]._data + _data[:, c_ns[ri] : c_ns[ri + 1]] = raws[ri]._data self._data = _data self.preload = True # now combine information from each raw file to construct new self annotations = self.annotations - assert annotations.orig_time == self.info['meas_date'] + assert annotations.orig_time == self.info["meas_date"] edge_samps = list() for ri, r in enumerate(raws): n_samples = self.last_samp - self.first_samp + 1 annotations = _combine_annotations( - annotations, r.annotations, n_samples, - self.first_samp, r.first_samp, - self.info['sfreq']) - edge_samps.append(sum(self._last_samps) - - sum(self._first_samps) + (ri + 1)) + annotations, + r.annotations, + n_samples, + self.first_samp, + r.first_samp, + self.info["sfreq"], + ) + edge_samps.append(sum(self._last_samps) - sum(self._first_samps) + (ri + 1)) self._first_samps = np.r_[self._first_samps, r._first_samps] self._last_samps = np.r_[self._last_samps, r._last_samps] self._read_picks += r._read_picks self._raw_extras += r._raw_extras self._filenames += r._filenames - assert annotations.orig_time == self.info['meas_date'] + assert annotations.orig_time == self.info["meas_date"] # The above _combine_annotations gets everything synchronized to # first_samp. set_annotations (with no absolute time reference) assumes # that the annotations being set are relative to first_samp, and will # add it back on. So here we have to remove it: if annotations.orig_time is None: - annotations.onset -= self.first_samp / self.info['sfreq'] + annotations.onset -= self.first_samp / self.info["sfreq"] self.set_annotations(annotations) for edge_samp in edge_samps: - onset = _sync_onset(self, (edge_samp) / self.info['sfreq'], True) - self.annotations.append(onset, 0., 'BAD boundary') - self.annotations.append(onset, 0., 'EDGE boundary') - if not (len(self._first_samps) == len(self._last_samps) == - len(self._raw_extras) == len(self._filenames) == - len(self._read_picks)): - raise RuntimeError('Append error') # should never happen + onset = _sync_onset(self, (edge_samp) / self.info["sfreq"], True) + self.annotations.append(onset, 0.0, "BAD boundary") + self.annotations.append(onset, 0.0, "EDGE boundary") + if not ( + len(self._first_samps) + == len(self._last_samps) + == len(self._raw_extras) + == len(self._filenames) + == len(self._read_picks) + ): + raise RuntimeError("Append error") # should never happen def close(self): """Clean up the object. @@ -1738,20 +2050,23 @@ def copy(self): def __repr__(self): # noqa: D105 name = self.filenames[0] - name = '' if name is None else op.basename(name) + ', ' + name = "" if name is None else op.basename(name) + ", " size_str = str(sizeof_fmt(self._size)) # str in case it fails -> None - size_str += ', data%s loaded' % ('' if self.preload else ' not') - s = ('%s%s x %s (%0.1f s), ~%s' - % (name, len(self.ch_names), self.n_times, self.times[-1], - size_str)) + size_str += ", data%s loaded" % ("" if self.preload else " not") + s = "%s%s x %s (%0.1f s), ~%s" % ( + name, + len(self.ch_names), + self.n_times, + self.times[-1], + size_str, + ) return "<%s | %s>" % (self.__class__.__name__, s) @repr_html def _repr_html_(self, caption=None): from ..html_templates import repr_templates_env - basenames = [ - os.path.basename(f) for f in self._filenames if f is not None - ] + + basenames = [os.path.basename(f) for f in self._filenames if f is not None] # https://stackoverflow.com/a/10981895 duration = timedelta(seconds=self.times[-1]) @@ -1760,11 +2075,12 @@ def _repr_html_(self, caption=None): seconds += duration.microseconds / 1e6 seconds = np.ceil(seconds) # always take full seconds - duration = f'{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}' - raw_template = repr_templates_env.get_template('raw.html.jinja') + duration = f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}" + raw_template = repr_templates_env.get_template("raw.html.jinja") return raw_template.render( info_repr=self.info._repr_html_(caption=caption), - filenames=basenames, duration=duration + filenames=basenames, + duration=duration, ) def add_events(self, events, stim_channel=None, replace=False): @@ -1790,23 +2106,25 @@ def add_events(self, events, stim_channel=None, replace=False): ----- Data must be preloaded in order to add events. """ - _check_preload(self, 'Adding events') + _check_preload(self, "Adding events") events = np.asarray(events) if events.ndim != 2 or events.shape[1] != 3: - raise ValueError('events must be shape (n_events, 3)') + raise ValueError("events must be shape (n_events, 3)") stim_channel = _get_stim_channel(stim_channel, self.info) pick = pick_channels(self.ch_names, stim_channel, ordered=False) if len(pick) == 0: - raise ValueError('Channel %s not found' % stim_channel) + raise ValueError("Channel %s not found" % stim_channel) pick = pick[0] idx = events[:, 0].astype(int) if np.any(idx < self.first_samp) or np.any(idx > self.last_samp): - raise ValueError('event sample numbers must be between %s and %s' - % (self.first_samp, self.last_samp)) + raise ValueError( + "event sample numbers must be between %s and %s" + % (self.first_samp, self.last_samp) + ) if not all(idx == events[:, 0]): - raise ValueError('event sample numbers must be integers') + raise ValueError("event sample numbers must be integers") if replace: - self._data[pick, :] = 0. + self._data[pick, :] = 0.0 self._data[pick, idx - self.first_samp] += events[:, 2] def _get_buffer_size(self, buffer_size_sec=None): @@ -1814,13 +2132,24 @@ def _get_buffer_size(self, buffer_size_sec=None): if buffer_size_sec is None: buffer_size_sec = self.buffer_size_sec buffer_size_sec = float(buffer_size_sec) - return int(np.ceil(buffer_size_sec * self.info['sfreq'])) + return int(np.ceil(buffer_size_sec * self.info["sfreq"])) @verbose - def compute_psd(self, method='welch', fmin=0, fmax=np.inf, tmin=None, - tmax=None, picks=None, proj=False, - reject_by_annotation=True, *, n_jobs=1, verbose=None, - **method_kw): + def compute_psd( + self, + method="welch", + fmin=0, + fmax=np.inf, + tmin=None, + tmax=None, + picks=None, + proj=False, + reject_by_annotation=True, + *, + n_jobs=1, + verbose=None, + **method_kw, + ): """Perform spectral analysis on sensor data. Parameters @@ -1853,15 +2182,34 @@ def compute_psd(self, method='welch', fmin=0, fmax=np.inf, tmin=None, self._set_legacy_nfft_default(tmin, tmax, method, method_kw) return Spectrum( - self, method=method, fmin=fmin, fmax=fmax, tmin=tmin, tmax=tmax, - picks=picks, proj=proj, reject_by_annotation=reject_by_annotation, - n_jobs=n_jobs, verbose=verbose, **method_kw) + self, + method=method, + fmin=fmin, + fmax=fmax, + tmin=tmin, + tmax=tmax, + picks=picks, + proj=proj, + reject_by_annotation=reject_by_annotation, + n_jobs=n_jobs, + verbose=verbose, + **method_kw, + ) @verbose - def to_data_frame(self, picks=None, index=None, - scalings=None, copy=True, start=None, stop=None, - long_format=False, time_format=None, *, - verbose=None): + def to_data_frame( + self, + picks=None, + index=None, + scalings=None, + copy=True, + start=None, + stop=None, + long_format=False, + time_format=None, + *, + verbose=None, + ): """Export data in tabular structure as a pandas DataFrame. Channels are converted to columns in the DataFrame. By default, an @@ -1895,13 +2243,14 @@ def to_data_frame(self, picks=None, index=None, # check pandas once here, instead of in each private utils function pd = _check_pandas_installed() # noqa # arg checking - valid_index_args = ['time'] - valid_time_formats = ['ms', 'timedelta', 'datetime'] + valid_index_args = ["time"] + valid_time_formats = ["ms", "timedelta", "datetime"] index = _check_pandas_index_arguments(index, valid_index_args) - time_format = _check_time_format(time_format, valid_time_formats, - self.info['meas_date']) + time_format = _check_time_format( + time_format, valid_time_formats, self.info["meas_date"] + ) # get data - picks = _picks_to_idx(self.info, picks, 'all', exclude=()) + picks = _picks_to_idx(self.info, picks, "all", exclude=()) data, times = self[picks, start:stop] data = data.T if copy: @@ -1910,10 +2259,11 @@ def to_data_frame(self, picks=None, index=None, # prepare extra columns / multiindex mindex = list() times = _convert_times(self, times, time_format) - mindex.append(('time', times)) + mindex.append(("time", times)) # build DataFrame - df = _build_data_frame(self, data, picks, long_format, mindex, index, - default_index=['time']) + df = _build_data_frame( + self, data, picks, long_format, mindex, index, default_index=["time"] + ) return df def describe(self, data_frame=False): @@ -1936,6 +2286,7 @@ def describe(self, data_frame=False): results in a pandas.DataFrame (requires pandas). """ from scipy.stats import scoreatpercentile as q + nchan = self.info["nchan"] # describe each channel @@ -1954,6 +2305,7 @@ def describe(self, data_frame=False): if data_frame: # return data frame import pandas as pd + df = pd.DataFrame(cols) df.index.name = "ch" return df @@ -1962,35 +2314,41 @@ def describe(self, data_frame=False): scalings = _handle_default("scalings") units = _handle_default("units") for i in range(nchan): - unit = units.get(cols['type'][i]) - scaling = scalings.get(cols['type'][i], 1) + unit = units.get(cols["type"][i]) + scaling = scalings.get(cols["type"][i], 1) if scaling != 1: - cols['unit'][i] = unit + cols["unit"][i] = unit for col in ["min", "Q1", "median", "Q3", "max"]: cols[col][i] *= scaling - lens = {"ch": max(2, len(str(nchan))), - "name": max(4, max([len(n) for n in cols["name"]])), - "type": max(4, max([len(t) for t in cols["type"]])), - "unit": max(4, max([len(u) for u in cols["unit"]]))} + lens = { + "ch": max(2, len(str(nchan))), + "name": max(4, max([len(n) for n in cols["name"]])), + "type": max(4, max([len(t) for t in cols["type"]])), + "unit": max(4, max([len(u) for u in cols["unit"]])), + } # print description, start with header print(self) - print(f"{'ch':>{lens['ch']}} " - f"{'name':<{lens['name']}} " - f"{'type':<{lens['type']}} " - f"{'unit':<{lens['unit']}} " - f"{'min':>9} " - f"{'Q1':>9} " - f"{'median':>9} " - f"{'Q3':>9} " - f"{'max':>9}") + print( + f"{'ch':>{lens['ch']}} " + f"{'name':<{lens['name']}} " + f"{'type':<{lens['type']}} " + f"{'unit':<{lens['unit']}} " + f"{'min':>9} " + f"{'Q1':>9} " + f"{'median':>9} " + f"{'Q3':>9} " + f"{'max':>9}" + ) # print description for each channel for i in range(nchan): - msg = (f"{i:>{lens['ch']}} " - f"{cols['name'][i]:<{lens['name']}} " - f"{cols['type'][i].upper():<{lens['type']}} " - f"{cols['unit'][i]:<{lens['unit']}} ") + msg = ( + f"{i:>{lens['ch']}} " + f"{cols['name'][i]:<{lens['name']}} " + f"{cols['type'][i].upper():<{lens['type']}} " + f"{cols['unit'][i]:<{lens['unit']}} " + ) for col in ["min", "Q1", "median", "Q3"]: msg += f"{cols[col][i]:>9.2f} " msg += f"{cols['max'][i]:>9.2f}" @@ -2002,8 +2360,8 @@ def _allocate_data(preload, shape, dtype): if preload in (None, True): # None comes from _read_segment data = np.zeros(shape, dtype) else: - _validate_type(preload, 'path-like', 'preload') - data = np.memmap(str(preload), mode='w+', dtype=dtype, shape=shape) + _validate_type(preload, "path-like", "preload") + data = np.memmap(str(preload), mode="w+", dtype=dtype, shape=shape) return data @@ -2054,16 +2412,18 @@ def _get_ch_factors(inst, units, picks_idxs): """ _validate_type(units, types=(None, str, dict), item_name="units") ch_factors = np.ones(len(picks_idxs)) - si_units = _handle_default('si_units') + si_units = _handle_default("si_units") ch_types = inst.get_channel_types(picks=picks_idxs) # Convert to dict if str units if isinstance(units, str): # Check that there is only one channel type unit_ch_type = list(set(ch_types) & set(si_units.keys())) if len(unit_ch_type) > 1: - raise ValueError('"units" cannot be str if there is more than ' - 'one channel type with a unit ' - f'{unit_ch_type}.') + raise ValueError( + '"units" cannot be str if there is more than ' + "one channel type with a unit " + f"{unit_ch_type}." + ) units = {unit_ch_type[0]: units} # make the str argument a dict # Loop over the dict to get channel factors if isinstance(units, dict): @@ -2071,8 +2431,7 @@ def _get_ch_factors(inst, units, picks_idxs): # Get the scaling factors scaling = _get_scaling(ch_type, ch_unit) if scaling != 1: - indices = [i_ch for i_ch, ch in enumerate(ch_types) - if ch == ch_type] + indices = [i_ch for i_ch, ch in enumerate(ch_types) if ch == ch_type] ch_factors[indices] *= scaling return ch_factors @@ -2094,43 +2453,44 @@ def _get_scaling(ch_type, target_unit): The scaling factor to convert from the si_unit (used by default for MNE objects) to the target unit. """ - scaling = 1. - si_units = _handle_default('si_units') - si_units_splitted = {key: si_units[key].split('/') for key in si_units} - prefixes = _handle_default('prefixes') + scaling = 1.0 + si_units = _handle_default("si_units") + si_units_splitted = {key: si_units[key].split("/") for key in si_units} + prefixes = _handle_default("prefixes") prefix_list = list(prefixes.keys()) # Check that the provided unit exists for the ch_type - unit_list = target_unit.split('/') + unit_list = target_unit.split("/") if ch_type not in si_units.keys(): raise KeyError( - f'{ch_type} is not a channel type that can be scaled ' - 'from units.') + f"{ch_type} is not a channel type that can be scaled " "from units." + ) si_unit_list = si_units_splitted[ch_type] if len(unit_list) != len(si_unit_list): raise ValueError( - f'{target_unit} is not a valid unit for {ch_type}, use a ' - f'sub-multiple of {si_units[ch_type]} instead.') + f"{target_unit} is not a valid unit for {ch_type}, use a " + f"sub-multiple of {si_units[ch_type]} instead." + ) for i, unit in enumerate(unit_list): - valid = [prefix + si_unit_list[i] - for prefix in prefix_list] + valid = [prefix + si_unit_list[i] for prefix in prefix_list] if unit not in valid: raise ValueError( - f'{target_unit} is not a valid unit for {ch_type}, use a ' - f'sub-multiple of {si_units[ch_type]} instead.') + f"{target_unit} is not a valid unit for {ch_type}, use a " + f"sub-multiple of {si_units[ch_type]} instead." + ) # Get the scaling factors for i, unit in enumerate(unit_list): has_square = False # XXX power normally not used as csd cannot get_data() - if unit[-1] == '²': + if unit[-1] == "²": has_square = True - if unit == 'm' or unit == 'm²': - factor = 1. + if unit == "m" or unit == "m²": + factor = 1.0 elif unit[0] in prefixes.keys(): factor = prefixes[unit[0]] else: - factor = 1. + factor = 1.0 if factor != 1: if has_square: factor *= factor @@ -2146,13 +2506,14 @@ class _ReadSegmentFileProtector: def __init__(self, raw): self.__raw = raw - assert hasattr(raw, '_projector') + assert hasattr(raw, "_projector") self._filenames = raw._filenames self._raw_extras = raw._raw_extras def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): return self.__raw.__class__._read_segment_file( - self, data, idx, fi, start, stop, cals, mult) + self, data, idx, fi, start, stop, cals, mult + ) class _RawShell: @@ -2183,72 +2544,105 @@ def set_annotations(self, annotations): ############################################################################### # Writing -def _write_raw(fname, raw, info, picks, fmt, data_type, reset_range, start, - stop, buffer_size, projector, drop_small_buffer, - split_size, split_naming, part_idx, prev_fname, overwrite): +def _write_raw( + fname, + raw, + info, + picks, + fmt, + data_type, + reset_range, + start, + stop, + buffer_size, + projector, + drop_small_buffer, + split_size, + split_naming, + part_idx, + prev_fname, + overwrite, +): """Write raw file with splitting.""" # we've done something wrong if we hit this n_times_max = len(raw.times) if start >= stop or stop > n_times_max: - raise RuntimeError('Cannot write raw file with no data: %s -> %s ' - '(max: %s) requested' % (start, stop, n_times_max)) + raise RuntimeError( + "Cannot write raw file with no data: %s -> %s " + "(max: %s) requested" % (start, stop, n_times_max) + ) # Expand `~` if present fname = str(_check_fname(fname=fname, overwrite=overwrite)) base, ext = op.splitext(fname) if part_idx > 0: - if split_naming == 'neuromag': + if split_naming == "neuromag": # insert index in filename - use_fname = '%s-%d%s' % (base, part_idx, ext) + use_fname = "%s-%d%s" % (base, part_idx, ext) else: - assert split_naming == 'bids' + assert split_naming == "bids" use_fname = _construct_bids_filename(base, ext, part_idx + 1) # check for file existence _check_fname(use_fname, overwrite) else: use_fname = fname # reserve our BIDS split fname in case we need to split - if split_naming == 'bids' and part_idx == 0: + if split_naming == "bids" and part_idx == 0: # reserve our possible split name reserved_fname = _construct_bids_filename(base, ext, part_idx + 1) - logger.info( - f'Reserving possible split file {op.basename(reserved_fname)}') + logger.info(f"Reserving possible split file {op.basename(reserved_fname)}") _check_fname(reserved_fname, overwrite) ctx = _ReservedFilename(reserved_fname) else: reserved_fname = use_fname ctx = nullcontext() - logger.info('Writing %s' % use_fname) + logger.info("Writing %s" % use_fname) - picks = _picks_to_idx(info, picks, 'all', ()) + picks = _picks_to_idx(info, picks, "all", ()) with start_and_end_file(use_fname) as fid: - cals = _start_writing_raw(fid, info, picks, data_type, - reset_range, raw.annotations) + cals = _start_writing_raw( + fid, info, picks, data_type, reset_range, raw.annotations + ) with ctx: final_fname = _write_raw_fid( - raw, info, picks, fid, cals, part_idx, start, stop, - buffer_size, prev_fname, split_size, use_fname, - projector, drop_small_buffer, fmt, fname, reserved_fname, - data_type, reset_range, split_naming, - overwrite=True # we've started writing already above + raw, + info, + picks, + fid, + cals, + part_idx, + start, + stop, + buffer_size, + prev_fname, + split_size, + use_fname, + projector, + drop_small_buffer, + fmt, + fname, + reserved_fname, + data_type, + reset_range, + split_naming, + overwrite=True, # we've started writing already above ) if final_fname != use_fname: - assert split_naming == 'bids' - logger.info(f'Renaming BIDS split file {op.basename(final_fname)}') + assert split_naming == "bids" + logger.info(f"Renaming BIDS split file {op.basename(final_fname)}") ctx.remove = False shutil.move(use_fname, final_fname) if part_idx == 0: - logger.info('[done]') + logger.info("[done]") return final_fname, part_idx class _ReservedFilename: - def __init__(self, fname): self.fname = fname assert op.isdir(op.dirname(fname)), fname - with open(fname, 'w'): + with open(fname, "w"): pass self.remove = True @@ -2260,10 +2654,29 @@ def __exit__(self, exc_type, exc_value, traceback): os.remove(self.fname) -def _write_raw_fid(raw, info, picks, fid, cals, part_idx, start, stop, - buffer_size, prev_fname, split_size, use_fname, - projector, drop_small_buffer, fmt, fname, reserved_fname, - data_type, reset_range, split_naming, overwrite): +def _write_raw_fid( + raw, + info, + picks, + fid, + cals, + part_idx, + start, + stop, + buffer_size, + prev_fname, + split_size, + use_fname, + projector, + drop_small_buffer, + fmt, + fname, + reserved_fname, + data_type, + reset_range, + split_naming, + overwrite, +): first_samp = raw.first_samp + start if first_samp != 0: write_int(fid, FIFF.FIFF_FIRST_SAMPLE, first_samp) @@ -2273,17 +2686,19 @@ def _write_raw_fid(raw, info, picks, fid, cals, part_idx, start, stop, start_block(fid, FIFF.FIFFB_REF) write_int(fid, FIFF.FIFF_REF_ROLE, FIFF.FIFFV_ROLE_PREV_FILE) write_string(fid, FIFF.FIFF_REF_FILE_NAME, prev_fname) - if info['meas_id'] is not None: - write_id(fid, FIFF.FIFF_REF_FILE_ID, info['meas_id']) + if info["meas_id"] is not None: + write_id(fid, FIFF.FIFF_REF_FILE_ID, info["meas_id"]) write_int(fid, FIFF.FIFF_REF_FILE_NUM, part_idx - 1) end_block(fid, FIFF.FIFFB_REF) pos_prev = fid.tell() if pos_prev > split_size: - raise ValueError('file is larger than "split_size" after writing ' - 'measurement information, you must use a larger ' - 'value for split size: %s plus enough bytes for ' - 'the chosen buffer_size' % pos_prev) + raise ValueError( + 'file is larger than "split_size" after writing ' + "measurement information, you must use a larger " + "value for split size: %s plus enough bytes for " + "the chosen buffer_size" % pos_prev + ) # Check to see if this has acquisition skips and, if so, if we can # write out empty buffers instead of zeroes @@ -2291,15 +2706,17 @@ def _write_raw_fid(raw, info, picks, fid, cals, part_idx, start, stop, lasts = np.array(firsts) + buffer_size if lasts[-1] > stop: lasts[-1] = stop - sk_onsets, sk_ends = _annotations_starts_stops(raw, 'bad_acq_skip') + sk_onsets, sk_ends = _annotations_starts_stops(raw, "bad_acq_skip") do_skips = False if len(sk_onsets) > 0: if np.in1d(sk_onsets, firsts).all() and np.in1d(sk_ends, lasts).all(): do_skips = True else: if part_idx == 0: - warn('Acquisition skips detected but did not fit evenly into ' - 'output buffer_size, will be written as zeroes.') + warn( + "Acquisition skips detected but did not fit evenly into " + "output buffer_size, will be written as zeroes." + ) n_current_skip = 0 final_fname = use_fname @@ -2323,12 +2740,10 @@ def _write_raw_fid(raw, info, picks, fid, cals, part_idx, start, stop, if projector is not None: data = np.dot(projector, data) - if ((drop_small_buffer and (first > start) and - (len(times) < buffer_size))): - logger.info('Skipping data chunk due to small buffer ... ' - '[done]') + if drop_small_buffer and (first > start) and (len(times) < buffer_size): + logger.info("Skipping data chunk due to small buffer ... " "[done]") break - logger.debug(f'Writing FIF {first:6d} ... {last:6d} ...') + logger.debug(f"Writing FIF {first:6d} ... {last:6d} ...") _write_raw_buffer(fid, data, cals, fmt) pos = fid.tell() @@ -2338,36 +2753,59 @@ def _write_raw_fid(raw, info, picks, fid, cals, part_idx, start, stop, # This should occur on the first buffer write of the file, so # we should mention the space required for the meas info raise ValueError( - 'buffer size (%s) is too large for the given split size (%s) ' - 'by %s bytes after writing info (%s) and leaving enough space ' + "buffer size (%s) is too large for the given split size (%s) " + "by %s bytes after writing info (%s) and leaving enough space " 'for end tags (%s): decrease "buffer_size_sec" or increase ' - '"split_size".' % (this_buff_size_bytes, split_size, overage, - pos_prev, _NEXT_FILE_BUFFER)) + '"split_size".' + % ( + this_buff_size_bytes, + split_size, + overage, + pos_prev, + _NEXT_FILE_BUFFER, + ) + ) # Split files if necessary, leave some space for next file info # make sure we check to make sure we actually *need* another buffer # with the "and" check - if pos >= split_size - this_buff_size_bytes - _NEXT_FILE_BUFFER and \ - first + buffer_size < stop: + if ( + pos >= split_size - this_buff_size_bytes - _NEXT_FILE_BUFFER + and first + buffer_size < stop + ): final_fname = reserved_fname next_fname, next_idx = _write_raw( - fname, raw, info, picks, fmt, - data_type, reset_range, first + buffer_size, stop, buffer_size, - projector, drop_small_buffer, split_size, split_naming, - part_idx + 1, final_fname, overwrite) + fname, + raw, + info, + picks, + fmt, + data_type, + reset_range, + first + buffer_size, + stop, + buffer_size, + projector, + drop_small_buffer, + split_size, + split_naming, + part_idx + 1, + final_fname, + overwrite, + ) start_block(fid, FIFF.FIFFB_REF) write_int(fid, FIFF.FIFF_REF_ROLE, FIFF.FIFFV_ROLE_NEXT_FILE) write_string(fid, FIFF.FIFF_REF_FILE_NAME, op.basename(next_fname)) - if info['meas_id'] is not None: - write_id(fid, FIFF.FIFF_REF_FILE_ID, info['meas_id']) + if info["meas_id"] is not None: + write_id(fid, FIFF.FIFF_REF_FILE_ID, info["meas_id"]) write_int(fid, FIFF.FIFF_REF_FILE_NUM, next_idx) end_block(fid, FIFF.FIFFB_REF) break pos_prev = pos - logger.info('Closing %s' % use_fname) - if info.get('maxshield', False): + logger.info("Closing %s" % use_fname) + if info.get("maxshield", False): end_block(fid, FIFF.FIFFB_IAS_RAW_DATA) else: end_block(fid, FIFF.FIFFB_RAW_DATA) @@ -2376,8 +2814,7 @@ def _write_raw_fid(raw, info, picks, fid, cals, part_idx, start, stop, @fill_doc -def _start_writing_raw(fid, info, sel, data_type, - reset_range, annotations): +def _start_writing_raw(fid, info, sel, data_type, reset_range, annotations): """Start write raw data in file. Parameters @@ -2413,18 +2850,18 @@ def _start_writing_raw(fid, info, sel, data_type, # start_block(fid, FIFF.FIFFB_MEAS) write_id(fid, FIFF.FIFF_BLOCK_ID) - if info['meas_id'] is not None: - write_id(fid, FIFF.FIFF_PARENT_BLOCK_ID, info['meas_id']) + if info["meas_id"] is not None: + write_id(fid, FIFF.FIFF_PARENT_BLOCK_ID, info["meas_id"]) cals = [] - for k in range(info['nchan']): + for k in range(info["nchan"]): # # Scan numbers may have been messed up # - info['chs'][k]['scanno'] = k + 1 # scanno starts at 1 in FIF format + info["chs"][k]["scanno"] = k + 1 # scanno starts at 1 in FIF format if reset_range is True: - info['chs'][k]['range'] = 1.0 - cals.append(info['chs'][k]['cal'] * info['chs'][k]['range']) + info["chs"][k]["range"] = 1.0 + cals.append(info["chs"][k]["cal"] * info["chs"][k]["range"]) write_meas_info(fid, info, data_type=data_type, reset_range=reset_range) @@ -2437,7 +2874,7 @@ def _start_writing_raw(fid, info, sel, data_type, # # Start the raw data # - if info.get('maxshield', False): + if info.get("maxshield", False): start_block(fid, FIFF.FIFFB_IAS_RAW_DATA) else: start_block(fid, FIFF.FIFFB_RAW_DATA) @@ -2462,30 +2899,31 @@ def _write_raw_buffer(fid, buf, cals, fmt): that short and int formats cannot be used for complex data. """ if buf.shape[0] != len(cals): - raise ValueError('buffer and calibration sizes do not match') + raise ValueError("buffer and calibration sizes do not match") - _check_option('fmt', fmt, ['short', 'int', 'single', 'double']) + _check_option("fmt", fmt, ["short", "int", "single", "double"]) cast_int = False # allow unsafe cast if np.isrealobj(buf): - if fmt == 'short': + if fmt == "short": write_function = write_dau_pack16 cast_int = True - elif fmt == 'int': + elif fmt == "int": write_function = write_int cast_int = True - elif fmt == 'single': + elif fmt == "single": write_function = write_float else: write_function = write_double else: - if fmt == 'single': + if fmt == "single": write_function = write_complex64 - elif fmt == 'double': + elif fmt == "double": write_function = write_complex128 else: - raise ValueError('only "single" and "double" supported for ' - 'writing complex data') + raise ValueError( + 'only "single" and "double" supported for ' "writing complex data" + ) buf = buf / np.ravel(cals)[:, None] if cast_int: @@ -2497,36 +2935,42 @@ def _check_raw_compatibility(raw): """Ensure all instances of Raw have compatible parameters.""" for ri in range(1, len(raw)): if not isinstance(raw[ri], type(raw[0])): - raise ValueError(f'raw[{ri}] type must match') - for key in ('nchan', 'sfreq'): + raise ValueError(f"raw[{ri}] type must match") + for key in ("nchan", "sfreq"): a, b = raw[ri].info[key], raw[0].info[key] if a != b: raise ValueError( - f'raw[{ri}].info[{key}] must match:\n' - f'{repr(a)} != {repr(b)}') - for kind in ('bads', 'ch_names'): + f"raw[{ri}].info[{key}] must match:\n" f"{repr(a)} != {repr(b)}" + ) + for kind in ("bads", "ch_names"): set1 = set(raw[0].info[kind]) set2 = set(raw[ri].info[kind]) mismatch = set1.symmetric_difference(set2) if mismatch: - raise ValueError(f'raw[{ri}][\'info\'][{kind}] do not match: ' - f'{sorted(mismatch)}') + raise ValueError( + f"raw[{ri}]['info'][{kind}] do not match: " f"{sorted(mismatch)}" + ) if any(raw[ri]._cals != raw[0]._cals): - raise ValueError('raw[%d]._cals must match' % ri) - if len(raw[0].info['projs']) != len(raw[ri].info['projs']): - raise ValueError('SSP projectors in raw files must be the same') - if not all(_proj_equal(p1, p2) for p1, p2 in - zip(raw[0].info['projs'], raw[ri].info['projs'])): - raise ValueError('SSP projectors in raw files must be the same') + raise ValueError("raw[%d]._cals must match" % ri) + if len(raw[0].info["projs"]) != len(raw[ri].info["projs"]): + raise ValueError("SSP projectors in raw files must be the same") + if not all( + _proj_equal(p1, p2) + for p1, p2 in zip(raw[0].info["projs"], raw[ri].info["projs"]) + ): + raise ValueError("SSP projectors in raw files must be the same") if any(r.orig_format != raw[0].orig_format for r in raw): - warn('raw files do not all have the same data format, could result in ' - 'precision mismatch. Setting raw.orig_format="unknown"') - raw[0].orig_format = 'unknown' + warn( + "raw files do not all have the same data format, could result in " + 'precision mismatch. Setting raw.orig_format="unknown"' + ) + raw[0].orig_format = "unknown" @verbose -def concatenate_raws(raws, preload=None, events_list=None, *, - on_mismatch='raise', verbose=None): +def concatenate_raws( + raws, preload=None, events_list=None, *, on_mismatch="raise", verbose=None +): """Concatenate `~mne.io.Raw` instances as if they were continuous. .. note:: ``raws[0]`` is modified in-place to achieve the concatenation. @@ -2553,13 +2997,18 @@ def concatenate_raws(raws, preload=None, events_list=None, *, The events. Only returned if ``event_list`` is not None. """ for idx, raw in enumerate(raws[1:], start=1): - _ensure_infos_match(info1=raws[0].info, info2=raw.info, - name=f'raws[{idx}]', on_mismatch=on_mismatch) + _ensure_infos_match( + info1=raws[0].info, + info2=raw.info, + name=f"raws[{idx}]", + on_mismatch=on_mismatch, + ) if events_list is not None: if len(events_list) != len(raws): - raise ValueError('`raws` and `event_list` are required ' - 'to be of the same length') + raise ValueError( + "`raws` and `event_list` are required " "to be of the same length" + ) first, last = zip(*[(r.first_samp, r.last_samp) for r in raws]) events = concatenate_events(events_list, first, last) raws[0].append(raws[1:], preload) @@ -2594,16 +3043,19 @@ def match_channel_orders(raws, copy=True): def _check_maxshield(allow_maxshield): """Warn or error about MaxShield.""" - msg = ('This file contains raw Internal Active ' - 'Shielding data. It may be distorted. Elekta ' - 'recommends it be run through MaxFilter to ' - 'produce reliable results. Consider closing ' - 'the file and running MaxFilter on the data.') + msg = ( + "This file contains raw Internal Active " + "Shielding data. It may be distorted. Elekta " + "recommends it be run through MaxFilter to " + "produce reliable results. Consider closing " + "the file and running MaxFilter on the data." + ) if allow_maxshield: - if not (isinstance(allow_maxshield, str) and - allow_maxshield == 'yes'): + if not (isinstance(allow_maxshield, str) and allow_maxshield == "yes"): warn(msg) else: - msg += (' Use allow_maxshield=True if you are sure you' - ' want to load the data despite this warning.') + msg += ( + " Use allow_maxshield=True if you are sure you" + " want to load the data despite this warning." + ) raise ValueError(msg) diff --git a/mne/io/besa/besa.py b/mne/io/besa/besa.py index 368506b6506..440bf89877c 100644 --- a/mne/io/besa/besa.py +++ b/mne/io/besa/besa.py @@ -27,12 +27,12 @@ def read_evoked_besa(fname, verbose=None): The evoked data in the .avr or .mul file. """ fname = Path(fname) - if fname.suffix == '.avr': + if fname.suffix == ".avr": return _read_evoked_besa_avr(fname, verbose) - elif fname.suffix == '.mul': + elif fname.suffix == ".mul": return _read_evoked_besa_mul(fname, verbose) else: - raise ValueError('Filename must end in either .avr or .mul') + raise ValueError("Filename must end in either .avr or .mul") @verbose @@ -44,7 +44,7 @@ def _read_evoked_besa_avr(fname, verbose): # There are two versions of .avr files. The old style, generated by # BESA 1, 2 and 3 does not define Nchan and does not have channel names # in the file. - new_style = 'Nchan=' in header + new_style = "Nchan=" in header if new_style: ch_names = f.readline().strip().split() else: @@ -58,62 +58,76 @@ def _read_evoked_besa_avr(fname, verbose): if new_style: if len(ch_names) != len(data): raise RuntimeError( - 'Mismatch between the number of channel names defined in ' - f'the .avr file ({len(ch_names)}) and the number of rows ' - f'in the data matrix ({len(data)}).') + "Mismatch between the number of channel names defined in " + f"the .avr file ({len(ch_names)}) and the number of rows " + f"in the data matrix ({len(data)})." + ) else: # Determine channel names from the .elp sidecar file if ch_types is not None: ch_names = list(ch_types.keys()) if len(ch_names) != len(data): - raise RuntimeError('Mismatch between the number of channels ' - f'defined in the .avr file ({len(data)}) ' - f'and .elp file ({len(ch_names)}).') + raise RuntimeError( + "Mismatch between the number of channels " + f"defined in the .avr file ({len(data)}) " + f"and .elp file ({len(ch_names)})." + ) else: - logger.info('No .elp file found and no channel names present in ' - 'the .avr file. Falling back to generic names. ') - ch_names = [f'CH{i + 1:02d}' for i in range(len(data))] + logger.info( + "No .elp file found and no channel names present in " + "the .avr file. Falling back to generic names. " + ) + ch_names = [f"CH{i + 1:02d}" for i in range(len(data))] # Consolidate channel types if ch_types is None: - logger.info('Marking all channels as EEG.') - ch_types = ['eeg'] * len(ch_names) + logger.info("Marking all channels as EEG.") + ch_types = ["eeg"] * len(ch_names) else: ch_types = [ch_types[ch] for ch in ch_names] # Go over all the header fields and make sure they are all defined to # something sensible. - if 'Npts' in fields: - fields['Npts'] = int(fields['Npts']) - if fields['Npts'] != data.shape[1]: - logger.warn(f'The size of the data matrix ({data.shape}) does not ' - f'match the "Npts" field ({fields["Npts"]}).') - if 'Nchan' in fields: - fields['Nchan'] = int(fields['Nchan']) - if fields['Nchan'] != data.shape[0]: - logger.warn(f'The size of the data matrix ({data.shape}) does not ' - f'match the "Nchan" field ({fields["Nchan"]}).') - if 'DI' in fields: - fields['DI'] = float(fields['DI']) + if "Npts" in fields: + fields["Npts"] = int(fields["Npts"]) + if fields["Npts"] != data.shape[1]: + logger.warn( + f"The size of the data matrix ({data.shape}) does not " + f'match the "Npts" field ({fields["Npts"]}).' + ) + if "Nchan" in fields: + fields["Nchan"] = int(fields["Nchan"]) + if fields["Nchan"] != data.shape[0]: + logger.warn( + f"The size of the data matrix ({data.shape}) does not " + f'match the "Nchan" field ({fields["Nchan"]}).' + ) + if "DI" in fields: + fields["DI"] = float(fields["DI"]) else: - raise RuntimeError('No "DI" field present. Could not determine ' - 'sampling frequency.') - if 'TSB' in fields: - fields['TSB'] = float(fields['TSB']) + raise RuntimeError( + 'No "DI" field present. Could not determine ' "sampling frequency." + ) + if "TSB" in fields: + fields["TSB"] = float(fields["TSB"]) else: - fields['TSB'] = 0 - if 'SB' in fields: - fields['SB'] = float(fields['SB']) + fields["TSB"] = 0 + if "SB" in fields: + fields["SB"] = float(fields["SB"]) else: - fields['SB'] = 1.0 - if 'SegmentName' not in fields: - fields['SegmentName'] = '' + fields["SB"] = 1.0 + if "SegmentName" not in fields: + fields["SegmentName"] = "" # Build the Evoked object based on the header fields. - info = create_info(ch_names, sfreq=1000 / fields['DI'], ch_types='eeg') - return EvokedArray(data / fields['SB'] / 1E6, info, - tmin=fields['TSB'] / 1000, - comment=fields['SegmentName'], verbose=verbose) + info = create_info(ch_names, sfreq=1000 / fields["DI"], ch_types="eeg") + return EvokedArray( + data / fields["SB"] / 1e6, + info, + tmin=fields["TSB"] / 1000, + comment=fields["SegmentName"], + verbose=verbose, + ) @verbose @@ -127,54 +141,66 @@ def _read_evoked_besa_mul(fname, verbose): data = np.loadtxt(fname, skiprows=2, ndmin=2) if len(ch_names) != data.shape[1]: - raise RuntimeError('Mismatch between the number of channel names ' - f'defined in the .mul file ({len(ch_names)}) ' - 'and the number of columns in the data matrix ' - f'({data.shape[1]}).') + raise RuntimeError( + "Mismatch between the number of channel names " + f"defined in the .mul file ({len(ch_names)}) " + "and the number of columns in the data matrix " + f"({data.shape[1]})." + ) # Consolidate channel types ch_types = _read_elp_sidecar(fname) if ch_types is None: - logger.info('Marking all channels as EEG.') - ch_types = ['eeg'] * len(ch_names) + logger.info("Marking all channels as EEG.") + ch_types = ["eeg"] * len(ch_names) else: ch_types = [ch_types[ch] for ch in ch_names] # Go over all the header fields and make sure they are all defined to # something sensible. - if 'TimePoints' in fields: - fields['TimePoints'] = int(fields['TimePoints']) - if fields['TimePoints'] != data.shape[0]: + if "TimePoints" in fields: + fields["TimePoints"] = int(fields["TimePoints"]) + if fields["TimePoints"] != data.shape[0]: + logger.warn( + f"The size of the data matrix ({data.shape}) does not " + f'match the "TimePoints" field ({fields["TimePoints"]}).' + ) + if "Channels" in fields: + fields["Channels"] = int(fields["Channels"]) + if fields["Channels"] != data.shape[1]: logger.warn( - f'The size of the data matrix ({data.shape}) does not ' - f'match the "TimePoints" field ({fields["TimePoints"]}).') - if 'Channels' in fields: - fields['Channels'] = int(fields['Channels']) - if fields['Channels'] != data.shape[1]: - logger.warn(f'The size of the data matrix ({data.shape}) does not ' - f'match the "Channels" field ({fields["Channels"]}).') - if 'SamplingInterval[ms]' in fields: - fields['SamplingInterval[ms]'] = float(fields['SamplingInterval[ms]']) + f"The size of the data matrix ({data.shape}) does not " + f'match the "Channels" field ({fields["Channels"]}).' + ) + if "SamplingInterval[ms]" in fields: + fields["SamplingInterval[ms]"] = float(fields["SamplingInterval[ms]"]) else: - raise RuntimeError('No "SamplingInterval[ms]" field present. Could ' - 'not determine sampling frequency.') - if 'BeginSweep[ms]' in fields: - fields['BeginSweep[ms]'] = float(fields['BeginSweep[ms]']) + raise RuntimeError( + 'No "SamplingInterval[ms]" field present. Could ' + "not determine sampling frequency." + ) + if "BeginSweep[ms]" in fields: + fields["BeginSweep[ms]"] = float(fields["BeginSweep[ms]"]) else: - fields['BeginSweep[ms]'] = 0.0 - if 'Bins/uV' in fields: - fields['Bins/uV'] = float(fields['Bins/uV']) + fields["BeginSweep[ms]"] = 0.0 + if "Bins/uV" in fields: + fields["Bins/uV"] = float(fields["Bins/uV"]) else: - fields['Bins/uV'] = 1 - if 'SegmentName' not in fields: - fields['SegmentName'] = '' + fields["Bins/uV"] = 1 + if "SegmentName" not in fields: + fields["SegmentName"] = "" # Build the Evoked object based on the header fields. - info = create_info(ch_names, sfreq=1000 / fields['SamplingInterval[ms]'], - ch_types=ch_types) - return EvokedArray(data.T / fields['Bins/uV'] / 1E6, info, - tmin=fields['BeginSweep[ms]'] / 1000, - comment=fields['SegmentName'], verbose=verbose) + info = create_info( + ch_names, sfreq=1000 / fields["SamplingInterval[ms]"], ch_types=ch_types + ) + return EvokedArray( + data.T / fields["Bins/uV"] / 1e6, + info, + tmin=fields["BeginSweep[ms]"] / 1000, + comment=fields["SegmentName"], + verbose=verbose, + ) def _parse_header(header): @@ -196,7 +222,7 @@ def _parse_header(header): """ parts = header.split() # Splits on one or more spaces name_val_pairs = zip(parts[::2], parts[1::2]) - return dict((name.replace('=', ''), val) for name, val in name_val_pairs) + return dict((name.replace("=", ""), val) for name, val in name_val_pairs) def _read_elp_sidecar(fname): @@ -218,13 +244,12 @@ def _read_elp_sidecar(fname): If the sidecar file exists, return a dictionary mapping channel names to channel types. Otherwise returns ``None``. """ - fname_elp = fname.parent / (fname.stem + '.elp') + fname_elp = fname.parent / (fname.stem + ".elp") if not fname_elp.exists(): - logger.info(f'No {fname_elp} file present containing electrode ' - 'information.') + logger.info(f"No {fname_elp} file present containing electrode " "information.") return None - logger.info(f'Reading electrode names and types from {fname_elp}') + logger.info(f"Reading electrode names and types from {fname_elp}") ch_types = OrderedDict() with open(fname_elp) as f: lines = f.readlines() @@ -235,9 +260,10 @@ def _read_elp_sidecar(fname): ch_types[ch_name] = ch_type.lower() else: # No channel types present - logger.info('No channel types present in .elp file. Marking all ' - 'channels as EEG.') + logger.info( + "No channel types present in .elp file. Marking all " "channels as EEG." + ) for line in lines: ch_name = line.split()[:1] - ch_types[ch_name] = 'eeg' + ch_types[ch_name] = "eeg" return ch_types diff --git a/mne/io/besa/tests/test_besa.py b/mne/io/besa/tests/test_besa.py index 23097dd497f..fcaf32d651c 100644 --- a/mne/io/besa/tests/test_besa.py +++ b/mne/io/besa/tests/test_besa.py @@ -8,67 +8,67 @@ FILE = Path(inspect.getfile(inspect.currentframe())) -data_dir = FILE.parent / 'data' -avr_file = data_dir / 'simulation.avr' -avr_file_oldstyle = data_dir / 'simulation_oldstyle.avr' -mul_file = data_dir / 'simulation.mul' -montage = read_custom_montage(data_dir / 'simulation.elp') +data_dir = FILE.parent / "data" +avr_file = data_dir / "simulation.avr" +avr_file_oldstyle = data_dir / "simulation_oldstyle.avr" +mul_file = data_dir / "simulation.mul" +montage = read_custom_montage(data_dir / "simulation.elp") @pytest.mark.filterwarnings("ignore:Fiducial point nasion not found") -@pytest.mark.parametrize('fname', (avr_file, avr_file_oldstyle, mul_file)) +@pytest.mark.parametrize("fname", (avr_file, avr_file_oldstyle, mul_file)) def test_read_evoked_besa(fname): """Test reading MESA .avr and .mul files.""" ev = read_evoked_besa(fname) assert len(ev.ch_names) == len(ev.data) == 33 - assert ev.info['sfreq'] == 200 + assert ev.info["sfreq"] == 200 assert ev.tmin == -0.1 assert len(ev.times) == 200 assert ev.ch_names == montage.ch_names - assert ev.comment == 'simulation' + assert ev.comment == "simulation" def test_read_evoked_besa_avr_incomplete(tmp_path): """Test reading incomplete BESA .avr files.""" # Check old style .avr file without an .elp sidecar - with open(f'{tmp_path}/missing.avr', 'w') as f: - f.write('Npts= 1 TSB= 0 SB= 1.00 SC= 500.0 DI= 5\n0\n1\n2\n') - ev = read_evoked_besa(f'{tmp_path}/missing.avr') - assert ev.ch_names == ['CH01', 'CH02', 'CH03'] + with open(f"{tmp_path}/missing.avr", "w") as f: + f.write("Npts= 1 TSB= 0 SB= 1.00 SC= 500.0 DI= 5\n0\n1\n2\n") + ev = read_evoked_besa(f"{tmp_path}/missing.avr") + assert ev.ch_names == ["CH01", "CH02", "CH03"] # Create BESA file with missing header fields and verify things don't break - with open(f'{tmp_path}/missing.avr', 'w') as f: - f.write('DI= 5\n0\n') - ev = read_evoked_besa(f'{tmp_path}/missing.avr') + with open(f"{tmp_path}/missing.avr", "w") as f: + f.write("DI= 5\n0\n") + ev = read_evoked_besa(f"{tmp_path}/missing.avr") assert len(ev.ch_names) == len(ev.data) == 1 - assert ev.info['sfreq'] == 200 + assert ev.info["sfreq"] == 200 assert ev.tmin == 0 assert len(ev.times) == 1 - assert ev.ch_names == ['CH01'] - assert ev.comment == '' + assert ev.ch_names == ["CH01"] + assert ev.comment == "" # The DI field (sample frequency) must exist - with open(f'{tmp_path}/missing.avr', 'w') as f: - f.write('Npts= 1 TSB= 0 SB= 1.00 SC= 500.0\n0\n') + with open(f"{tmp_path}/missing.avr", "w") as f: + f.write("Npts= 1 TSB= 0 SB= 1.00 SC= 500.0\n0\n") with pytest.raises(RuntimeError, match='No "DI" field present'): - ev = read_evoked_besa(f'{tmp_path}/missing.avr') + ev = read_evoked_besa(f"{tmp_path}/missing.avr") def test_read_evoked_besa_mul_incomplete(tmp_path): """Test reading incomplete BESA .mul files.""" # Create BESA file with missing header fields and verify things don't break - with open(f'{tmp_path}/missing.mul', 'w') as f: - f.write('SamplingInterval[ms]= 5\nCH1\n0\n') - ev = read_evoked_besa(f'{tmp_path}/missing.mul') + with open(f"{tmp_path}/missing.mul", "w") as f: + f.write("SamplingInterval[ms]= 5\nCH1\n0\n") + ev = read_evoked_besa(f"{tmp_path}/missing.mul") assert len(ev.ch_names) == len(ev.data) == 1 - assert ev.info['sfreq'] == 200 + assert ev.info["sfreq"] == 200 assert ev.tmin == 0 assert len(ev.times) == 1 - assert ev.ch_names == ['CH1'] - assert ev.comment == '' + assert ev.ch_names == ["CH1"] + assert ev.comment == "" # The SamplingInterval[ms] field (sample frequency) must exist - with open(f'{tmp_path}/missing.mul', 'w') as f: - f.write('TimePoints= 1 Channels= 1\nCH1\n0\n') + with open(f"{tmp_path}/missing.mul", "w") as f: + f.write("TimePoints= 1 Channels= 1\nCH1\n0\n") with pytest.raises(RuntimeError, match=r'No "SamplingInterval\[ms\]"'): - ev = read_evoked_besa(f'{tmp_path}/missing.mul') + ev = read_evoked_besa(f"{tmp_path}/missing.mul") diff --git a/mne/io/boxy/boxy.py b/mne/io/boxy/boxy.py index 3f3cdcdfbcf..fbe98723e34 100644 --- a/mne/io/boxy/boxy.py +++ b/mne/io/boxy/boxy.py @@ -58,16 +58,16 @@ class RawBOXY(BaseRaw): @verbose def __init__(self, fname, preload=False, verbose=None): - logger.info('Loading %s' % fname) + logger.info("Loading %s" % fname) # Read header file and grab some info. start_line = np.inf col_names = mrk_col = filetype = mrk_data = end_line = None raw_extras = dict() - raw_extras['offsets'] = list() # keep track of our offsets + raw_extras["offsets"] = list() # keep track of our offsets sfreq = None fname = str(_check_fname(fname, "read", True, "fname")) - with open(fname, 'r') as fid: + with open(fname, "r") as fid: line_num = 0 i_line = fid.readline() while i_line: @@ -75,67 +75,71 @@ def __init__(self, fname, preload=False, verbose=None): if line_num >= start_line: assert col_names is not None assert filetype is not None - if '#DATA ENDS' in i_line: + if "#DATA ENDS" in i_line: # Data ends just before this. end_line = line_num break if mrk_col is not None: - if filetype == 'non-parsed': + if filetype == "non-parsed": # Non-parsed files have different lines lengths. - crnt_line = i_line.rsplit(' ')[0] - temp_data = re.findall( - r'[-+]?\d*\.?\d+', crnt_line) + crnt_line = i_line.rsplit(" ")[0] + temp_data = re.findall(r"[-+]?\d*\.?\d+", crnt_line) if len(temp_data) == len(col_names): - mrk_data.append(float( - re.findall(r'[-+]?\d*\.?\d+', crnt_line) - [mrk_col])) + mrk_data.append( + float( + re.findall(r"[-+]?\d*\.?\d+", crnt_line)[ + mrk_col + ] + ) + ) else: - crnt_line = i_line.rsplit(' ')[0] - mrk_data.append(float(re.findall( - r'[-+]?\d*\.?\d+', crnt_line)[mrk_col])) - raw_extras['offsets'].append(fid.tell()) + crnt_line = i_line.rsplit(" ")[0] + mrk_data.append( + float(re.findall(r"[-+]?\d*\.?\d+", crnt_line)[mrk_col]) + ) + raw_extras["offsets"].append(fid.tell()) # now proceed with more standard header parsing - elif 'BOXY.EXE:' in i_line: - boxy_ver = re.findall(r'\d*\.\d+', - i_line.rsplit(' ')[-1])[0] + elif "BOXY.EXE:" in i_line: + boxy_ver = re.findall(r"\d*\.\d+", i_line.rsplit(" ")[-1])[0] # Check that the BOXY version is supported - if boxy_ver not in ['0.40', '0.84']: - raise RuntimeError('MNE has not been tested with BOXY ' - 'version (%s)' % boxy_ver) - elif 'Detector Channels' in i_line: - raw_extras['detect_num'] = int(i_line.rsplit(' ')[0]) - elif 'External MUX Channels' in i_line: - raw_extras['source_num'] = int(i_line.rsplit(' ')[0]) - elif 'Update Rate (Hz)' in i_line or \ - 'Updata Rate (Hz)' in i_line: + if boxy_ver not in ["0.40", "0.84"]: + raise RuntimeError( + "MNE has not been tested with BOXY " + "version (%s)" % boxy_ver + ) + elif "Detector Channels" in i_line: + raw_extras["detect_num"] = int(i_line.rsplit(" ")[0]) + elif "External MUX Channels" in i_line: + raw_extras["source_num"] = int(i_line.rsplit(" ")[0]) + elif "Update Rate (Hz)" in i_line or "Updata Rate (Hz)" in i_line: # Version 0.40 of the BOXY recording software # (and possibly other versions lower than 0.84) contains a # typo in the raw data file where 'Update Rate' is spelled # "Updata Rate. This will account for this typo. - sfreq = float(i_line.rsplit(' ')[0]) - elif '#DATA BEGINS' in i_line: + sfreq = float(i_line.rsplit(" ")[0]) + elif "#DATA BEGINS" in i_line: # Data should start a couple lines later. start_line = line_num + 3 elif line_num == start_line - 2: # Grab names for each column of data. - raw_extras['col_names'] = col_names = re.findall( - r'\w+\-\w+|\w+\-\d+|\w+', i_line.rsplit(' ')[0]) - if 'exmux' in col_names: + raw_extras["col_names"] = col_names = re.findall( + r"\w+\-\w+|\w+\-\d+|\w+", i_line.rsplit(" ")[0] + ) + if "exmux" in col_names: # Change filetype based on data organisation. - filetype = 'non-parsed' + filetype = "non-parsed" else: - filetype = 'parsed' - if 'digaux' in col_names: - mrk_col = col_names.index('digaux') + filetype = "parsed" + if "digaux" in col_names: + mrk_col = col_names.index("digaux") mrk_data = list() # raw_extras['offsets'].append(fid.tell()) elif line_num == start_line - 1: - raw_extras['offsets'].append(fid.tell()) + raw_extras["offsets"].append(fid.tell()) line_num += 1 i_line = fid.readline() assert sfreq is not None - raw_extras.update( - filetype=filetype, start_line=start_line, end_line=end_line) + raw_extras.update(filetype=filetype, start_line=start_line, end_line=end_line) # Label each channel in our data, for each data type (DC, AC, Ph). # Data is organised by channels x timepoint, where the first @@ -144,30 +148,36 @@ def __init__(self, fname, preload=False, verbose=None): ch_names = list() ch_types = list() cals = list() - for det_num in range(raw_extras['detect_num']): - for src_num in range(raw_extras['source_num']): + for det_num in range(raw_extras["detect_num"]): + for src_num in range(raw_extras["source_num"]): for i_type, ch_type in [ - ('DC', 'fnirs_cw_amplitude'), - ('AC', 'fnirs_fd_ac_amplitude'), - ('Ph', 'fnirs_fd_phase')]: - ch_names.append( - f'S{src_num + 1}_D{det_num + 1} {i_type}') + ("DC", "fnirs_cw_amplitude"), + ("AC", "fnirs_fd_ac_amplitude"), + ("Ph", "fnirs_fd_phase"), + ]: + ch_names.append(f"S{src_num + 1}_D{det_num + 1} {i_type}") ch_types.append(ch_type) - cals.append(np.pi / 180. if i_type == 'Ph' else 1.) + cals.append(np.pi / 180.0 if i_type == "Ph" else 1.0) # Create info structure. info = create_info(ch_names, sfreq, ch_types) - for ch, cal in zip(info['chs'], cals): - ch['cal'] = cal + for ch, cal in zip(info["chs"], cals): + ch["cal"] = cal # Determine how long our data is. delta = end_line - start_line - assert len(raw_extras['offsets']) == delta + 1 - if filetype == 'non-parsed': - delta //= (raw_extras['source_num']) + assert len(raw_extras["offsets"]) == delta + 1 + if filetype == "non-parsed": + delta //= raw_extras["source_num"] super(RawBOXY, self).__init__( - info, preload, filenames=[fname], first_samps=[0], - last_samps=[delta - 1], raw_extras=[raw_extras], verbose=verbose) + info, + preload, + filenames=[fname], + first_samps=[0], + last_samps=[delta - 1], + raw_extras=[raw_extras], + verbose=verbose, + ) # Now let's grab our markers, if they are present. if mrk_data is not None: @@ -198,33 +208,33 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): Regardless of type, output has (n_montages x n_sources x n_detectors + n_marker_channels) rows, and (n_timepoints x n_blocks) columns. """ - source_num = self._raw_extras[fi]['source_num'] - detect_num = self._raw_extras[fi]['detect_num'] - start_line = self._raw_extras[fi]['start_line'] - end_line = self._raw_extras[fi]['end_line'] - filetype = self._raw_extras[fi]['filetype'] - col_names = self._raw_extras[fi]['col_names'] - offsets = self._raw_extras[fi]['offsets'] + source_num = self._raw_extras[fi]["source_num"] + detect_num = self._raw_extras[fi]["detect_num"] + start_line = self._raw_extras[fi]["start_line"] + end_line = self._raw_extras[fi]["end_line"] + filetype = self._raw_extras[fi]["filetype"] + col_names = self._raw_extras[fi]["col_names"] + offsets = self._raw_extras[fi]["offsets"] boxy_file = self._filenames[fi] # Non-parsed multiplexes sources, so we need source_num times as many # lines in that case - if filetype == 'parsed': + if filetype == "parsed": start_read = start_line + start stop_read = start_read + (stop - start) else: - assert filetype == 'non-parsed' + assert filetype == "non-parsed" start_read = start_line + start * source_num stop_read = start_read + (stop - start) * source_num assert start_read >= start_line assert stop_read <= end_line # Possible detector names. - detectors = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'[:detect_num] + detectors = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"[:detect_num] # Loop through our data. one = np.zeros((len(col_names), stop_read - start_read)) - with open(boxy_file, 'r') as fid: + with open(boxy_file, "r") as fid: # Just a more efficient version of this: # ii = 0 # for line_num, i_line in enumerate(fid): @@ -238,26 +248,36 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): fid.seek(offsets[start_read - start_line], 0) for oo in one.T: i_data = fid.readline().strip().split() - oo[:len(i_data)] = i_data + oo[: len(i_data)] = i_data # in theory we could index in the loop above, but it's painfully slow, # so let's just take a hopefully minor memory hit - if filetype == 'non-parsed': - ch_idxs = [col_names.index(f'{det}-{i_type}') - for det in detectors - for i_type in ['DC', 'AC', 'Ph']] - one = one[ch_idxs].reshape( # each "time point" multiplexes srcs - len(detectors), 3, -1, source_num - ).transpose( # reorganize into (det, source, DC/AC/Ph, t) order - 0, 3, 1, 2 - ).reshape( # reshape the way we store it (det x source x DAP, t) - len(detectors) * source_num * 3, -1) + if filetype == "non-parsed": + ch_idxs = [ + col_names.index(f"{det}-{i_type}") + for det in detectors + for i_type in ["DC", "AC", "Ph"] + ] + one = ( + one[ch_idxs] + .reshape( # each "time point" multiplexes srcs + len(detectors), 3, -1, source_num + ) + .transpose( # reorganize into (det, source, DC/AC/Ph, t) order + 0, 3, 1, 2 + ) + .reshape( # reshape the way we store it (det x source x DAP, t) + len(detectors) * source_num * 3, -1 + ) + ) else: - assert filetype == 'parsed' - ch_idxs = [col_names.index(f'{det}-{i_type}{si + 1}') - for det in detectors - for si in range(source_num) - for i_type in ['DC', 'AC', 'Ph']] + assert filetype == "parsed" + ch_idxs = [ + col_names.index(f"{det}-{i_type}{si + 1}") + for det in detectors + for si in range(source_num) + for i_type in ["DC", "AC", "Ph"] + ] one = one[ch_idxs] # Place our data into the data object in place. diff --git a/mne/io/boxy/tests/test_boxy.py b/mne/io/boxy/tests/test_boxy.py index e9cf09ee4cf..0058075f107 100644 --- a/mne/io/boxy/tests/test_boxy.py +++ b/mne/io/boxy/tests/test_boxy.py @@ -4,8 +4,7 @@ import pytest import numpy as np -from numpy.testing import (assert_allclose, assert_array_equal, - assert_array_less) +from numpy.testing import assert_allclose, assert_array_equal, assert_array_less import scipy.io as spio from mne import pick_types @@ -15,10 +14,7 @@ data_path = testing.data_path(download=False) boxy_0_40 = ( - data_path - / "BOXY" - / "boxy_0_40_recording" - / "boxy_0_40_notriggers_unparsed.txt" + data_path / "BOXY" / "boxy_0_40_recording" / "boxy_0_40_notriggers_unparsed.txt" ) p_pod_0_40 = ( data_path @@ -34,10 +30,7 @@ / "boxy_0_84_triggers_unparsed.txt" ) boxy_0_84_parsed = ( - data_path - / "BOXY" - / "boxy_0_84_digaux_recording" - / "boxy_0_84_triggers_parsed.txt" + data_path / "BOXY" / "boxy_0_84_digaux_recording" / "boxy_0_84_triggers_parsed.txt" ) p_pod_0_84 = ( data_path @@ -50,21 +43,22 @@ def _assert_ppod(raw, p_pod_file): have_types = raw.get_channel_types(unique=True) - assert 'fnirs_fd_phase' in raw, have_types - assert 'fnirs_cw_amplitude' in raw, have_types - assert 'fnirs_fd_ac_amplitude' in raw, have_types + assert "fnirs_fd_phase" in raw, have_types + assert "fnirs_cw_amplitude" in raw, have_types + assert "fnirs_fd_ac_amplitude" in raw, have_types ppod_data = spio.loadmat(p_pod_file) # Compare MNE loaded data to p_pod loaded data. - map_ = dict(dc='fnirs_cw_amplitude', ac='fnirs_fd_ac_amplitude', - ph='fnirs_fd_phase') + map_ = dict( + dc="fnirs_cw_amplitude", ac="fnirs_fd_ac_amplitude", ph="fnirs_fd_phase" + ) for key, value in map_.items(): ppod = ppod_data[key].T m = np.median(np.abs(ppod)) assert 1e-1 < m < 1e5, key # our atol is meaningful atol = m * 1e-10 py = raw.get_data(value) - if key == 'ph': # radians + if key == "ph": # radians assert_array_less(-np.pi, py) assert_array_less(py, 3 * np.pi) py = np.rad2deg(py) @@ -75,33 +69,59 @@ def _assert_ppod(raw, p_pod_file): def test_boxy_load(): """Test reading BOXY files.""" raw = read_raw_boxy(boxy_0_40, verbose=True) - assert raw.info['sfreq'] == 62.5 + assert raw.info["sfreq"] == 62.5 _assert_ppod(raw, p_pod_0_40) # Grab our different data types. - mne_ph = raw.copy().pick(picks='fnirs_fd_phase') - mne_dc = raw.copy().pick(picks='fnirs_cw_amplitude') - mne_ac = raw.copy().pick(picks='fnirs_fd_ac_amplitude') + mne_ph = raw.copy().pick(picks="fnirs_fd_phase") + mne_dc = raw.copy().pick(picks="fnirs_cw_amplitude") + mne_ac = raw.copy().pick(picks="fnirs_fd_ac_amplitude") # Check channel names. - first_chans = ['S1_D1', 'S2_D1', 'S3_D1', 'S4_D1', 'S5_D1', - 'S6_D1', 'S7_D1', 'S8_D1', 'S9_D1', 'S10_D1'] - last_chans = ['S1_D8', 'S2_D8', 'S3_D8', 'S4_D8', 'S5_D8', - 'S6_D8', 'S7_D8', 'S8_D8', 'S9_D8', 'S10_D8'] - - assert mne_dc.info['ch_names'][:10] == [i_chan + ' ' + 'DC' - for i_chan in first_chans] - assert mne_ac.info['ch_names'][:10] == [i_chan + ' ' + 'AC' - for i_chan in first_chans] - assert mne_ph.info['ch_names'][:10] == [i_chan + ' ' + 'Ph' - for i_chan in first_chans] - - assert mne_dc.info['ch_names'][70::] == [i_chan + ' ' + 'DC' - for i_chan in last_chans] - assert mne_ac.info['ch_names'][70::] == [i_chan + ' ' + 'AC' - for i_chan in last_chans] - assert mne_ph.info['ch_names'][70::] == [i_chan + ' ' + 'Ph' - for i_chan in last_chans] + first_chans = [ + "S1_D1", + "S2_D1", + "S3_D1", + "S4_D1", + "S5_D1", + "S6_D1", + "S7_D1", + "S8_D1", + "S9_D1", + "S10_D1", + ] + last_chans = [ + "S1_D8", + "S2_D8", + "S3_D8", + "S4_D8", + "S5_D8", + "S6_D8", + "S7_D8", + "S8_D8", + "S9_D8", + "S10_D8", + ] + + assert mne_dc.info["ch_names"][:10] == [ + i_chan + " " + "DC" for i_chan in first_chans + ] + assert mne_ac.info["ch_names"][:10] == [ + i_chan + " " + "AC" for i_chan in first_chans + ] + assert mne_ph.info["ch_names"][:10] == [ + i_chan + " " + "Ph" for i_chan in first_chans + ] + + assert mne_dc.info["ch_names"][70::] == [ + i_chan + " " + "DC" for i_chan in last_chans + ] + assert mne_ac.info["ch_names"][70::] == [ + i_chan + " " + "AC" for i_chan in last_chans + ] + assert mne_ph.info["ch_names"][70::] == [ + i_chan + " " + "Ph" for i_chan in last_chans + ] # Since this data set has no 'digaux' for creating trigger annotations, # let's make sure our Raw object has no annotations. @@ -109,7 +129,7 @@ def test_boxy_load(): @testing.requires_testing_data -@pytest.mark.parametrize('fname', (boxy_0_84, boxy_0_84_parsed)) +@pytest.mark.parametrize("fname", (boxy_0_84, boxy_0_84_parsed)) def test_boxy_filetypes(fname): """Test reading parsed and unparsed BOXY data files.""" # BOXY data files can be saved in two formats (parsed and unparsed) which @@ -127,54 +147,49 @@ def test_boxy_filetypes(fname): # files are comparable, then we will compare the MNE loaded data between # parsed and unparsed files. raw = read_raw_boxy(fname, verbose=True) - assert raw.info['sfreq'] == 79.4722 + assert raw.info["sfreq"] == 79.4722 _assert_ppod(raw, p_pod_0_84) # Grab our different data types. - unp_dc = raw.copy().pick('fnirs_cw_amplitude') - unp_ac = raw.copy().pick('fnirs_fd_ac_amplitude') - unp_ph = raw.copy().pick('fnirs_fd_phase') + unp_dc = raw.copy().pick("fnirs_cw_amplitude") + unp_ac = raw.copy().pick("fnirs_fd_ac_amplitude") + unp_ph = raw.copy().pick("fnirs_fd_phase") # Check channel names. - chans = ['S1_D1', 'S2_D1', 'S3_D1', 'S4_D1', - 'S5_D1', 'S6_D1', 'S7_D1', 'S8_D1'] + chans = ["S1_D1", "S2_D1", "S3_D1", "S4_D1", "S5_D1", "S6_D1", "S7_D1", "S8_D1"] - assert unp_dc.info['ch_names'] == [i_chan + ' ' + 'DC' - for i_chan in chans] - assert unp_ac.info['ch_names'] == [i_chan + ' ' + 'AC' - for i_chan in chans] - assert unp_ph.info['ch_names'] == [i_chan + ' ' + 'Ph' - for i_chan in chans] + assert unp_dc.info["ch_names"] == [i_chan + " " + "DC" for i_chan in chans] + assert unp_ac.info["ch_names"] == [i_chan + " " + "AC" for i_chan in chans] + assert unp_ph.info["ch_names"] == [i_chan + " " + "Ph" for i_chan in chans] @testing.requires_testing_data -@pytest.mark.parametrize('fname', (boxy_0_84, boxy_0_84_parsed)) +@pytest.mark.parametrize("fname", (boxy_0_84, boxy_0_84_parsed)) def test_boxy_digaux(fname): """Test reading BOXY files and generating annotations from digaux.""" srate = 79.4722 raw = read_raw_boxy(fname, verbose=True) # Grab our different data types. - picks_dc = pick_types(raw.info, fnirs='fnirs_cw_amplitude') - picks_ac = pick_types(raw.info, fnirs='fnirs_fd_ac_amplitude') - picks_ph = pick_types(raw.info, fnirs='fnirs_fd_phase') + picks_dc = pick_types(raw.info, fnirs="fnirs_cw_amplitude") + picks_ac = pick_types(raw.info, fnirs="fnirs_fd_ac_amplitude") + picks_ph = pick_types(raw.info, fnirs="fnirs_fd_phase") assert_array_equal(picks_dc, np.arange(0, 8) * 3 + 0) assert_array_equal(picks_ac, np.arange(0, 8) * 3 + 1) assert_array_equal(picks_ph, np.arange(0, 8) * 3 + 2) # Check that our event order matches what we expect. - event_list = ['1.0', '2.0', '3.0', '4.0', '5.0'] + event_list = ["1.0", "2.0", "3.0", "4.0", "5.0"] assert_array_equal(raw.annotations.description, event_list) # Check that our event timings are what we expect. - event_onset = [i_time * (1.0 / srate) for i_time in - [105, 185, 265, 344, 424]] + event_onset = [i_time * (1.0 / srate) for i_time in [105, 185, 265, 344, 424]] assert_allclose(raw.annotations.onset, event_onset, atol=1e-6) # Now let's compare parsed and unparsed events to p_pod loaded digaux. # Load our p_pod data. ppod_data = spio.loadmat(p_pod_0_84) - ppod_digaux = np.transpose(ppod_data['digaux'])[0] + ppod_digaux = np.transpose(ppod_data["digaux"])[0] # Now let's get our triggers from the p_pod digaux. # We only want the first instance of each trigger. @@ -192,14 +207,13 @@ def test_boxy_digaux(fname): tmp_dur = 0 prev_mrk = i_mrk onset = np.asarray([i_mrk * (1.0 / srate) for i_mrk in mrk_idx]) - description = np.asarray([str(float(i_mrk))for i_mrk in - ppod_digaux[mrk_idx]]) + description = np.asarray([str(float(i_mrk)) for i_mrk in ppod_digaux[mrk_idx]]) assert_array_equal(raw.annotations.description, description) assert_allclose(raw.annotations.onset, onset, atol=1e-6) @testing.requires_testing_data -@pytest.mark.parametrize('fname', (boxy_0_40, boxy_0_84, boxy_0_84_parsed)) +@pytest.mark.parametrize("fname", (boxy_0_40, boxy_0_84, boxy_0_84_parsed)) def test_raw_properties(fname): """Test raw reader properties.""" _test_raw_reader(read_raw_boxy, fname=fname, boundary_decimal=1) diff --git a/mne/io/brainvision/brainvision.py b/mne/io/brainvision/brainvision.py index 892f189fca2..495721dfd85 100644 --- a/mne/io/brainvision/brainvision.py +++ b/mne/io/brainvision/brainvision.py @@ -62,29 +62,43 @@ class RawBrainVision(BaseRaw): """ @verbose - def __init__(self, vhdr_fname, - eog=('HEOGL', 'HEOGR', 'VEOGb'), misc='auto', - scale=1., preload=False, verbose=None): # noqa: D107 + def __init__( + self, + vhdr_fname, + eog=("HEOGL", "HEOGR", "VEOGb"), + misc="auto", + scale=1.0, + preload=False, + verbose=None, + ): # noqa: D107 # Channel info and events - logger.info('Extracting parameters from %s...' % vhdr_fname) + logger.info("Extracting parameters from %s..." % vhdr_fname) hdr_fname = op.abspath(vhdr_fname) ext = op.splitext(hdr_fname)[-1] - ahdr_format = True if ext == '.ahdr' else False - (info, data_fname, fmt, order, n_samples, mrk_fname, montage, - orig_units) = _get_hdr_info(hdr_fname, eog, misc, scale) - - with open(data_fname, 'rb') as f: + ahdr_format = True if ext == ".ahdr" else False + ( + info, + data_fname, + fmt, + order, + n_samples, + mrk_fname, + montage, + orig_units, + ) = _get_hdr_info(hdr_fname, eog, misc, scale) + + with open(data_fname, "rb") as f: if isinstance(fmt, dict): # ASCII, this will be slow :( - if order == 'F': # multiplexed, channels in columns + if order == "F": # multiplexed, channels in columns n_skip = 0 - for ii in range(int(fmt['skiplines'])): + for ii in range(int(fmt["skiplines"])): n_skip += len(f.readline()) offsets = np.cumsum([n_skip] + [len(line) for line in f]) n_samples = len(offsets) - 1 - elif order == 'C': # vectorized, channels, in rows + elif order == "C": # vectorized, channels, in rows raise NotImplementedError() else: - n_data_ch = int(info['nchan']) + n_data_ch = int(info["nchan"]) f.seek(0, os.SEEK_END) n_samples = f.tell() dtype_bytes = _fmt_byte_dict[fmt] @@ -92,22 +106,26 @@ def __init__(self, vhdr_fname, n_samples = n_samples // (dtype_bytes * n_data_ch) orig_format = "single" if isinstance(fmt, dict) else fmt - raw_extras = dict( - offsets=offsets, fmt=fmt, order=order, n_samples=n_samples) + raw_extras = dict(offsets=offsets, fmt=fmt, order=order, n_samples=n_samples) super(RawBrainVision, self).__init__( - info, last_samps=[n_samples - 1], filenames=[data_fname], - orig_format=orig_format, preload=preload, verbose=verbose, - raw_extras=[raw_extras], orig_units=orig_units) + info, + last_samps=[n_samples - 1], + filenames=[data_fname], + orig_format=orig_format, + preload=preload, + verbose=verbose, + raw_extras=[raw_extras], + orig_units=orig_units, + ) self.set_montage(montage) settings, cfg, cinfo, _ = _aux_hdr_info(hdr_fname) split_settings = settings.splitlines() - self.impedances = _parse_impedance(split_settings, - self.info['meas_date']) + self.impedances = _parse_impedance(split_settings, self.info["meas_date"]) # Get annotations from marker file - annots = read_annotations(mrk_fname, info['sfreq']) + annots = read_annotations(mrk_fname, info["sfreq"]) self.set_annotations(annots) # Drop the fake ahdr channel if needed @@ -117,40 +135,52 @@ def __init__(self, vhdr_fname, def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): """Read a chunk of raw data.""" # read data - n_data_ch = self._raw_extras[fi]['orig_nchan'] - fmt = self._raw_extras[fi]['fmt'] - if self._raw_extras[fi]['order'] == 'C': + n_data_ch = self._raw_extras[fi]["orig_nchan"] + fmt = self._raw_extras[fi]["fmt"] + if self._raw_extras[fi]["order"] == "C": _read_segments_c(self, data, idx, fi, start, stop, cals, mult) elif isinstance(fmt, str): dtype = _fmt_dtype_dict[fmt] - _read_segments_file(self, data, idx, fi, start, stop, cals, mult, - dtype=dtype, n_channels=n_data_ch) + _read_segments_file( + self, + data, + idx, + fi, + start, + stop, + cals, + mult, + dtype=dtype, + n_channels=n_data_ch, + ) else: - offsets = self._raw_extras[fi]['offsets'] - with open(self._filenames[fi], 'rb') as fid: + offsets = self._raw_extras[fi]["offsets"] + with open(self._filenames[fi], "rb") as fid: fid.seek(offsets[start]) block = np.empty((n_data_ch, stop - start)) for ii in range(stop - start): - line = fid.readline().decode('ASCII') + line = fid.readline().decode("ASCII") line = line.strip() # Not sure why we special-handle the "," character here, # but let's just keep this for historical and backward- # compat reasons - if (isinstance(fmt, dict) and - 'decimalsymbol' in fmt and - fmt['decimalsymbol'] != '.'): - line = line.replace(',', '.') - - if ' ' in line: + if ( + isinstance(fmt, dict) + and "decimalsymbol" in fmt + and fmt["decimalsymbol"] != "." + ): + line = line.replace(",", ".") + + if " " in line: line_data = line.split() - elif ',' in line: + elif "," in line: # likely exported from BrainVision Analyzer? - line_data = line.split(',') + line_data = line.split(",") else: raise RuntimeError( - 'Unknown BrainVision data format encountered. ' - 'Please contact the MNE-Python developers.' + "Unknown BrainVision data format encountered. " + "Please contact the MNE-Python developers." ) block[:n_data_ch, ii] = [float(part) for part in line_data] @@ -159,13 +189,13 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): def _read_segments_c(raw, data, idx, fi, start, stop, cals, mult): """Read chunk of vectorized raw data.""" - n_samples = raw._raw_extras[fi]['n_samples'] - fmt = raw._raw_extras[fi]['fmt'] + n_samples = raw._raw_extras[fi]["n_samples"] + fmt = raw._raw_extras[fi]["fmt"] dtype = _fmt_dtype_dict[fmt] n_bytes = _fmt_byte_dict[fmt] - n_channels = raw._raw_extras[fi]['orig_nchan'] + n_channels = raw._raw_extras[fi]["orig_nchan"] block = np.zeros((n_channels, stop - start)) - with open(raw._filenames[fi], 'rb', buffering=0) as fid: + with open(raw._filenames[fi], "rb", buffering=0) as fid: ids = np.arange(idx.start, idx.stop) if isinstance(idx, slice) else idx for ch_id in ids: fid.seek(start * n_bytes + ch_id * n_bytes * n_samples) @@ -194,14 +224,14 @@ def _read_mrk(fname): recording time is found. """ # read marker file - with open(fname, 'rb') as fid: + with open(fname, "rb") as fid: txt = fid.read() # we don't actually need to know the coding for the header line. # the characters in it all belong to ASCII and are thus the # same in Latin-1 and UTF-8 - header = txt.decode('ascii', 'ignore').split('\n')[0].strip() - _check_bv_version(header, 'marker') + header = txt.decode("ascii", "ignore").split("\n")[0].strip() + _check_bv_version(header, "marker") # although the markers themselves are guaranteed to be ASCII (they # consist of numbers and a few reserved words), we should still @@ -212,60 +242,59 @@ def _read_mrk(fname): try: # if there is an explicit codepage set, use it # we pretend like it's ascii when searching for the codepage - cp_setting = re.search('Codepage=(.+)', - txt.decode('ascii', 'ignore'), - re.IGNORECASE & re.MULTILINE) - codepage = 'utf-8' + cp_setting = re.search( + "Codepage=(.+)", txt.decode("ascii", "ignore"), re.IGNORECASE & re.MULTILINE + ) + codepage = "utf-8" if cp_setting: codepage = cp_setting.group(1).strip() # BrainAmp Recorder also uses ANSI codepage # an ANSI codepage raises a LookupError exception # python recognize ANSI decoding as cp1252 - if codepage == 'ANSI': - codepage = 'cp1252' + if codepage == "ANSI": + codepage = "cp1252" txt = txt.decode(codepage) except UnicodeDecodeError: # if UTF-8 (new standard) or explicit codepage setting fails, # fallback to Latin-1, which is Windows default and implicit # standard in older recordings - txt = txt.decode('latin-1') + txt = txt.decode("latin-1") # extract Marker Infos block m = re.search(r"\[Marker Infos\]", txt, re.IGNORECASE) if not m: - return np.array(list()), np.array(list()), np.array(list()), '' + return np.array(list()), np.array(list()), np.array(list()), "" - mk_txt = txt[m.end():] + mk_txt = txt[m.end() :] m = re.search(r"^\[.*\]$", mk_txt) if m: - mk_txt = mk_txt[:m.start()] + mk_txt = mk_txt[: m.start()] # extract event information items = re.findall(r"^Mk\d+=(.*)", mk_txt, re.MULTILINE) onset, duration, description = list(), list(), list() - date_str = '' + date_str = "" for info in items: - info_data = info.split(',') + info_data = info.split(",") mtype, mdesc, this_onset, this_duration = info_data[:4] # commas in mtype and mdesc are handled as "\1". convert back to comma - mtype = mtype.replace(r'\1', ',') - mdesc = mdesc.replace(r'\1', ',') - if date_str == '' and len(info_data) == 5 and mtype == 'New Segment': + mtype = mtype.replace(r"\1", ",") + mdesc = mdesc.replace(r"\1", ",") + if date_str == "" and len(info_data) == 5 and mtype == "New Segment": # to handle the origin of time and handle the presence of multiple # New Segment annotations. We only keep the first one that is # different from an empty string for date_str. date_str = info_data[-1] - this_duration = (int(this_duration) - if this_duration.isdigit() else 0) + this_duration = int(this_duration) if this_duration.isdigit() else 0 duration.append(this_duration) onset.append(int(this_onset) - 1) # BV is 1-indexed, not 0-indexed - description.append(mtype + '/' + mdesc) + description.append(mtype + "/" + mdesc) return np.array(onset), np.array(duration), np.array(description), date_str -def _read_annotations_brainvision(fname, sfreq='auto'): +def _read_annotations_brainvision(fname, sfreq="auto"): """Create Annotations from BrainVision vmrk/amrk. This function reads a .vmrk or .amrk file and makes an @@ -292,20 +321,20 @@ def _read_annotations_brainvision(fname, sfreq='auto'): onset, duration, description, date_str = _read_mrk(fname) orig_time = _str_to_meas_date(date_str) - if sfreq == 'auto': - hdr_fname = op.splitext(fname)[0] + '.vhdr' + if sfreq == "auto": + hdr_fname = op.splitext(fname)[0] + ".vhdr" # if vhdr file does not exist assume that the format is ahdr if not op.exists(hdr_fname): - hdr_fname = op.splitext(fname)[0] + '.ahdr' + hdr_fname = op.splitext(fname)[0] + ".ahdr" logger.info("Finding 'sfreq' from header file: %s" % hdr_fname) _, _, _, info = _aux_hdr_info(hdr_fname) - sfreq = info['sfreq'] + sfreq = info["sfreq"] onset = np.array(onset, dtype=float) / sfreq duration = np.array(duration, dtype=float) / sfreq - annotations = Annotations(onset=onset, duration=duration, - description=description, - orig_time=orig_time) + annotations = Annotations( + onset=onset, duration=duration, description=description, orig_time=orig_time + ) return annotations @@ -316,13 +345,14 @@ def _check_bv_version(header, kind): %r. Contact MNE-Python developers for support.""" # optional space, optional Core or V-Amp, optional Exchange, # Version/Header, optional comma, 1/2 - _data_re = (r"Brain ?Vision( Core| V-Amp)? Data( Exchange)? " - r"%s File,? Version %s\.0") + _data_re = ( + r"Brain ?Vision( Core| V-Amp)? Data( Exchange)? " r"%s File,? Version %s\.0" + ) - assert kind in ('header', 'marker') + assert kind in ("header", "marker") - if header == '': - warn(f'Missing header in {kind} file.') + if header == "": + warn(f"Missing header in {kind} file.") for version in range(1, 3): this_re = _data_re % (kind.capitalize(), version) if re.search(this_re, header) is not None: @@ -331,37 +361,39 @@ def _check_bv_version(header, kind): warn(_data_err % (kind, header)) -_orientation_dict = dict(MULTIPLEXED='F', VECTORIZED='C') -_fmt_dict = dict(INT_16='short', INT_32='int', IEEE_FLOAT_32='single') +_orientation_dict = dict(MULTIPLEXED="F", VECTORIZED="C") +_fmt_dict = dict(INT_16="short", INT_32="int", IEEE_FLOAT_32="single") _fmt_byte_dict = dict(short=2, int=4, single=4) -_fmt_dtype_dict = dict(short=' 0: misc += to_misc - warn('No coordinate information found for channels {}. ' - 'Setting channel types to misc. To avoid this warning, set ' - 'channel types explicitly.'.format(to_misc)) + warn( + "No coordinate information found for channels {}. " + "Setting channel types to misc. To avoid this warning, set " + "channel types explicitly.".format(to_misc) + ) if np.isnan(cals).any(): - raise RuntimeError('Missing channel units') + raise RuntimeError("Missing channel units") # Attempts to extract filtering info from header. If not found, both are # set to zero. settings = settings.splitlines() idx = None - if 'Channels' in settings: - idx = settings.index('Channels') - settings = settings[idx + 1:] + if "Channels" in settings: + idx = settings.index("Channels") + settings = settings[idx + 1 :] hp_col, lp_col = 4, 5 for idx, setting in enumerate(settings): - if re.match(r'#\s+Name', setting): + if re.match(r"#\s+Name", setting): break else: idx = None @@ -644,14 +683,16 @@ def _get_hdr_info(hdr_fname, eog, misc, scale): idx_amp = idx filter_list_has_ch_name = True - if 'S o f t w a r e F i l t e r s' in settings: - idx = settings.index('S o f t w a r e F i l t e r s') - for idx, setting in enumerate(settings[idx + 1:], idx + 1): - if re.match(r'#\s+Low Cutoff', setting): + if "S o f t w a r e F i l t e r s" in settings: + idx = settings.index("S o f t w a r e F i l t e r s") + for idx, setting in enumerate(settings[idx + 1 :], idx + 1): + if re.match(r"#\s+Low Cutoff", setting): hp_col, lp_col = 1, 2 filter_list_has_ch_name = False - warn('Online software filter detected. Using software ' - 'filter settings and ignoring hardware values') + warn( + "Online software filter detected. Using software " + "filter settings and ignoring hardware values" + ) break else: idx = idx_amp @@ -663,8 +704,8 @@ def _get_hdr_info(hdr_fname, eog, misc, scale): # for newer BV files, the unit is specified for every channel # separated by a single space, while for older files, the unit is # specified in the column headers - divider = r'\s+' - if 'Resolution / Unit' in settings[idx]: + divider = r"\s+" + if "Resolution / Unit" in settings[idx]: shift = 1 # shift for unit else: shift = 0 @@ -678,9 +719,9 @@ def _get_hdr_info(hdr_fname, eog, misc, scale): # `Ebersole, J. S., & Pedley, T. A. (Eds.). (2003). # Current practice of clinical electroencephalography. # Lippincott Williams & Wilkins.`, page 40-41 - header = re.split(r'\s\s+', settings[idx]) - hp_s = '[s]' in header[hp_col] - lp_s = '[s]' in header[lp_col] + header = re.split(r"\s\s+", settings[idx]) + hp_s = "[s]" in header[hp_col] + lp_s = "[s]" in header[lp_col] for i, ch in enumerate(ch_names, 1): if ahdr_format and i == len(ch_names) and ch == _AHDR_CHANNEL_NAME: @@ -707,77 +748,85 @@ def _get_hdr_info(hdr_fname, eog, misc, scale): if len(highpass) == 0: pass elif len(set(highpass)) == 1: - if highpass[0] in ('NaN', 'Off'): + if highpass[0] in ("NaN", "Off"): pass # Placeholder for future use. Highpass set in _empty_info - elif highpass[0] == 'DC': - info['highpass'] = 0. + elif highpass[0] == "DC": + info["highpass"] = 0.0 else: - info['highpass'] = float(highpass[0]) + info["highpass"] = float(highpass[0]) if hp_s: # filter time constant t [secs] to Hz conversion: 1/2*pi*t - info['highpass'] = 1. / (2 * np.pi * info['highpass']) + info["highpass"] = 1.0 / (2 * np.pi * info["highpass"]) else: heterogeneous_hp_filter = True if hp_s: # We convert channels with disabled filters to having # highpass relaxed / no filters - highpass = [float(filt) if filt not in ('NaN', 'Off', 'DC') - else np.Inf for filt in highpass] - info['highpass'] = np.max(np.array(highpass, dtype=np.float64)) + highpass = [ + float(filt) if filt not in ("NaN", "Off", "DC") else np.Inf + for filt in highpass + ] + info["highpass"] = np.max(np.array(highpass, dtype=np.float64)) # Coveniently enough 1 / np.Inf = 0.0, so this works for # DC / no highpass filter # filter time constant t [secs] to Hz conversion: 1/2*pi*t - info['highpass'] = 1. / (2 * np.pi * info['highpass']) + info["highpass"] = 1.0 / (2 * np.pi * info["highpass"]) # not exactly the cleanest use of FP, but this makes us # more conservative in *not* warning. - if info['highpass'] == 0.0 and len(set(highpass)) == 1: + if info["highpass"] == 0.0 and len(set(highpass)) == 1: # not actually heterogeneous in effect # ... just heterogeneously disabled heterogeneous_hp_filter = False else: - highpass = [float(filt) if filt not in ('NaN', 'Off', 'DC') - else 0.0 for filt in highpass] - info['highpass'] = np.min(np.array(highpass, dtype=np.float64)) - if info['highpass'] == 0.0 and len(set(highpass)) == 1: + highpass = [ + float(filt) if filt not in ("NaN", "Off", "DC") else 0.0 + for filt in highpass + ] + info["highpass"] = np.min(np.array(highpass, dtype=np.float64)) + if info["highpass"] == 0.0 and len(set(highpass)) == 1: # not actually heterogeneous in effect # ... just heterogeneously disabled heterogeneous_hp_filter = False if heterogeneous_hp_filter: - warn('Channels contain different highpass filters. ' - 'Lowest (weakest) filter setting (%0.2f Hz) ' - 'will be stored.' % info['highpass']) + warn( + "Channels contain different highpass filters. " + "Lowest (weakest) filter setting (%0.2f Hz) " + "will be stored." % info["highpass"] + ) if len(lowpass) == 0: pass elif len(set(lowpass)) == 1: - if lowpass[0] in ('NaN', 'Off', '0'): + if lowpass[0] in ("NaN", "Off", "0"): pass # Placeholder for future use. Lowpass set in _empty_info else: - info['lowpass'] = float(lowpass[0]) + info["lowpass"] = float(lowpass[0]) if lp_s: # filter time constant t [secs] to Hz conversion: 1/2*pi*t - info['lowpass'] = 1. / (2 * np.pi * info['lowpass']) + info["lowpass"] = 1.0 / (2 * np.pi * info["lowpass"]) else: heterogeneous_lp_filter = True if lp_s: # We convert channels with disabled filters to having # infinitely relaxed / no filters - lowpass = [float(filt) if filt not in ('NaN', 'Off', '0') - else 0.0 for filt in lowpass] - info['lowpass'] = np.min(np.array(lowpass, dtype=np.float64)) + lowpass = [ + float(filt) if filt not in ("NaN", "Off", "0") else 0.0 + for filt in lowpass + ] + info["lowpass"] = np.min(np.array(lowpass, dtype=np.float64)) try: # filter time constant t [secs] to Hz conversion: 1/2*pi*t - info['lowpass'] = 1. / (2 * np.pi * info['lowpass']) + info["lowpass"] = 1.0 / (2 * np.pi * info["lowpass"]) except ZeroDivisionError: if len(set(lowpass)) == 1: # No lowpass actually set for the weakest setting # so we set lowpass to the Nyquist frequency - info['lowpass'] = info['sfreq'] / 2. + info["lowpass"] = info["sfreq"] / 2.0 # not actually heterogeneous in effect # ... just heterogeneously disabled heterogeneous_lp_filter = False @@ -788,14 +837,16 @@ def _get_hdr_info(hdr_fname, eog, misc, scale): else: # We convert channels with disabled filters to having # infinitely relaxed / no filters - lowpass = [float(filt) if filt not in ('NaN', 'Off', '0') - else np.Inf for filt in lowpass] - info['lowpass'] = np.max(np.array(lowpass, dtype=np.float64)) + lowpass = [ + float(filt) if filt not in ("NaN", "Off", "0") else np.Inf + for filt in lowpass + ] + info["lowpass"] = np.max(np.array(lowpass, dtype=np.float64)) - if np.isinf(info['lowpass']): + if np.isinf(info["lowpass"]): # No lowpass actually set for the weakest setting # so we set lowpass to the Nyquist frequency - info['lowpass'] = info['sfreq'] / 2. + info["lowpass"] = info["sfreq"] / 2.0 if len(set(lowpass)) == 1: # not actually heterogeneous in effect # ... just heterogeneously disabled @@ -806,17 +857,19 @@ def _get_hdr_info(hdr_fname, eog, misc, scale): # the Nyquist hint when the lowpass filter was actually # calculated from dividing the sampling frequency by 2, so the # exact/direct comparison (instead of tolerance) makes sense - if info['lowpass'] == info['sfreq'] / 2.0: - nyquist = ', Nyquist limit' + if info["lowpass"] == info["sfreq"] / 2.0: + nyquist = ", Nyquist limit" else: nyquist = "" - warn('Channels contain different lowpass filters. ' - 'Highest (weakest) filter setting (%0.2f Hz%s) ' - 'will be stored.' % (info['lowpass'], nyquist)) + warn( + "Channels contain different lowpass filters. " + "Highest (weakest) filter setting (%0.2f Hz%s) " + "will be stored." % (info["lowpass"], nyquist) + ) # Creates a list of dicts of eeg channels for raw.info - logger.info('Setting channel info structure...') - info['chs'] = [] + logger.info("Setting channel info structure...") + info["chs"] = [] for idx, ch_name in enumerate(ch_names): if ch_name in eog or idx in eog or idx - nchan in eog: kind = FIFF.FIFFV_EOG_CH @@ -829,7 +882,7 @@ def _get_hdr_info(hdr_fname, eog, misc, scale): unit = misc_chs[ch_name] else: unit = FIFF.FIFF_UNIT_NONE - elif ch_name == 'STI 014': + elif ch_name == "STI 014": kind = FIFF.FIFFV_STIM_CH coil_type = FIFF.FIFFV_COIL_NONE unit = FIFF.FIFF_UNIT_NONE @@ -837,23 +890,36 @@ def _get_hdr_info(hdr_fname, eog, misc, scale): kind = FIFF.FIFFV_EEG_CH coil_type = FIFF.FIFFV_COIL_EEG unit = FIFF.FIFF_UNIT_V - info['chs'].append(dict( - ch_name=ch_name, coil_type=coil_type, kind=kind, logno=idx + 1, - scanno=idx + 1, cal=cals[idx], range=ranges[idx], - loc=np.full(12, np.nan), - unit=unit, unit_mul=FIFF.FIFF_UNITM_NONE, - coord_frame=FIFF.FIFFV_COORD_HEAD)) + info["chs"].append( + dict( + ch_name=ch_name, + coil_type=coil_type, + kind=kind, + logno=idx + 1, + scanno=idx + 1, + cal=cals[idx], + range=ranges[idx], + loc=np.full(12, np.nan), + unit=unit, + unit_mul=FIFF.FIFF_UNITM_NONE, + coord_frame=FIFF.FIFFV_COORD_HEAD, + ) + ) info._unlocked = False info._update_redundant() - return (info, data_fname, fmt, order, n_samples, mrk_fname, montage, - orig_units) + return (info, data_fname, fmt, order, n_samples, mrk_fname, montage, orig_units) @fill_doc -def read_raw_brainvision(vhdr_fname, - eog=('HEOGL', 'HEOGR', 'VEOGb'), misc='auto', - scale=1., preload=False, verbose=None): +def read_raw_brainvision( + vhdr_fname, + eog=("HEOGL", "HEOGR", "VEOGb"), + misc="auto", + scale=1.0, + preload=False, + verbose=None, +): """Reader for Brain Vision EEG file. Parameters @@ -885,16 +951,23 @@ def read_raw_brainvision(vhdr_fname, -------- mne.io.Raw : Documentation of attributes and methods of RawBrainVision. """ - return RawBrainVision(vhdr_fname=vhdr_fname, eog=eog, - misc=misc, scale=scale, preload=preload, - verbose=verbose) - - -_BV_EVENT_IO_OFFSETS = {'Event/': 0, 'Stimulus/S': 0, 'Response/R': 1000, - 'Optic/O': 2000} -_OTHER_ACCEPTED_MARKERS = { - 'New Segment/': 99999, 'SyncStatus/Sync On': 99998 + return RawBrainVision( + vhdr_fname=vhdr_fname, + eog=eog, + misc=misc, + scale=scale, + preload=preload, + verbose=verbose, + ) + + +_BV_EVENT_IO_OFFSETS = { + "Event/": 0, + "Stimulus/S": 0, + "Response/R": 1000, + "Optic/O": 2000, } +_OTHER_ACCEPTED_MARKERS = {"New Segment/": 99999, "SyncStatus/Sync On": 99998} _OTHER_OFFSET = 10001 # where to start "unknown" event_ids _AHDR_CHANNEL_NAME = "AHDR_CHANNEL" @@ -913,15 +986,17 @@ def __call__(self, description): elif description in _OTHER_ACCEPTED_MARKERS: code = _OTHER_ACCEPTED_MARKERS[description] else: - code = (super(_BVEventParser, self) - .__call__(description, offset=_OTHER_OFFSET)) + code = super(_BVEventParser, self).__call__( + description, offset=_OTHER_OFFSET + ) return code def _check_bv_annot(descriptions): - markers_basename = set([dd.rstrip('0123456789 ') for dd in descriptions]) - bv_markers = (set(_BV_EVENT_IO_OFFSETS.keys()) - .union(set(_OTHER_ACCEPTED_MARKERS.keys()))) + markers_basename = set([dd.rstrip("0123456789 ") for dd in descriptions]) + bv_markers = set(_BV_EVENT_IO_OFFSETS.keys()).union( + set(_OTHER_ACCEPTED_MARKERS.keys()) + ) return len(markers_basename - bv_markers) == 0 @@ -941,46 +1016,47 @@ def _parse_impedance(settings, recording_date=None): A dictionary of all electrodes and their impedances. """ ranges = _parse_impedance_ranges(settings) - impedance_setting_lines = [i for i in settings if - i.startswith('Impedance [') and - i.endswith(' :')] + impedance_setting_lines = [ + i for i in settings if i.startswith("Impedance [") and i.endswith(" :") + ] impedances = dict() if len(impedance_setting_lines) > 0: idx = settings.index(impedance_setting_lines[0]) impedance_setting = impedance_setting_lines[0].split() - impedance_unit = impedance_setting[1].lstrip('[').rstrip(']') + impedance_unit = impedance_setting[1].lstrip("[").rstrip("]") impedance_time = None # If we have a recording date, we can update it with the time of # impedance measurement if recording_date is not None: - meas_time = [int(i) for i in impedance_setting[3].split(':')] - impedance_time = recording_date.replace(hour=meas_time[0], - minute=meas_time[1], - second=meas_time[2], - microsecond=0) - for setting in settings[idx + 1:]: + meas_time = [int(i) for i in impedance_setting[3].split(":")] + impedance_time = recording_date.replace( + hour=meas_time[0], + minute=meas_time[1], + second=meas_time[2], + microsecond=0, + ) + for setting in settings[idx + 1 :]: # Parse channel impedances until we find a line that doesn't start # with a channel name and optional +/- polarity for passive elecs - match = re.match(r'[ a-zA-Z0-9_+-]+:', setting) + match = re.match(r"[ a-zA-Z0-9_+-]+:", setting) if match: - channel_name = match.group().rstrip(':') + channel_name = match.group().rstrip(":") channel_imp_line = setting.split() - imp_as_number = re.findall(r"[-+]?\d*\.\d+|\d+", - channel_imp_line[-1]) + imp_as_number = re.findall(r"[-+]?\d*\.\d+|\d+", channel_imp_line[-1]) channel_impedance = dict( imp=float(imp_as_number[0]) if imp_as_number else np.nan, imp_unit=impedance_unit, ) if impedance_time is not None: - channel_impedance.update({'imp_meas_time': impedance_time}) - - if channel_name == 'Ref' and 'Reference' in ranges: - channel_impedance.update(ranges['Reference']) - elif channel_name == 'Gnd' and 'Ground' in ranges: - channel_impedance.update(ranges['Ground']) - elif 'Data' in ranges: - channel_impedance.update(ranges['Data']) + channel_impedance.update({"imp_meas_time": impedance_time}) + + if channel_name == "Ref" and "Reference" in ranges: + channel_impedance.update(ranges["Reference"]) + elif channel_name == "Gnd" and "Ground" in ranges: + channel_impedance.update(ranges["Ground"]) + elif "Data" in ranges: + channel_impedance.update(ranges["Data"]) impedances[channel_name] = channel_impedance else: break @@ -1000,17 +1076,18 @@ def _parse_impedance_ranges(settings): electrode_imp_ranges : dict A dictionary of impedance ranges for each type of electrode. """ - impedance_ranges = [item for item in settings if - "Selected Impedance Measurement Range" in item] + impedance_ranges = [ + item for item in settings if "Selected Impedance Measurement Range" in item + ] electrode_imp_ranges = dict() if impedance_ranges: if len(impedance_ranges) == 1: img_range = impedance_ranges[0].split() - for electrode_type in ['Data', 'Reference', 'Ground']: + for electrode_type in ["Data", "Reference", "Ground"]: electrode_imp_ranges[electrode_type] = { "imp_lower_bound": float(img_range[-4]), "imp_upper_bound": float(img_range[-2]), - "imp_range_unit": img_range[-1] + "imp_range_unit": img_range[-1], } else: for electrode_range in impedance_ranges: @@ -1018,6 +1095,6 @@ def _parse_impedance_ranges(settings): electrode_imp_ranges[electrode_range[0]] = { "imp_lower_bound": float(electrode_range[6]), "imp_upper_bound": float(electrode_range[8]), - "imp_range_unit": electrode_range[9] + "imp_range_unit": electrode_range[9], } return electrode_imp_ranges diff --git a/mne/io/brainvision/tests/test_brainvision.py b/mne/io/brainvision/tests/test_brainvision.py index c78baa89027..9c94737b6b0 100644 --- a/mne/io/brainvision/tests/test_brainvision.py +++ b/mne/io/brainvision/tests/test_brainvision.py @@ -77,80 +77,88 @@ def test_orig_units(recwarn): raw = read_raw_brainvision(vhdr_path) orig_units = raw._orig_units assert len(orig_units) == 32 - assert orig_units['FP1'] == 'µV' + assert orig_units["FP1"] == "µV" # no unit specified in the vhdr, ensure we default to µV here - assert orig_units['FP2'] == 'µV' - assert orig_units['F3'] == 'µV' - - sum([v == 'µV' for v in orig_units.values()]) == 26 - - assert orig_units['CP5'] == 'n/a' # originally BS, not a valid unit - assert orig_units['CP6'] == 'µS' - assert orig_units['HL'] == 'n/a' # originally ARU, not a valid unit - assert orig_units['HR'] == 'n/a' # originally uS ... - assert orig_units['Vb'] == 'S' - assert orig_units['ReRef'] == 'C' - - -DATE_TEST_CASES = np.array([ - ('Mk1=New Segment,,1,1,0,20131113161403794232\n', # content - [1384359243, 794232], # meas_date internal representation - '2013-11-13 16:14:03 UTC'), # meas_date representation - - (('Mk1=New Segment,,1,1,0,20070716122240937454\n' - 'Mk2=New Segment,,2,1,0,20070716122240937455\n'), - [1184588560, 937454], - '2007-07-16 12:22:40 UTC'), - - ('Mk1=New Segment,,1,1,0,\nMk2=New Segment,,2,1,0,20070716122240937454\n', - [1184588560, 937454], - '2007-07-16 12:22:40 UTC'), - - ('Mk1=STATUS,,1,1,0\n', None, 'unspecified'), - ('Mk1=New Segment,,1,1,0,\n', None, 'unspecified'), - ('Mk1=New Segment,,1,1,0\n', None, 'unspecified'), - ('Mk1=New Segment,,1,1,0,00000000000304125000', None, 'unspecified'), - -], dtype=np.dtype({ - 'names': ['content', 'meas_date', 'meas_date_repr'], - 'formats': [object, object, 'U22'] -})) + assert orig_units["FP2"] == "µV" + assert orig_units["F3"] == "µV" + + sum([v == "µV" for v in orig_units.values()]) == 26 + + assert orig_units["CP5"] == "n/a" # originally BS, not a valid unit + assert orig_units["CP6"] == "µS" + assert orig_units["HL"] == "n/a" # originally ARU, not a valid unit + assert orig_units["HR"] == "n/a" # originally uS ... + assert orig_units["Vb"] == "S" + assert orig_units["ReRef"] == "C" + + +DATE_TEST_CASES = np.array( + [ + ( + "Mk1=New Segment,,1,1,0,20131113161403794232\n", # content + [1384359243, 794232], # meas_date internal representation + "2013-11-13 16:14:03 UTC", + ), # meas_date representation + ( + ( + "Mk1=New Segment,,1,1,0,20070716122240937454\n" + "Mk2=New Segment,,2,1,0,20070716122240937455\n" + ), + [1184588560, 937454], + "2007-07-16 12:22:40 UTC", + ), + ( + "Mk1=New Segment,,1,1,0,\nMk2=New Segment,,2,1,0,20070716122240937454\n", + [1184588560, 937454], + "2007-07-16 12:22:40 UTC", + ), + ("Mk1=STATUS,,1,1,0\n", None, "unspecified"), + ("Mk1=New Segment,,1,1,0,\n", None, "unspecified"), + ("Mk1=New Segment,,1,1,0\n", None, "unspecified"), + ("Mk1=New Segment,,1,1,0,00000000000304125000", None, "unspecified"), + ], + dtype=np.dtype( + { + "names": ["content", "meas_date", "meas_date_repr"], + "formats": [object, object, "U22"], + } + ), +) -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def _mocked_meas_date_data(tmp_path_factory): """Prepare files for mocked_meas_date_file fixture.""" # Prepare the files tmp_path = tmp_path_factory.mktemp("brainvision_mocked_meas_date") vhdr_fname, vmrk_fname, eeg_fname = [ - tmp_path / ff.name - for ff in [vhdr_path, vmrk_path, eeg_path] + tmp_path / ff.name for ff in [vhdr_path, vmrk_path, eeg_path] ] for orig, dest in zip([vhdr_path, eeg_path], [vhdr_fname, eeg_fname]): shutil.copyfile(orig, dest) # Get the marker information - with open(vmrk_path, 'r') as fin: + with open(vmrk_path, "r") as fin: lines = fin.readlines() return vhdr_fname, vmrk_fname, lines -@pytest.fixture(scope='session', params=[tt for tt in DATE_TEST_CASES]) +@pytest.fixture(scope="session", params=[tt for tt in DATE_TEST_CASES]) def mocked_meas_date_file(_mocked_meas_date_data, request): """Prepare a generator for use in test_meas_date.""" MEAS_DATE_LINE = 11 # see test.vmrk file vhdr_fname, vmrk_fname, lines = _mocked_meas_date_data - lines[MEAS_DATE_LINE] = request.param['content'] - with open(vmrk_fname, 'w') as fout: + lines[MEAS_DATE_LINE] = request.param["content"] + with open(vmrk_fname, "w") as fout: fout.writelines(lines) - meas_date = request.param['meas_date'] + meas_date = request.param["meas_date"] if meas_date is not None: meas_date = _stamp_to_dt(meas_date) - yield vhdr_fname, meas_date, request.param['meas_date_repr'] + yield vhdr_fname, meas_date, request.param["meas_date_repr"] def test_meas_date(mocked_meas_date_file): @@ -159,9 +167,9 @@ def test_meas_date(mocked_meas_date_file): raw = read_raw_brainvision(vhdr_f) assert expected_meas_repr in repr(raw.info) if expected_meas is None: - assert raw.info['meas_date'] is None + assert raw.info["meas_date"] is None else: - assert raw.info['meas_date'] == expected_meas + assert raw.info["meas_date"] == expected_meas def test_vhdr_codepage_ansi(tmp_path): @@ -174,20 +182,20 @@ def test_vhdr_codepage_ansi(tmp_path): # copy data file shutil.copy(eeg_path, ansi_eeg_path) # modify header file - with open(ansi_vhdr_path, 'wb') as fout: - with open(vhdr_path, 'rb') as fin: + with open(ansi_vhdr_path, "wb") as fout: + with open(vhdr_path, "rb") as fin: for line in fin: # Common Infos section - if line.startswith(b'Codepage'): - line = b'Codepage=ANSI\n' + if line.startswith(b"Codepage"): + line = b"Codepage=ANSI\n" fout.write(line) # modify marker file - with open(ansi_vmrk_path, 'wb') as fout: - with open(vmrk_path, 'rb') as fin: + with open(ansi_vmrk_path, "wb") as fout: + with open(vmrk_path, "rb") as fin: for line in fin: # Common Infos section - if line.startswith(b'Codepage'): - line = b'Codepage=ANSI\n' + if line.startswith(b"Codepage"): + line = b"Codepage=ANSI\n" fout.write(line) raw = read_raw_brainvision(ansi_vhdr_path) @@ -198,15 +206,18 @@ def test_vhdr_codepage_ansi(tmp_path): assert_allclose(times_new, times_expected, atol=1e-15) -@pytest.mark.parametrize('header', [ - b'BrainVision Data Exchange %s File Version 1.0\n', - # 2.0, space, core, comma - b'Brain Vision Core Data Exchange %s File, Version 2.0\n', - # unsupported version - b'Brain Vision Core Data Exchange %s File, Version 3.0\n', - # missing header - b'\n', -]) +@pytest.mark.parametrize( + "header", + [ + b"BrainVision Data Exchange %s File Version 1.0\n", + # 2.0, space, core, comma + b"Brain Vision Core Data Exchange %s File, Version 2.0\n", + # unsupported version + b"Brain Vision Core Data Exchange %s File, Version 3.0\n", + # missing header + b"\n", + ], +) def test_vhdr_versions(tmp_path, header): """Test BV reading with different header variants.""" raw_init = read_raw_brainvision(vhdr_path) @@ -215,33 +226,33 @@ def test_vhdr_versions(tmp_path, header): use_vmrk_path = tmp_path / vmrk_path.name use_eeg_path = tmp_path / eeg_path.name shutil.copy(eeg_path, use_eeg_path) - with open(use_vhdr_path, 'wb') as fout: - with open(vhdr_path, 'rb') as fin: + with open(use_vhdr_path, "wb") as fout: + with open(vhdr_path, "rb") as fin: for line in fin: # Common Infos section - if line.startswith(b'Brain'): - if header != b'\n': - line = header % b'Header' + if line.startswith(b"Brain"): + if header != b"\n": + line = header % b"Header" else: line = header fout.write(line) - with open(use_vmrk_path, 'wb') as fout: - with open(vmrk_path, 'rb') as fin: + with open(use_vmrk_path, "wb") as fout: + with open(vmrk_path, "rb") as fin: for line in fin: # Common Infos section - if line.startswith(b'Brain'): - if header != b'\n': - line = header % b'Marker' + if line.startswith(b"Brain"): + if header != b"\n": + line = header % b"Marker" else: line = header fout.write(line) - if (b'3.0' in header): # unsupported version - with pytest.warns(RuntimeWarning, match=r'3\.0.*Contact MNE-Python'): + if b"3.0" in header: # unsupported version + with pytest.warns(RuntimeWarning, match=r"3\.0.*Contact MNE-Python"): read_raw_brainvision(use_vhdr_path) return - elif header == b'\n': # no version header - with pytest.warns(RuntimeWarning, match='Missing header'): + elif header == b"\n": # no version header + with pytest.warns(RuntimeWarning, match="Missing header"): read_raw_brainvision(use_vhdr_path) return else: @@ -250,7 +261,7 @@ def test_vhdr_versions(tmp_path, header): assert_allclose(data_new, data_expected, atol=1e-15) -@pytest.mark.parametrize('data_sep', (b' ', b',', b'+')) +@pytest.mark.parametrize("data_sep", (b" ", b",", b"+")) def test_ascii(tmp_path, data_sep): """Test ASCII BV reading.""" raw = read_raw_brainvision(vhdr_path) @@ -262,33 +273,37 @@ def test_ascii(tmp_path, data_sep): ) # modify header file skipping = False - with open(ascii_vhdr_path, 'wb') as fout: - with open(vhdr_path, 'rb') as fin: + with open(ascii_vhdr_path, "wb") as fout: + with open(vhdr_path, "rb") as fin: for line in fin: # Common Infos section - if line.startswith(b'DataFormat'): - line = b'DataFormat=ASCII\n' - elif line.startswith(b'DataFile='): - line = b'DataFile=test.dat\n' + if line.startswith(b"DataFormat"): + line = b"DataFormat=ASCII\n" + elif line.startswith(b"DataFile="): + line = b"DataFile=test.dat\n" # Replace the "'Binary Infos'" section - elif line.startswith(b'[Binary Infos]'): + elif line.startswith(b"[Binary Infos]"): skipping = True - fout.write(b'[ASCII Infos]\nDecimalSymbol=.\nSkipLines=1\n' - b'SkipColumns=0\n\n') - elif skipping and line.startswith(b'['): + fout.write( + b"[ASCII Infos]\nDecimalSymbol=.\nSkipLines=1\n" + b"SkipColumns=0\n\n" + ) + elif skipping and line.startswith(b"["): skipping = False if not skipping: fout.write(line) # create the .dat file data, times = raw[:] with open(ascii_vhdr_path.with_suffix(".dat"), "wb") as fid: - fid.write(data_sep.join(ch_name.encode('ASCII') - for ch_name in raw.ch_names) + b'\n') - fid.write(b'\n'.join(b' '.join(b'%.3f' % dd for dd in d) - for d in data.T / raw._cals)) + fid.write( + data_sep.join(ch_name.encode("ASCII") for ch_name in raw.ch_names) + b"\n" + ) + fid.write( + b"\n".join(b" ".join(b"%.3f" % dd for dd in d) for d in data.T / raw._cals) + ) - if data_sep == b';': - with pytest.raises(RuntimeError, match='Unknown.*data format'): + if data_sep == b";": + with pytest.raises(RuntimeError, match="Unknown.*data format"): read_raw_brainvision(ascii_vhdr_path) return @@ -309,12 +324,13 @@ def test_ch_names_comma(tmp_path): # Copy existing vhdr file to tmp_path and manipulate to contain # a channel with comma - for src, dest in zip((vhdr_path, vmrk_path, eeg_path), - ('test.vhdr', 'test.vmrk', 'test.eeg')): + for src, dest in zip( + (vhdr_path, vmrk_path, eeg_path), ("test.vhdr", "test.vmrk", "test.eeg") + ): shutil.copyfile(src, tmp_path / dest) - comma_vhdr = tmp_path / 'test.vhdr' - with open(comma_vhdr, 'r') as fin: + comma_vhdr = tmp_path / "test.vhdr" + with open(comma_vhdr, "r") as fin: lines = fin.readlines() new_lines = [] @@ -331,7 +347,7 @@ def test_ch_names_comma(tmp_path): new_lines.append(line) assert nperformed_replacements == len(replace_dict) - with open(comma_vhdr, 'w') as fout: + with open(comma_vhdr, "w") as fout: fout.writelines(new_lines) # Read the line containing a "comma channel name" @@ -339,62 +355,58 @@ def test_ch_names_comma(tmp_path): assert "F4,foo" in raw.ch_names -@pytest.mark.filterwarnings('ignore:.*different.*:RuntimeWarning') +@pytest.mark.filterwarnings("ignore:.*different.*:RuntimeWarning") def test_brainvision_data_highpass_filters(): """Test reading raw Brain Vision files with amplifier filter settings.""" # Homogeneous highpass in seconds (default measurement unit) - raw = _test_raw_reader( - read_raw_brainvision, vhdr_fname=vhdr_highpass_path, eog=eog - ) + raw = _test_raw_reader(read_raw_brainvision, vhdr_fname=vhdr_highpass_path, eog=eog) - assert raw.info['highpass'] == 1. / (2 * np.pi * 10) - assert raw.info['lowpass'] == 250. + assert raw.info["highpass"] == 1.0 / (2 * np.pi * 10) + assert raw.info["lowpass"] == 250.0 # Heterogeneous highpass in seconds (default measurement unit) - with pytest.warns(RuntimeWarning, match='different .*pass filters') as w: + with pytest.warns(RuntimeWarning, match="different .*pass filters") as w: raw = _test_raw_reader( - read_raw_brainvision, vhdr_fname=vhdr_mixed_highpass_path, - eog=eog) + read_raw_brainvision, vhdr_fname=vhdr_mixed_highpass_path, eog=eog + ) w = [str(ww.message) for ww in w] - assert not any('different lowpass filters' in ww for ww in w), w - assert all('different highpass filters' in ww for ww in w), w + assert not any("different lowpass filters" in ww for ww in w), w + assert all("different highpass filters" in ww for ww in w), w - assert raw.info['highpass'] == 1. / (2 * np.pi * 10) - assert raw.info['lowpass'] == 250. + assert raw.info["highpass"] == 1.0 / (2 * np.pi * 10) + assert raw.info["lowpass"] == 250.0 # Homogeneous highpass in Hertz raw = _test_raw_reader( - read_raw_brainvision, vhdr_fname=vhdr_highpass_hz_path, - eog=eog) + read_raw_brainvision, vhdr_fname=vhdr_highpass_hz_path, eog=eog + ) - assert raw.info['highpass'] == 10. - assert raw.info['lowpass'] == 250. + assert raw.info["highpass"] == 10.0 + assert raw.info["lowpass"] == 250.0 # Heterogeneous highpass in Hertz - with pytest.warns(RuntimeWarning, match='different .*pass filters') as w: + with pytest.warns(RuntimeWarning, match="different .*pass filters") as w: raw = _test_raw_reader( - read_raw_brainvision, vhdr_fname=vhdr_mixed_highpass_hz_path, - eog=eog) + read_raw_brainvision, vhdr_fname=vhdr_mixed_highpass_hz_path, eog=eog + ) w = [str(ww.message) for ww in w] - assert not any('will be dropped' in ww for ww in w), w - assert not any('different lowpass filters' in ww for ww in w), w - assert all('different highpass filters' in ww for ww in w), w + assert not any("will be dropped" in ww for ww in w), w + assert not any("different lowpass filters" in ww for ww in w), w + assert all("different highpass filters" in ww for ww in w), w - assert raw.info['highpass'] == 5. - assert raw.info['lowpass'] == 250. + assert raw.info["highpass"] == 5.0 + assert raw.info["lowpass"] == 250.0 def test_brainvision_data_lowpass_filters(): """Test files with amplifier LP filter settings.""" # Homogeneous lowpass in Hertz (default measurement unit) - raw = _test_raw_reader( - read_raw_brainvision, vhdr_fname=vhdr_lowpass_path, eog=eog - ) + raw = _test_raw_reader(read_raw_brainvision, vhdr_fname=vhdr_lowpass_path, eog=eog) - assert raw.info['highpass'] == 1. / (2 * np.pi * 10) - assert raw.info['lowpass'] == 250. + assert raw.info["highpass"] == 1.0 / (2 * np.pi * 10) + assert raw.info["lowpass"] == 250.0 # Heterogeneous lowpass in Hertz (default measurement unit) with pytest.warns(RuntimeWarning) as w: # event parsing @@ -402,25 +414,23 @@ def test_brainvision_data_lowpass_filters(): read_raw_brainvision, vhdr_fname=vhdr_mixed_lowpass_path, eog=eog ) - lowpass_warning = ['different lowpass filters' in str(ww.message) - for ww in w] - highpass_warning = ['different highpass filters' in str(ww.message) - for ww in w] + lowpass_warning = ["different lowpass filters" in str(ww.message) for ww in w] + highpass_warning = ["different highpass filters" in str(ww.message) for ww in w] expected_warnings = zip(lowpass_warning, highpass_warning) - assert (all(any([lp, hp]) for lp, hp in expected_warnings)) + assert all(any([lp, hp]) for lp, hp in expected_warnings) - assert raw.info['highpass'] == 1. / (2 * np.pi * 10) - assert raw.info['lowpass'] == 250. + assert raw.info["highpass"] == 1.0 / (2 * np.pi * 10) + assert raw.info["lowpass"] == 250.0 # Homogeneous lowpass in seconds raw = _test_raw_reader( read_raw_brainvision, vhdr_fname=vhdr_lowpass_s_path, eog=eog ) - assert raw.info['highpass'] == 1. / (2 * np.pi * 10) - assert raw.info['lowpass'] == 1. / (2 * np.pi * 0.004) + assert raw.info["highpass"] == 1.0 / (2 * np.pi * 10) + assert raw.info["lowpass"] == 1.0 / (2 * np.pi * 0.004) # Heterogeneous lowpass in seconds with pytest.warns(RuntimeWarning) as w: # filter settings @@ -428,17 +438,15 @@ def test_brainvision_data_lowpass_filters(): read_raw_brainvision, vhdr_fname=vhdr_mixed_lowpass_s_path, eog=eog ) - lowpass_warning = ['different lowpass filters' in str(ww.message) - for ww in w] - highpass_warning = ['different highpass filters' in str(ww.message) - for ww in w] + lowpass_warning = ["different lowpass filters" in str(ww.message) for ww in w] + highpass_warning = ["different highpass filters" in str(ww.message) for ww in w] expected_warnings = zip(lowpass_warning, highpass_warning) - assert (all(any([lp, hp]) for lp, hp in expected_warnings)) + assert all(any([lp, hp]) for lp, hp in expected_warnings) - assert raw.info['highpass'] == 1. / (2 * np.pi * 10) - assert raw.info['lowpass'] == 1. / (2 * np.pi * 0.004) + assert raw.info["highpass"] == 1.0 / (2 * np.pi * 10) + assert raw.info["lowpass"] == 1.0 / (2 * np.pi * 0.004) def test_brainvision_data_partially_disabled_hw_filters(): @@ -446,148 +454,155 @@ def test_brainvision_data_partially_disabled_hw_filters(): with pytest.warns(RuntimeWarning) as w: # event parsing raw = _test_raw_reader( read_raw_brainvision, - vhdr_fname=vhdr_partially_disabled_hw_filter_path, eog=eog + vhdr_fname=vhdr_partially_disabled_hw_filter_path, + eog=eog, ) - trigger_warning = ['will be dropped' in str(ww.message) - for ww in w] - lowpass_warning = ['different lowpass filters' in str(ww.message) - for ww in w] - highpass_warning = ['different highpass filters' in str(ww.message) - for ww in w] + trigger_warning = ["will be dropped" in str(ww.message) for ww in w] + lowpass_warning = ["different lowpass filters" in str(ww.message) for ww in w] + highpass_warning = ["different highpass filters" in str(ww.message) for ww in w] expected_warnings = zip(trigger_warning, lowpass_warning, highpass_warning) - assert (all(any([trg, lp, hp]) for trg, lp, hp in expected_warnings)) + assert all(any([trg, lp, hp]) for trg, lp, hp in expected_warnings) - assert raw.info['highpass'] == 0. - assert raw.info['lowpass'] == 500. + assert raw.info["highpass"] == 0.0 + assert raw.info["lowpass"] == 500.0 def test_brainvision_data_software_filters_latin1_global_units(): """Test reading raw Brain Vision files.""" - with pytest.warns(RuntimeWarning, match='software filter'): + with pytest.warns(RuntimeWarning, match="software filter"): raw = _test_raw_reader( - read_raw_brainvision, vhdr_fname=vhdr_old_path, - eog=("VEOGo", "VEOGu", "HEOGli", "HEOGre"), misc=("A2",)) + read_raw_brainvision, + vhdr_fname=vhdr_old_path, + eog=("VEOGo", "VEOGu", "HEOGli", "HEOGre"), + misc=("A2",), + ) - assert raw.info['highpass'] == 1. / (2 * np.pi * 0.9) - assert raw.info['lowpass'] == 50. + assert raw.info["highpass"] == 1.0 / (2 * np.pi * 0.9) + assert raw.info["lowpass"] == 50.0 # test sensor name with spaces (#9299) - with pytest.warns(RuntimeWarning, match='software filter'): + with pytest.warns(RuntimeWarning, match="software filter"): raw = _test_raw_reader( - read_raw_brainvision, vhdr_fname=vhdr_old_longname_path, - eog=("VEOGo", "VEOGu", "HEOGli", "HEOGre"), misc=("A2",)) + read_raw_brainvision, + vhdr_fname=vhdr_old_longname_path, + eog=("VEOGo", "VEOGu", "HEOGli", "HEOGre"), + misc=("A2",), + ) - assert raw.info['highpass'] == 1. / (2 * np.pi * 0.9) - assert raw.info['lowpass'] == 50. + assert raw.info["highpass"] == 1.0 / (2 * np.pi * 0.9) + assert raw.info["lowpass"] == 50.0 def test_brainvision_data(): """Test reading raw Brain Vision files.""" pytest.raises(OSError, read_raw_brainvision, vmrk_path) - pytest.raises(ValueError, read_raw_brainvision, vhdr_path, - preload=True, scale="foo") + pytest.raises( + ValueError, read_raw_brainvision, vhdr_path, preload=True, scale="foo" + ) raw_py = _test_raw_reader( - read_raw_brainvision, vhdr_fname=vhdr_path, eog=eog, misc='auto' + read_raw_brainvision, vhdr_fname=vhdr_path, eog=eog, misc="auto" ) - assert ('RawBrainVision' in repr(raw_py)) + assert "RawBrainVision" in repr(raw_py) - assert raw_py.info['highpass'] == 0. - assert raw_py.info['lowpass'] == 250. + assert raw_py.info["highpass"] == 0.0 + assert raw_py.info["lowpass"] == 250.0 - picks = pick_types(raw_py.info, meg=False, eeg=True, exclude='bads') + picks = pick_types(raw_py.info, meg=False, eeg=True, exclude="bads") data_py, times_py = raw_py[picks] # compare with a file that was generated using MNE-C raw_bin = read_raw_fif(eeg_bin, preload=True) - picks = pick_types(raw_py.info, meg=False, eeg=True, exclude='bads') + picks = pick_types(raw_py.info, meg=False, eeg=True, exclude="bads") data_bin, times_bin = raw_bin[picks] assert_allclose(data_py, data_bin) assert_allclose(times_py, times_bin) # Make sure EOG channels are marked correctly - for ch in raw_py.info['chs']: - if ch['ch_name'] in eog: - assert ch['kind'] == FIFF.FIFFV_EOG_CH - elif ch['ch_name'] == 'STI 014': - assert ch['kind'] == FIFF.FIFFV_STIM_CH - elif ch['ch_name'] in ('CP5', 'CP6'): - assert ch['kind'] == FIFF.FIFFV_MISC_CH - assert ch['unit'] == FIFF.FIFF_UNIT_NONE - elif ch['ch_name'] == 'ReRef': - assert ch['kind'] == FIFF.FIFFV_MISC_CH - assert ch['unit'] == FIFF.FIFF_UNIT_CEL - elif ch['ch_name'] in raw_py.info['ch_names']: - assert ch['kind'] == FIFF.FIFFV_EEG_CH - assert ch['unit'] == FIFF.FIFF_UNIT_V + for ch in raw_py.info["chs"]: + if ch["ch_name"] in eog: + assert ch["kind"] == FIFF.FIFFV_EOG_CH + elif ch["ch_name"] == "STI 014": + assert ch["kind"] == FIFF.FIFFV_STIM_CH + elif ch["ch_name"] in ("CP5", "CP6"): + assert ch["kind"] == FIFF.FIFFV_MISC_CH + assert ch["unit"] == FIFF.FIFF_UNIT_NONE + elif ch["ch_name"] == "ReRef": + assert ch["kind"] == FIFF.FIFFV_MISC_CH + assert ch["unit"] == FIFF.FIFF_UNIT_CEL + elif ch["ch_name"] in raw_py.info["ch_names"]: + assert ch["kind"] == FIFF.FIFFV_EEG_CH + assert ch["unit"] == FIFF.FIFF_UNIT_V else: - raise RuntimeError("Unknown Channel: %s" % ch['ch_name']) + raise RuntimeError("Unknown Channel: %s" % ch["ch_name"]) # test loading v2 - read_raw_brainvision(vhdr_v2_path, eog=eog, preload=True, - verbose='error') + read_raw_brainvision(vhdr_v2_path, eog=eog, preload=True, verbose="error") # test different units with alternative header file raw_units = _test_raw_reader( - read_raw_brainvision, vhdr_fname=vhdr_units_path, eog=eog, misc='auto' + read_raw_brainvision, vhdr_fname=vhdr_units_path, eog=eog, misc="auto" ) - assert raw_units.info['chs'][0]['ch_name'] == 'FP1' - assert raw_units.info['chs'][0]['kind'] == FIFF.FIFFV_EEG_CH + assert raw_units.info["chs"][0]["ch_name"] == "FP1" + assert raw_units.info["chs"][0]["kind"] == FIFF.FIFFV_EEG_CH data_units, _ = raw_units[0] assert_allclose(data_py[0, :], data_units.squeeze()) - assert raw_units.info['chs'][1]['ch_name'] == 'FP2' - assert raw_units.info['chs'][1]['kind'] == FIFF.FIFFV_EEG_CH + assert raw_units.info["chs"][1]["ch_name"] == "FP2" + assert raw_units.info["chs"][1]["kind"] == FIFF.FIFFV_EEG_CH data_units, _ = raw_units[1] assert_allclose(data_py[1, :], data_units.squeeze()) - assert raw_units.info['chs'][2]['ch_name'] == 'F3' - assert raw_units.info['chs'][2]['kind'] == FIFF.FIFFV_EEG_CH + assert raw_units.info["chs"][2]["ch_name"] == "F3" + assert raw_units.info["chs"][2]["kind"] == FIFF.FIFFV_EEG_CH data_units, _ = raw_units[2] assert_allclose(data_py[2, :], data_units.squeeze()) def test_brainvision_vectorized_data(): """Test reading BrainVision data files with vectorized data.""" - with pytest.warns(RuntimeWarning, match='software filter'): + with pytest.warns(RuntimeWarning, match="software filter"): raw = read_raw_brainvision(vhdr_old_path, preload=True) assert_array_equal(raw._data.shape, (29, 251)) - first_two_samples_all_chs = np.array([[+5.22000008e-06, +5.10000000e-06], - [+2.10000000e-06, +2.27000008e-06], - [+1.15000000e-06, +1.33000002e-06], - [+4.00000000e-07, +4.00000000e-07], - [-3.02999992e-06, -2.82000008e-06], - [+2.71000004e-06, +2.45000000e-06], - [+2.41000004e-06, +2.36000004e-06], - [+1.01999998e-06, +1.18000002e-06], - [-1.33999996e-06, -1.25000000e-06], - [-2.60000000e-06, -2.46000004e-06], - [+6.80000019e-07, +8.00000000e-07], - [+1.48000002e-06, +1.48999996e-06], - [+1.61000004e-06, +1.51000004e-06], - [+7.19999981e-07, +8.60000038e-07], - [-3.00000000e-07, -4.00000006e-08], - [-1.20000005e-07, +6.00000024e-08], - [+8.19999981e-07, +9.89999962e-07], - [+1.13000002e-06, +1.28000002e-06], - [+1.08000002e-06, +1.33999996e-06], - [+2.20000005e-07, +5.69999981e-07], - [-4.09999990e-07, +4.00000006e-08], - [+5.19999981e-07, +9.39999962e-07], - [+1.01000004e-06, +1.51999998e-06], - [+1.01000004e-06, +1.55000000e-06], - [-1.43000002e-06, -1.13999996e-06], - [+3.65000000e-06, +3.65999985e-06], - [+4.15999985e-06, +3.79000015e-06], - [+9.26999969e-06, +8.95999985e-06], - [-7.35999985e-06, -7.18000031e-06], - ]) + first_two_samples_all_chs = np.array( + [ + [+5.22000008e-06, +5.10000000e-06], + [+2.10000000e-06, +2.27000008e-06], + [+1.15000000e-06, +1.33000002e-06], + [+4.00000000e-07, +4.00000000e-07], + [-3.02999992e-06, -2.82000008e-06], + [+2.71000004e-06, +2.45000000e-06], + [+2.41000004e-06, +2.36000004e-06], + [+1.01999998e-06, +1.18000002e-06], + [-1.33999996e-06, -1.25000000e-06], + [-2.60000000e-06, -2.46000004e-06], + [+6.80000019e-07, +8.00000000e-07], + [+1.48000002e-06, +1.48999996e-06], + [+1.61000004e-06, +1.51000004e-06], + [+7.19999981e-07, +8.60000038e-07], + [-3.00000000e-07, -4.00000006e-08], + [-1.20000005e-07, +6.00000024e-08], + [+8.19999981e-07, +9.89999962e-07], + [+1.13000002e-06, +1.28000002e-06], + [+1.08000002e-06, +1.33999996e-06], + [+2.20000005e-07, +5.69999981e-07], + [-4.09999990e-07, +4.00000006e-08], + [+5.19999981e-07, +9.39999962e-07], + [+1.01000004e-06, +1.51999998e-06], + [+1.01000004e-06, +1.55000000e-06], + [-1.43000002e-06, -1.13999996e-06], + [+3.65000000e-06, +3.65999985e-06], + [+4.15999985e-06, +3.79000015e-06], + [+9.26999969e-06, +8.95999985e-06], + [-7.35999985e-06, -7.18000031e-06], + ] + ) assert_allclose(raw._data[:, :2], first_two_samples_all_chs) @@ -595,13 +610,13 @@ def test_brainvision_vectorized_data(): def test_coodinates_extraction(): """Test reading of [Coordinates] section if present.""" # vhdr 2 has a Coordinates section - with pytest.warns(RuntimeWarning, match='coordinate information'): + with pytest.warns(RuntimeWarning, match="coordinate information"): raw = read_raw_brainvision(vhdr_v2_path) # Basic check of extracted coordinates - assert raw.info['dig'] is not None - diglist = raw.info['dig'] - coords = np.array([dig['r'] for dig in diglist]) + assert raw.info["dig"] is not None + diglist = raw.info["dig"] + coords = np.array([dig["r"] for dig in diglist]) EXPECTED_SHAPE = ( # HL, HR, Vb, ReRef are not set in dig # but LPA, Nasion, RPA are estimated @@ -616,16 +631,16 @@ def test_coodinates_extraction(): # vhdr 1 does not have a Coordinates section raw2 = read_raw_brainvision(vhdr_path) - assert raw2.info['dig'] is None + assert raw2.info["dig"] is None @testing.requires_testing_data def test_brainvision_neuroone_export(): """Test Brainvision file exported with neuroone system.""" - raw = read_raw_brainvision(neuroone_vhdr, verbose='error') - assert raw.info['meas_date'] is None - assert len(raw.info['chs']) == 65 - assert raw.info['sfreq'] == 5000. + raw = read_raw_brainvision(neuroone_vhdr, verbose="error") + assert raw.info["meas_date"] is None + assert len(raw.info["chs"]) == 65 + assert raw.info["sfreq"] == 5000.0 @testing.requires_testing_data @@ -637,8 +652,8 @@ def test_read_vmrk_annotations(tmp_path): # delete=False is for Windows compatibility with open(vmrk_path) as myfile: head = [next(myfile) for x in range(6)] - fname = tmp_path / 'temp.vmrk' - with open(str(fname), 'w') as temp: + fname = tmp_path / "temp.vmrk" + with open(str(fname), "w") as temp: for item in head: temp.write(item) read_annotations(fname, sfreq=sfreq) @@ -648,38 +663,92 @@ def test_read_vmrk_annotations(tmp_path): def test_read_vhdr_annotations_and_events(tmp_path): """Test load brainvision annotations and parse them to events.""" # First we add a custom event that contains a comma in its description - for src, dest in zip((vhdr_path, vmrk_path, eeg_path), - ('test.vhdr', 'test.vmrk', 'test.eeg')): + for src, dest in zip( + (vhdr_path, vmrk_path, eeg_path), ("test.vhdr", "test.vmrk", "test.eeg") + ): shutil.copyfile(src, tmp_path / dest) # Commas are encoded as "\1" - with open(tmp_path / 'test.vmrk', 'a') as fout: + with open(tmp_path / "test.vmrk", "a") as fout: fout.write(r"Mk15=Comma\1Type,CommaValue\11,7800,1,0\n") sfreq = 1000.0 expected_orig_time = _stamp_to_dt((1384359243, 794232)) expected_onset_latency = np.array( - [0, 486., 496., 1769., 1779., 3252., 3262., 4935., 4945., 5999., 6619., - 6629., 7629., 7699., 7799.] + [ + 0, + 486.0, + 496.0, + 1769.0, + 1779.0, + 3252.0, + 3262.0, + 4935.0, + 4945.0, + 5999.0, + 6619.0, + 6629.0, + 7629.0, + 7699.0, + 7799.0, + ] ) expected_annot_description = [ - 'New Segment/', 'Stimulus/S253', 'Stimulus/S255', 'Event/254', - 'Stimulus/S255', 'Event/254', 'Stimulus/S255', 'Stimulus/S253', - 'Stimulus/S255', 'Response/R255', 'Event/254', 'Stimulus/S255', - 'SyncStatus/Sync On', 'Optic/O 1', 'Comma,Type/CommaValue,1' + "New Segment/", + "Stimulus/S253", + "Stimulus/S255", + "Event/254", + "Stimulus/S255", + "Event/254", + "Stimulus/S255", + "Stimulus/S253", + "Stimulus/S255", + "Response/R255", + "Event/254", + "Stimulus/S255", + "SyncStatus/Sync On", + "Optic/O 1", + "Comma,Type/CommaValue,1", ] - expected_events = np.stack([ - expected_onset_latency, - np.zeros_like(expected_onset_latency), - [99999, 253, 255, 254, 255, 254, 255, 253, 255, 1255, 254, 255, 99998, - 2001, 10001], - ]).astype('int64').T - expected_event_id = {'New Segment/': 99999, 'Stimulus/S253': 253, - 'Stimulus/S255': 255, 'Event/254': 254, - 'Response/R255': 1255, 'SyncStatus/Sync On': 99998, - 'Optic/O 1': 2001, 'Comma,Type/CommaValue,1': 10001} - - raw = read_raw_brainvision(tmp_path / 'test.vhdr', eog=eog) + expected_events = ( + np.stack( + [ + expected_onset_latency, + np.zeros_like(expected_onset_latency), + [ + 99999, + 253, + 255, + 254, + 255, + 254, + 255, + 253, + 255, + 1255, + 254, + 255, + 99998, + 2001, + 10001, + ], + ] + ) + .astype("int64") + .T + ) + expected_event_id = { + "New Segment/": 99999, + "Stimulus/S253": 253, + "Stimulus/S255": 255, + "Event/254": 254, + "Response/R255": 1255, + "SyncStatus/Sync On": 99998, + "Optic/O 1": 2001, + "Comma,Type/CommaValue,1": 10001, + } + + raw = read_raw_brainvision(tmp_path / "test.vhdr", eog=eog) # validate annotations assert raw.annotations.orig_time == expected_orig_time @@ -692,14 +761,15 @@ def test_read_vhdr_annotations_and_events(tmp_path): assert event_id == expected_event_id # validate that None gives us a sorted list - expected_none_event_id = {desc: idx + 1 for idx, desc in enumerate(sorted( - event_id.keys()))} + expected_none_event_id = { + desc: idx + 1 for idx, desc in enumerate(sorted(event_id.keys())) + } events, event_id = events_from_annotations(raw, event_id=None) assert event_id == expected_none_event_id # Add some custom ones, plus a 2-digit one - s_10 = 'Stimulus/S 10' - raw.annotations.append([1, 2, 3], 10, ['ZZZ', s_10, 'YYY']) + s_10 = "Stimulus/S 10" + raw.annotations.append([1, 2, 3], 10, ["ZZZ", s_10, "YYY"]) # others starting at 10001 ... # we already have "Comma,Type/CommaValue,1" as 10001 expected_event_id.update(YYY=10002, ZZZ=10003) @@ -709,7 +779,7 @@ def test_read_vhdr_annotations_and_events(tmp_path): # Concatenating two shouldn't change the resulting event_id # (BAD and EDGE should be ignored) - with pytest.warns(RuntimeWarning, match='expanding outside'): + with pytest.warns(RuntimeWarning, match="expanding outside"): raw_concat = concatenate_raws([raw.copy(), raw.copy()]) _, event_id = events_from_annotations(raw_concat) assert event_id == expected_event_id @@ -718,14 +788,16 @@ def test_read_vhdr_annotations_and_events(tmp_path): @testing.requires_testing_data def test_automatic_vmrk_sfreq_recovery(): """Test proper sfreq inference by checking the onsets.""" - assert_array_equal(read_annotations(vmrk_path, sfreq='auto'), - read_annotations(vmrk_path, sfreq=1000.0)) + assert_array_equal( + read_annotations(vmrk_path, sfreq="auto"), + read_annotations(vmrk_path, sfreq=1000.0), + ) @testing.requires_testing_data def test_event_id_stability_when_save_and_fif_reload(tmp_path): """Test load events from brainvision annotations when read_raw_fif.""" - fname = tmp_path / 'bv-raw.fif' + fname = tmp_path / "bv-raw.fif" raw = read_raw_brainvision(vhdr_path, eog=eog) original_events, original_event_id = events_from_annotations(raw) @@ -739,55 +811,109 @@ def test_event_id_stability_when_save_and_fif_reload(tmp_path): def test_parse_impedance(): """Test case for parsing the impedances from header.""" - expected_imp_meas_time = datetime.datetime(2013, 11, 13, 16, 12, 27, - tzinfo=datetime.timezone.utc) - expected_imp_unit = 'kOhm' + expected_imp_meas_time = datetime.datetime( + 2013, 11, 13, 16, 12, 27, tzinfo=datetime.timezone.utc + ) + expected_imp_unit = "kOhm" expected_electrodes = [ - 'FP1', 'FP2', 'F3', 'F4', 'C3', 'C4', 'P3', 'P4', 'O1', 'O2', 'F7', - 'F8', 'P7', 'P8', 'Fz', 'FCz', 'Cz', 'CPz', 'Pz', 'POz', 'FC1', 'FC2', - 'CP1', 'CP2', 'FC5', 'FC6', 'CP5', 'CP6', 'HL', 'HR', 'Vb', 'ReRef', - 'Ref', 'Gnd' + "FP1", + "FP2", + "F3", + "F4", + "C3", + "C4", + "P3", + "P4", + "O1", + "O2", + "F7", + "F8", + "P7", + "P8", + "Fz", + "FCz", + "Cz", + "CPz", + "Pz", + "POz", + "FC1", + "FC2", + "CP1", + "CP2", + "FC5", + "FC6", + "CP5", + "CP6", + "HL", + "HR", + "Vb", + "ReRef", + "Ref", + "Gnd", ] n_electrodes = len(expected_electrodes) - expected_imps = [np.nan] * (n_electrodes - 2) + [0., 4.] - expected_imp_lower_bound = 0. - expected_imp_upper_bound = [100.] * (n_electrodes - 2) + [10., 10.] - - expected_impedances = {elec: { - 'imp': expected_imps[i], - 'imp_unit': expected_imp_unit, - 'imp_meas_time': expected_imp_meas_time, - 'imp_lower_bound': expected_imp_lower_bound, - 'imp_upper_bound': expected_imp_upper_bound[i], - 'imp_range_unit': expected_imp_unit, - } for i, elec in enumerate(expected_electrodes)} + expected_imps = [np.nan] * (n_electrodes - 2) + [0.0, 4.0] + expected_imp_lower_bound = 0.0 + expected_imp_upper_bound = [100.0] * (n_electrodes - 2) + [10.0, 10.0] + + expected_impedances = { + elec: { + "imp": expected_imps[i], + "imp_unit": expected_imp_unit, + "imp_meas_time": expected_imp_meas_time, + "imp_lower_bound": expected_imp_lower_bound, + "imp_upper_bound": expected_imp_upper_bound[i], + "imp_range_unit": expected_imp_unit, + } + for i, elec in enumerate(expected_electrodes) + } raw = read_raw_brainvision(vhdr_path, eog=eog) - assert object_diff(expected_impedances, raw.impedances) == '' + assert object_diff(expected_impedances, raw.impedances) == "" # Test "Impedances Imported from actiCAP Control Software" - expected_imp_meas_time = expected_imp_meas_time.replace(hour=10, - minute=17, - second=2) - tmpidx = expected_electrodes.index('CP6') + expected_imp_meas_time = expected_imp_meas_time.replace( + hour=10, minute=17, second=2 + ) + tmpidx = expected_electrodes.index("CP6") expected_electrodes = expected_electrodes[:tmpidx] + [ - 'CP 6', 'ECG+', 'ECG-', 'HEOG+', 'HEOG-', 'VEOG+', 'VEOG-', 'ReRef', - 'Ref', 'Gnd' + "CP 6", + "ECG+", + "ECG-", + "HEOG+", + "HEOG-", + "VEOG+", + "VEOG-", + "ReRef", + "Ref", + "Gnd", ] n_electrodes = len(expected_electrodes) expected_imps = [np.nan] * (n_electrodes - 9) + [ - 35., 46., 6., 8., 3., 4., 0., 8., 2.5 + 35.0, + 46.0, + 6.0, + 8.0, + 3.0, + 4.0, + 0.0, + 8.0, + 2.5, ] - expected_impedances = {elec: { - 'imp': expected_imps[i], - 'imp_unit': expected_imp_unit, - 'imp_meas_time': expected_imp_meas_time, - } for i, elec in enumerate(expected_electrodes)} + expected_impedances = { + elec: { + "imp": expected_imps[i], + "imp_unit": expected_imp_unit, + "imp_meas_time": expected_imp_meas_time, + } + for i, elec in enumerate(expected_electrodes) + } - with pytest.warns(RuntimeWarning, match='different .*pass filters'): - raw = read_raw_brainvision(vhdr_mixed_lowpass_path, - eog=['HEOG', 'VEOG'], misc=['ECG']) - assert object_diff(expected_impedances, raw.impedances) == '' + with pytest.warns(RuntimeWarning, match="different .*pass filters"): + raw = read_raw_brainvision( + vhdr_mixed_lowpass_path, eog=["HEOG", "VEOG"], misc=["ECG"] + ) + assert object_diff(expected_impedances, raw.impedances) == "" @testing.requires_testing_data @@ -798,6 +924,6 @@ def test_ahdr_format(): expected_lp = 250.0 raw = read_raw_brainvision(vamp_ahdr) - assert raw.info['nchan'] == expected_num_channels - assert raw.info['highpass'] == expected_hp - assert raw.info['lowpass'] == expected_lp + assert raw.info["nchan"] == expected_num_channels + assert raw.info["highpass"] == expected_hp + assert raw.info["lowpass"] == expected_lp diff --git a/mne/io/bti/bti.py b/mne/io/bti/bti.py index 9f83d63b7c5..d3f4ea42f4c 100644 --- a/mne/io/bti/bti.py +++ b/mne/io/bti/bti.py @@ -16,44 +16,56 @@ import numpy as np from ...utils import logger, verbose, _stamp_to_dt, path_like -from ...transforms import (combine_transforms, invert_transform, - Transform) +from ...transforms import combine_transforms, invert_transform, Transform from .._digitization import _make_bti_dig_points from ..constants import FIFF from .. import BaseRaw, _coil_trans_to_loc, _loc_to_coil_trans, _empty_info from ..utils import _mult_cal_one, read_str from .constants import BTI -from .read import (read_int32, read_int16, read_float, read_double, - read_transform, read_char, read_int64, read_uint16, - read_uint32, read_double_matrix, read_float_matrix, - read_int16_matrix, read_dev_header) - -FIFF_INFO_DIG_FIELDS = ('kind', 'ident', 'r', 'coord_frame') +from .read import ( + read_int32, + read_int16, + read_float, + read_double, + read_transform, + read_char, + read_int64, + read_uint16, + read_uint32, + read_double_matrix, + read_float_matrix, + read_int16_matrix, + read_dev_header, +) + +FIFF_INFO_DIG_FIELDS = ("kind", "ident", "r", "coord_frame") FIFF_INFO_DIG_DEFAULTS = (None, None, None, FIFF.FIFFV_COORD_HEAD) -BTI_WH2500_REF_MAG = ('MxA', 'MyA', 'MzA', 'MxaA', 'MyaA', 'MzaA') -BTI_WH2500_REF_GRAD = ('GxxA', 'GyyA', 'GyxA', 'GzaA', 'GzyA') +BTI_WH2500_REF_MAG = ("MxA", "MyA", "MzA", "MxaA", "MyaA", "MzaA") +BTI_WH2500_REF_GRAD = ("GxxA", "GyyA", "GyxA", "GzaA", "GzyA") -dtypes = zip(list(range(1, 5)), ('>i2', '>i4', '>f4', '>f8')) +dtypes = zip(list(range(1, 5)), (">i2", ">i4", ">f4", ">f8")) DTYPES = {i: np.dtype(t) for i, t in dtypes} def _instantiate_default_info_chs(): """Populate entries in info['chs'] with default values.""" - return dict(loc=np.array([0, 0, 0, 1] * 3, dtype='f4'), - ch_name=None, - unit_mul=FIFF.FIFF_UNITM_NONE, - coord_frame=FIFF.FIFFV_COORD_UNKNOWN, - coil_type=FIFF.FIFFV_COIL_NONE, - range=1.0, - unit=FIFF.FIFF_UNIT_V, - cal=1.0, - scanno=None, - kind=FIFF.FIFFV_MISC_CH, - logno=None) - - -class _bytes_io_mock_context(): + return dict( + loc=np.array([0, 0, 0, 1] * 3, dtype="f4"), + ch_name=None, + unit_mul=FIFF.FIFF_UNITM_NONE, + coord_frame=FIFF.FIFFV_COORD_UNKNOWN, + coil_type=FIFF.FIFFV_COIL_NONE, + range=1.0, + unit=FIFF.FIFF_UNIT_V, + cal=1.0, + scanno=None, + kind=FIFF.FIFFV_MISC_CH, + logno=None, + ) + + +class _bytes_io_mock_context: """Make a context for BytesIO.""" def __init__(self, target): # noqa: D102 @@ -73,10 +85,10 @@ def _bti_open(fname, *args, **kwargs): elif isinstance(fname, BytesIO): return _bytes_io_mock_context(fname) else: - raise RuntimeError('Cannot mock this.') + raise RuntimeError("Cannot mock this.") -def _get_bti_dev_t(adjust=0., translation=(0.0, 0.02, 0.11)): +def _get_bti_dev_t(adjust=0.0, translation=(0.0, 0.02, 0.11)): """Get the general Magnes3600WH to Neuromag coordinate transform. Parameters @@ -93,20 +105,22 @@ def _get_bti_dev_t(adjust=0., translation=(0.0, 0.02, 0.11)): m_nm_t : ndarray 4 x 4 rotation, translation, scaling matrix. """ - flip_t = np.array([[0., -1., 0.], - [1., 0., 0.], - [0., 0., 1.]]) + flip_t = np.array([[0.0, -1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0]]) rad = np.deg2rad(adjust) - adjust_t = np.array([[1., 0., 0.], - [0., np.cos(rad), -np.sin(rad)], - [0., np.sin(rad), np.cos(rad)]]) + adjust_t = np.array( + [ + [1.0, 0.0, 0.0], + [0.0, np.cos(rad), -np.sin(rad)], + [0.0, np.sin(rad), np.cos(rad)], + ] + ) m_nm_t = np.eye(4) m_nm_t[:3, :3] = np.dot(flip_t, adjust_t) m_nm_t[:3, 3] = translation return m_nm_t -def _rename_channels(names, ecg_ch='E31', eog_ch=('E63', 'E64')): +def _rename_channels(names, ecg_ch="E31", eog_ch=("E63", "E64")): """Rename appropriately ordered list of channel names. Parameters @@ -122,26 +136,26 @@ def _rename_channels(names, ecg_ch='E31', eog_ch=('E63', 'E64')): new = list() ref_mag, ref_grad, eog, eeg, ext = [count(1) for _ in range(5)] for i, name in enumerate(names, 1): - if name.startswith('A'): - name = 'MEG %3.3d' % i - elif name == 'RESPONSE': - name = 'STI 013' - elif name == 'TRIGGER': - name = 'STI 014' + if name.startswith("A"): + name = "MEG %3.3d" % i + elif name == "RESPONSE": + name = "STI 013" + elif name == "TRIGGER": + name = "STI 014" elif any(name == k for k in eog_ch): - name = 'EOG %3.3d' % next(eog) + name = "EOG %3.3d" % next(eog) elif name == ecg_ch: - name = 'ECG 001' - elif name.startswith('E'): - name = 'EEG %3.3d' % next(eeg) - elif name == 'UACurrent': - name = 'UTL 001' - elif name.startswith('M'): - name = 'RFM %3.3d' % next(ref_mag) - elif name.startswith('G'): - name = 'RFG %3.3d' % next(ref_grad) - elif name.startswith('X'): - name = 'EXT %3.3d' % next(ext) + name = "ECG 001" + elif name.startswith("E"): + name = "EEG %3.3d" % next(eeg) + elif name == "UACurrent": + name = "UTL 001" + elif name.startswith("M"): + name = "RFM %3.3d" % next(ref_mag) + elif name.startswith("G"): + name = "RFG %3.3d" % next(ref_grad) + elif name.startswith("X"): + name = "EXT %3.3d" % next(ext) new += [name] @@ -151,7 +165,7 @@ def _rename_channels(names, ecg_ch='E31', eog_ch=('E63', 'E64')): # read the points def _read_head_shape(fname): """Read the head shape.""" - with _bti_open(fname, 'rb') as fid: + with _bti_open(fname, "rb") as fid: fid.seek(BTI.FILE_HS_N_DIGPOINTS) _n_dig_points = read_int32(fid) idx_points = read_double_matrix(fid, BTI.DATA_N_IDX_POINTS, 3) @@ -159,32 +173,32 @@ def _read_head_shape(fname): # reorder to lpa, rpa, nasion so = is direct. nasion, lpa, rpa = [idx_points[_, :] for _ in [2, 0, 1]] - hpi = idx_points[3:len(idx_points), :] + hpi = idx_points[3 : len(idx_points), :] return nasion, lpa, rpa, hpi, dig_points def _check_nan_dev_head_t(dev_ctf_t): """Make sure we deal with nans.""" - has_nan = np.isnan(dev_ctf_t['trans']) + has_nan = np.isnan(dev_ctf_t["trans"]) if np.any(has_nan): - logger.info('Missing values BTI dev->head transform. ' - 'Replacing with identity matrix.') - dev_ctf_t['trans'] = np.identity(4) + logger.info( + "Missing values BTI dev->head transform. " "Replacing with identity matrix." + ) + dev_ctf_t["trans"] = np.identity(4) def _convert_coil_trans(coil_trans, dev_ctf_t, bti_dev_t): """Convert the coil trans.""" - t = combine_transforms(invert_transform(dev_ctf_t), bti_dev_t, - 'ctf_head', 'meg') - t = np.dot(t['trans'], coil_trans) + t = combine_transforms(invert_transform(dev_ctf_t), bti_dev_t, "ctf_head", "meg") + t = np.dot(t["trans"], coil_trans) return t def _correct_offset(fid): """Align fid pointer.""" current = fid.tell() - if ((current % BTI.FILE_CURPOS) != 0): + if (current % BTI.FILE_CURPOS) != 0: offset = current % BTI.FILE_CURPOS fid.seek(BTI.FILE_CURPOS - (offset), 1) @@ -202,346 +216,408 @@ def _read_config(fname): cfg : dict The config blocks found. """ - with _bti_open(fname, 'rb') as fid: + with _bti_open(fname, "rb") as fid: cfg = dict() - cfg['hdr'] = {'version': read_int16(fid), - 'site_name': read_str(fid, 32), - 'dap_hostname': read_str(fid, 16), - 'sys_type': read_int16(fid), - 'sys_options': read_int32(fid), - 'supply_freq': read_int16(fid), - 'total_chans': read_int16(fid), - 'system_fixed_gain': read_float(fid), - 'volts_per_bit': read_float(fid), - 'total_sensors': read_int16(fid), - 'total_user_blocks': read_int16(fid), - 'next_der_chan_no': read_int16(fid)} + cfg["hdr"] = { + "version": read_int16(fid), + "site_name": read_str(fid, 32), + "dap_hostname": read_str(fid, 16), + "sys_type": read_int16(fid), + "sys_options": read_int32(fid), + "supply_freq": read_int16(fid), + "total_chans": read_int16(fid), + "system_fixed_gain": read_float(fid), + "volts_per_bit": read_float(fid), + "total_sensors": read_int16(fid), + "total_user_blocks": read_int16(fid), + "next_der_chan_no": read_int16(fid), + } fid.seek(2, 1) - cfg['checksum'] = read_uint32(fid) - cfg['reserved'] = read_char(fid, 32) - cfg['transforms'] = [read_transform(fid) for t in - range(cfg['hdr']['total_sensors'])] + cfg["checksum"] = read_uint32(fid) + cfg["reserved"] = read_char(fid, 32) + cfg["transforms"] = [ + read_transform(fid) for t in range(cfg["hdr"]["total_sensors"]) + ] - cfg['user_blocks'] = dict() - for block in range(cfg['hdr']['total_user_blocks']): + cfg["user_blocks"] = dict() + for block in range(cfg["hdr"]["total_user_blocks"]): ub = dict() - ub['hdr'] = {'nbytes': read_uint32(fid), - 'kind': read_str(fid, 20), - 'checksum': read_int32(fid), - 'username': read_str(fid, 32), - 'timestamp': read_uint32(fid), - 'user_space_size': read_uint32(fid), - 'reserved': read_char(fid, 32)} + ub["hdr"] = { + "nbytes": read_uint32(fid), + "kind": read_str(fid, 20), + "checksum": read_int32(fid), + "username": read_str(fid, 32), + "timestamp": read_uint32(fid), + "user_space_size": read_uint32(fid), + "reserved": read_char(fid, 32), + } _correct_offset(fid) start_bytes = fid.tell() - kind = ub['hdr'].pop('kind') + kind = ub["hdr"].pop("kind") if not kind: # make sure reading goes right. Should never be empty - raise RuntimeError('Could not read user block. Probably you ' - 'acquired data using a BTi version ' - 'currently not supported. Please contact ' - 'the mne-python developers.') - dta, cfg['user_blocks'][kind] = dict(), ub - if kind in [v for k, v in BTI.items() if k[:5] == 'UB_B_']: + raise RuntimeError( + "Could not read user block. Probably you " + "acquired data using a BTi version " + "currently not supported. Please contact " + "the mne-python developers." + ) + dta, cfg["user_blocks"][kind] = dict(), ub + if kind in [v for k, v in BTI.items() if k[:5] == "UB_B_"]: if kind == BTI.UB_B_MAG_INFO: - dta['version'] = read_int32(fid) + dta["version"] = read_int32(fid) fid.seek(20, 1) - dta['headers'] = list() + dta["headers"] = list() for hdr in range(6): - d = {'name': read_str(fid, 16), - 'transform': read_transform(fid), - 'units_per_bit': read_float(fid)} - dta['headers'] += [d] + d = { + "name": read_str(fid, 16), + "transform": read_transform(fid), + "units_per_bit": read_float(fid), + } + dta["headers"] += [d] fid.seek(20, 1) elif kind == BTI.UB_B_COH_POINTS: - dta['n_points'] = read_int32(fid) - dta['status'] = read_int32(fid) - dta['points'] = [] + dta["n_points"] = read_int32(fid) + dta["status"] = read_int32(fid) + dta["points"] = [] for pnt in range(16): - d = {'pos': read_double_matrix(fid, 1, 3), - 'direction': read_double_matrix(fid, 1, 3), - 'error': read_double(fid)} - dta['points'] += [d] + d = { + "pos": read_double_matrix(fid, 1, 3), + "direction": read_double_matrix(fid, 1, 3), + "error": read_double(fid), + } + dta["points"] += [d] elif kind == BTI.UB_B_CCP_XFM_BLOCK: - dta['method'] = read_int32(fid) + dta["method"] = read_int32(fid) # handle difference btw/ linux (0) and solaris (4) - size = 0 if ub['hdr']['user_space_size'] == 132 else 4 + size = 0 if ub["hdr"]["user_space_size"] == 132 else 4 fid.seek(size, 1) - dta['transform'] = read_transform(fid) + dta["transform"] = read_transform(fid) elif kind == BTI.UB_B_EEG_LOCS: - dta['electrodes'] = [] + dta["electrodes"] = [] while True: - d = {'label': read_str(fid, 16), - 'location': read_double_matrix(fid, 1, 3)} - if not d['label']: + d = { + "label": read_str(fid, 16), + "location": read_double_matrix(fid, 1, 3), + } + if not d["label"]: break - dta['electrodes'] += [d] + dta["electrodes"] += [d] - elif kind in [BTI.UB_B_WHC_CHAN_MAP_VER, - BTI.UB_B_WHS_SUBSYS_VER]: - dta['version'] = read_int16(fid) - dta['struct_size'] = read_int16(fid) - dta['entries'] = read_int16(fid) + elif kind in [BTI.UB_B_WHC_CHAN_MAP_VER, BTI.UB_B_WHS_SUBSYS_VER]: + dta["version"] = read_int16(fid) + dta["struct_size"] = read_int16(fid) + dta["entries"] = read_int16(fid) fid.seek(8, 1) elif kind == BTI.UB_B_WHC_CHAN_MAP: num_channels = None - for name, data in cfg['user_blocks'].items(): + for name, data in cfg["user_blocks"].items(): if name == BTI.UB_B_WHC_CHAN_MAP_VER: - num_channels = data['entries'] + num_channels = data["entries"] break if num_channels is None: - raise ValueError('Cannot find block %s to determine ' - 'number of channels' - % BTI.UB_B_WHC_CHAN_MAP_VER) + raise ValueError( + "Cannot find block %s to determine " + "number of channels" % BTI.UB_B_WHC_CHAN_MAP_VER + ) - dta['channels'] = list() + dta["channels"] = list() for i in range(num_channels): - d = {'subsys_type': read_int16(fid), - 'subsys_num': read_int16(fid), - 'card_num': read_int16(fid), - 'chan_num': read_int16(fid), - 'recdspnum': read_int16(fid)} - dta['channels'] += [d] + d = { + "subsys_type": read_int16(fid), + "subsys_num": read_int16(fid), + "card_num": read_int16(fid), + "chan_num": read_int16(fid), + "recdspnum": read_int16(fid), + } + dta["channels"] += [d] fid.seek(8, 1) elif kind == BTI.UB_B_WHS_SUBSYS: num_subsys = None - for name, data in cfg['user_blocks'].items(): + for name, data in cfg["user_blocks"].items(): if name == BTI.UB_B_WHS_SUBSYS_VER: - num_subsys = data['entries'] + num_subsys = data["entries"] break if num_subsys is None: - raise ValueError('Cannot find block %s to determine' - ' number of subsystems' - % BTI.UB_B_WHS_SUBSYS_VER) + raise ValueError( + "Cannot find block %s to determine" + " number of subsystems" % BTI.UB_B_WHS_SUBSYS_VER + ) - dta['subsys'] = list() + dta["subsys"] = list() for sub_key in range(num_subsys): - d = {'subsys_type': read_int16(fid), - 'subsys_num': read_int16(fid), - 'cards_per_sys': read_int16(fid), - 'channels_per_card': read_int16(fid), - 'card_version': read_int16(fid)} + d = { + "subsys_type": read_int16(fid), + "subsys_num": read_int16(fid), + "cards_per_sys": read_int16(fid), + "channels_per_card": read_int16(fid), + "card_version": read_int16(fid), + } fid.seek(2, 1) - d.update({'offsetdacgain': read_float(fid), - 'squid_type': read_int32(fid), - 'timesliceoffset': read_int16(fid), - 'padding': read_int16(fid), - 'volts_per_bit': read_float(fid)}) + d.update( + { + "offsetdacgain": read_float(fid), + "squid_type": read_int32(fid), + "timesliceoffset": read_int16(fid), + "padding": read_int16(fid), + "volts_per_bit": read_float(fid), + } + ) - dta['subsys'] += [d] + dta["subsys"] += [d] elif kind == BTI.UB_B_CH_LABELS: - dta['version'] = read_int32(fid) - dta['entries'] = read_int32(fid) + dta["version"] = read_int32(fid) + dta["entries"] = read_int32(fid) fid.seek(16, 1) - dta['labels'] = list() - for label in range(dta['entries']): - dta['labels'] += [read_str(fid, 16)] + dta["labels"] = list() + for label in range(dta["entries"]): + dta["labels"] += [read_str(fid, 16)] elif kind == BTI.UB_B_CALIBRATION: - dta['sensor_no'] = read_int16(fid) + dta["sensor_no"] = read_int16(fid) fid.seek(2, 1) - dta['timestamp'] = read_int32(fid) - dta['logdir'] = read_str(fid, 256) + dta["timestamp"] = read_int32(fid) + dta["logdir"] = read_str(fid, 256) elif kind == BTI.UB_B_SYS_CONFIG_TIME: # handle difference btw/ linux (256) and solaris (512) - size = 256 if ub['hdr']['user_space_size'] == 260 else 512 - dta['sysconfig_name'] = read_str(fid, size) - dta['timestamp'] = read_int32(fid) + size = 256 if ub["hdr"]["user_space_size"] == 260 else 512 + dta["sysconfig_name"] = read_str(fid, size) + dta["timestamp"] = read_int32(fid) elif kind == BTI.UB_B_DELTA_ENABLED: - dta['delta_enabled'] = read_int16(fid) + dta["delta_enabled"] = read_int16(fid) elif kind in [BTI.UB_B_E_TABLE_USED, BTI.UB_B_E_TABLE]: - dta['hdr'] = {'version': read_int32(fid), - 'entry_size': read_int32(fid), - 'n_entries': read_int32(fid), - 'filtername': read_str(fid, 16), - 'n_e_values': read_int32(fid), - 'reserved': read_str(fid, 28)} - - if dta['hdr']['version'] == 2: + dta["hdr"] = { + "version": read_int32(fid), + "entry_size": read_int32(fid), + "n_entries": read_int32(fid), + "filtername": read_str(fid, 16), + "n_e_values": read_int32(fid), + "reserved": read_str(fid, 28), + } + + if dta["hdr"]["version"] == 2: size = 16 - dta['ch_names'] = [read_str(fid, size) for ch in - range(dta['hdr']['n_entries'])] - dta['e_ch_names'] = [read_str(fid, size) for ch in - range(dta['hdr']['n_e_values'])] - - rows = dta['hdr']['n_entries'] - cols = dta['hdr']['n_e_values'] - dta['etable'] = read_float_matrix(fid, rows, cols) + dta["ch_names"] = [ + read_str(fid, size) for ch in range(dta["hdr"]["n_entries"]) + ] + dta["e_ch_names"] = [ + read_str(fid, size) + for ch in range(dta["hdr"]["n_e_values"]) + ] + + rows = dta["hdr"]["n_entries"] + cols = dta["hdr"]["n_e_values"] + dta["etable"] = read_float_matrix(fid, rows, cols) else: # handle MAGNES2500 naming scheme - dta['ch_names'] = ['WH2500'] * dta['hdr']['n_e_values'] - dta['hdr']['n_e_values'] = 6 - dta['e_ch_names'] = BTI_WH2500_REF_MAG - rows = dta['hdr']['n_entries'] - cols = dta['hdr']['n_e_values'] - dta['etable'] = read_float_matrix(fid, rows, cols) - - elif any([kind == BTI.UB_B_WEIGHTS_USED, - kind[:4] == BTI.UB_B_WEIGHT_TABLE]): - dta['hdr'] = dict( + dta["ch_names"] = ["WH2500"] * dta["hdr"]["n_e_values"] + dta["hdr"]["n_e_values"] = 6 + dta["e_ch_names"] = BTI_WH2500_REF_MAG + rows = dta["hdr"]["n_entries"] + cols = dta["hdr"]["n_e_values"] + dta["etable"] = read_float_matrix(fid, rows, cols) + + elif any( + [kind == BTI.UB_B_WEIGHTS_USED, kind[:4] == BTI.UB_B_WEIGHT_TABLE] + ): + dta["hdr"] = dict( version=read_int32(fid), n_bytes=read_uint32(fid), n_entries=read_uint32(fid), - name=read_str(fid, 32)) - if dta['hdr']['version'] == 2: - dta['hdr'].update( + name=read_str(fid, 32), + ) + if dta["hdr"]["version"] == 2: + dta["hdr"].update( description=read_str(fid, 80), n_anlg=read_uint32(fid), n_dsp=read_uint32(fid), - reserved=read_str(fid, 72)) - dta['ch_names'] = [read_str(fid, 16) for ch in - range(dta['hdr']['n_entries'])] - dta['anlg_ch_names'] = [read_str(fid, 16) for ch in - range(dta['hdr']['n_anlg'])] - - dta['dsp_ch_names'] = [read_str(fid, 16) for ch in - range(dta['hdr']['n_dsp'])] - dta['dsp_wts'] = read_float_matrix( - fid, dta['hdr']['n_entries'], dta['hdr']['n_dsp']) - dta['anlg_wts'] = read_int16_matrix( - fid, dta['hdr']['n_entries'], dta['hdr']['n_anlg']) + reserved=read_str(fid, 72), + ) + dta["ch_names"] = [ + read_str(fid, 16) for ch in range(dta["hdr"]["n_entries"]) + ] + dta["anlg_ch_names"] = [ + read_str(fid, 16) for ch in range(dta["hdr"]["n_anlg"]) + ] + + dta["dsp_ch_names"] = [ + read_str(fid, 16) for ch in range(dta["hdr"]["n_dsp"]) + ] + dta["dsp_wts"] = read_float_matrix( + fid, dta["hdr"]["n_entries"], dta["hdr"]["n_dsp"] + ) + dta["anlg_wts"] = read_int16_matrix( + fid, dta["hdr"]["n_entries"], dta["hdr"]["n_anlg"] + ) else: # handle MAGNES2500 naming scheme - fid.seek(start_bytes + ub['hdr']['user_space_size'] - - dta['hdr']['n_bytes'] * - dta['hdr']['n_entries'], 0) - - dta['hdr']['n_dsp'] = dta['hdr']['n_bytes'] // 4 - 2 - assert (dta['hdr']['n_dsp'] == - len(BTI_WH2500_REF_MAG) + - len(BTI_WH2500_REF_GRAD)) - dta['ch_names'] = ['WH2500'] * dta['hdr']['n_entries'] - dta['hdr']['n_anlg'] = 3 + fid.seek( + start_bytes + + ub["hdr"]["user_space_size"] + - dta["hdr"]["n_bytes"] * dta["hdr"]["n_entries"], + 0, + ) + + dta["hdr"]["n_dsp"] = dta["hdr"]["n_bytes"] // 4 - 2 + assert dta["hdr"]["n_dsp"] == len(BTI_WH2500_REF_MAG) + len( + BTI_WH2500_REF_GRAD + ) + dta["ch_names"] = ["WH2500"] * dta["hdr"]["n_entries"] + dta["hdr"]["n_anlg"] = 3 # These orders could be wrong, so don't set them # for now # dta['anlg_ch_names'] = BTI_WH2500_REF_MAG[:3] # dta['dsp_ch_names'] = (BTI_WH2500_REF_GRAD + # BTI_WH2500_REF_MAG) - dta['anlg_wts'] = np.zeros( - (dta['hdr']['n_entries'], dta['hdr']['n_anlg']), - dtype='i2') - dta['dsp_wts'] = np.zeros( - (dta['hdr']['n_entries'], dta['hdr']['n_dsp']), - dtype='f4') - for n in range(dta['hdr']['n_entries']): - dta['anlg_wts'][n] = read_int16_matrix( - fid, 1, dta['hdr']['n_anlg']) + dta["anlg_wts"] = np.zeros( + (dta["hdr"]["n_entries"], dta["hdr"]["n_anlg"]), dtype="i2" + ) + dta["dsp_wts"] = np.zeros( + (dta["hdr"]["n_entries"], dta["hdr"]["n_dsp"]), dtype="f4" + ) + for n in range(dta["hdr"]["n_entries"]): + dta["anlg_wts"][n] = read_int16_matrix( + fid, 1, dta["hdr"]["n_anlg"] + ) read_int16(fid) - dta['dsp_wts'][n] = read_float_matrix( - fid, 1, dta['hdr']['n_dsp']) + dta["dsp_wts"][n] = read_float_matrix( + fid, 1, dta["hdr"]["n_dsp"] + ) elif kind == BTI.UB_B_TRIG_MASK: - dta['version'] = read_int32(fid) - dta['entries'] = read_int32(fid) + dta["version"] = read_int32(fid) + dta["entries"] = read_int32(fid) fid.seek(16, 1) - dta['masks'] = [] - for entry in range(dta['entries']): - d = {'name': read_str(fid, 20), - 'nbits': read_uint16(fid), - 'shift': read_uint16(fid), - 'mask': read_uint32(fid)} - dta['masks'] += [d] + dta["masks"] = [] + for entry in range(dta["entries"]): + d = { + "name": read_str(fid, 20), + "nbits": read_uint16(fid), + "shift": read_uint16(fid), + "mask": read_uint32(fid), + } + dta["masks"] += [d] fid.seek(8, 1) else: - dta['unknown'] = {'hdr': read_char(fid, - ub['hdr']['user_space_size'])} + dta["unknown"] = {"hdr": read_char(fid, ub["hdr"]["user_space_size"])} n_read = fid.tell() - start_bytes - if n_read != ub['hdr']['user_space_size']: - raise RuntimeError('Internal MNE reading error, read size %d ' - '!= %d expected size for kind %s' - % (n_read, ub['hdr']['user_space_size'], - kind)) + if n_read != ub["hdr"]["user_space_size"]: + raise RuntimeError( + "Internal MNE reading error, read size %d " + "!= %d expected size for kind %s" + % (n_read, ub["hdr"]["user_space_size"], kind) + ) ub.update(dta) # finally update the userblock data _correct_offset(fid) # after reading. - cfg['chs'] = list() + cfg["chs"] = list() # prepare reading channels - for channel in range(cfg['hdr']['total_chans']): - ch = {'name': read_str(fid, 16), - 'chan_no': read_int16(fid), - 'ch_type': read_uint16(fid), - 'sensor_no': read_int16(fid), - 'data': dict()} + for channel in range(cfg["hdr"]["total_chans"]): + ch = { + "name": read_str(fid, 16), + "chan_no": read_int16(fid), + "ch_type": read_uint16(fid), + "sensor_no": read_int16(fid), + "data": dict(), + } fid.seek(2, 1) - ch.update({'gain': read_float(fid), - 'units_per_bit': read_float(fid), - 'yaxis_label': read_str(fid, 16), - 'aar_val': read_double(fid), - 'checksum': read_int32(fid), - 'reserved': read_str(fid, 32)}) - - cfg['chs'] += [ch] + ch.update( + { + "gain": read_float(fid), + "units_per_bit": read_float(fid), + "yaxis_label": read_str(fid, 16), + "aar_val": read_double(fid), + "checksum": read_int32(fid), + "reserved": read_str(fid, 32), + } + ) + + cfg["chs"] += [ch] _correct_offset(fid) # before and after dta = dict() - if ch['ch_type'] in [BTI.CHTYPE_MEG, BTI.CHTYPE_REFERENCE]: - dev = {'device_info': read_dev_header(fid), - 'inductance': read_float(fid), - 'padding': read_str(fid, 4), - 'transform': _correct_trans(read_transform(fid), False), - 'xform_flag': read_int16(fid), - 'total_loops': read_int16(fid)} + if ch["ch_type"] in [BTI.CHTYPE_MEG, BTI.CHTYPE_REFERENCE]: + dev = { + "device_info": read_dev_header(fid), + "inductance": read_float(fid), + "padding": read_str(fid, 4), + "transform": _correct_trans(read_transform(fid), False), + "xform_flag": read_int16(fid), + "total_loops": read_int16(fid), + } fid.seek(4, 1) - dev['reserved'] = read_str(fid, 32) - dta.update({'dev': dev, 'loops': []}) - for loop in range(dev['total_loops']): - d = {'position': read_double_matrix(fid, 1, 3), - 'orientation': read_double_matrix(fid, 1, 3), - 'radius': read_double(fid), - 'wire_radius': read_double(fid), - 'turns': read_int16(fid)} + dev["reserved"] = read_str(fid, 32) + dta.update({"dev": dev, "loops": []}) + for loop in range(dev["total_loops"]): + d = { + "position": read_double_matrix(fid, 1, 3), + "orientation": read_double_matrix(fid, 1, 3), + "radius": read_double(fid), + "wire_radius": read_double(fid), + "turns": read_int16(fid), + } fid.seek(2, 1) - d['checksum'] = read_int32(fid) - d['reserved'] = read_str(fid, 32) - dta['loops'] += [d] - - elif ch['ch_type'] == BTI.CHTYPE_EEG: - dta = {'device_info': read_dev_header(fid), - 'impedance': read_float(fid), - 'padding': read_str(fid, 4), - 'transform': read_transform(fid), - 'reserved': read_char(fid, 32)} - - elif ch['ch_type'] == BTI.CHTYPE_EXTERNAL: - dta = {'device_info': read_dev_header(fid), - 'user_space_size': read_int32(fid), - 'reserved': read_str(fid, 32)} - - elif ch['ch_type'] == BTI.CHTYPE_TRIGGER: - dta = {'device_info': read_dev_header(fid), - 'user_space_size': read_int32(fid)} + d["checksum"] = read_int32(fid) + d["reserved"] = read_str(fid, 32) + dta["loops"] += [d] + + elif ch["ch_type"] == BTI.CHTYPE_EEG: + dta = { + "device_info": read_dev_header(fid), + "impedance": read_float(fid), + "padding": read_str(fid, 4), + "transform": read_transform(fid), + "reserved": read_char(fid, 32), + } + + elif ch["ch_type"] == BTI.CHTYPE_EXTERNAL: + dta = { + "device_info": read_dev_header(fid), + "user_space_size": read_int32(fid), + "reserved": read_str(fid, 32), + } + + elif ch["ch_type"] == BTI.CHTYPE_TRIGGER: + dta = { + "device_info": read_dev_header(fid), + "user_space_size": read_int32(fid), + } fid.seek(2, 1) - dta['reserved'] = read_str(fid, 32) + dta["reserved"] = read_str(fid, 32) - elif ch['ch_type'] in [BTI.CHTYPE_UTILITY, BTI.CHTYPE_DERIVED]: - dta = {'device_info': read_dev_header(fid), - 'user_space_size': read_int32(fid), - 'reserved': read_str(fid, 32)} + elif ch["ch_type"] in [BTI.CHTYPE_UTILITY, BTI.CHTYPE_DERIVED]: + dta = { + "device_info": read_dev_header(fid), + "user_space_size": read_int32(fid), + "reserved": read_str(fid, 32), + } - elif ch['ch_type'] == BTI.CHTYPE_SHORTED: - dta = {'device_info': read_dev_header(fid), - 'reserved': read_str(fid, 32)} + elif ch["ch_type"] == BTI.CHTYPE_SHORTED: + dta = { + "device_info": read_dev_header(fid), + "reserved": read_str(fid, 32), + } ch.update(dta) # add data collected _correct_offset(fid) # after each reading @@ -551,13 +627,15 @@ def _read_config(fname): def _read_epoch(fid): """Read BTi PDF epoch.""" - out = {'pts_in_epoch': read_int32(fid), - 'epoch_duration': read_float(fid), - 'expected_iti': read_float(fid), - 'actual_iti': read_float(fid), - 'total_var_events': read_int32(fid), - 'checksum': read_int32(fid), - 'epoch_timestamp': read_int32(fid)} + out = { + "pts_in_epoch": read_int32(fid), + "epoch_duration": read_float(fid), + "expected_iti": read_float(fid), + "actual_iti": read_float(fid), + "total_var_events": read_int32(fid), + "checksum": read_int32(fid), + "epoch_timestamp": read_int32(fid), + } fid.seek(28, 1) @@ -566,20 +644,26 @@ def _read_epoch(fid): def _read_channel(fid): """Read BTi PDF channel.""" - out = {'chan_label': read_str(fid, 16), - 'chan_no': read_int16(fid), - 'attributes': read_int16(fid), - 'scale': read_float(fid), - 'yaxis_label': read_str(fid, 16), - 'valid_min_max': read_int16(fid)} + out = { + "chan_label": read_str(fid, 16), + "chan_no": read_int16(fid), + "attributes": read_int16(fid), + "scale": read_float(fid), + "yaxis_label": read_str(fid, 16), + "valid_min_max": read_int16(fid), + } fid.seek(6, 1) - out.update({'ymin': read_double(fid), - 'ymax': read_double(fid), - 'index': read_int32(fid), - 'checksum': read_int32(fid), - 'off_flag': read_str(fid, 4), - 'offset': read_float(fid)}) + out.update( + { + "ymin": read_double(fid), + "ymax": read_double(fid), + "index": read_int32(fid), + "checksum": read_int32(fid), + "off_flag": read_str(fid, 4), + "offset": read_float(fid), + } + ) fid.seek(24, 1) @@ -588,12 +672,14 @@ def _read_channel(fid): def _read_event(fid): """Read BTi PDF event.""" - out = {'event_name': read_str(fid, 16), - 'start_lat': read_float(fid), - 'end_lat': read_float(fid), - 'step_size': read_float(fid), - 'fixed_event': read_int16(fid), - 'checksum': read_int32(fid)} + out = { + "event_name": read_str(fid, 16), + "start_lat": read_float(fid), + "end_lat": read_float(fid), + "step_size": read_float(fid), + "fixed_event": read_int16(fid), + "checksum": read_int32(fid), + } fid.seek(32, 1) _correct_offset(fid) @@ -603,44 +689,48 @@ def _read_event(fid): def _read_process(fid): """Read BTi PDF process.""" - out = {'nbytes': read_int32(fid), - 'process_type': read_str(fid, 20), - 'checksum': read_int32(fid), - 'user': read_str(fid, 32), - 'timestamp': read_int32(fid), - 'filename': read_str(fid, 256), - 'total_steps': read_int32(fid)} + out = { + "nbytes": read_int32(fid), + "process_type": read_str(fid, 20), + "checksum": read_int32(fid), + "user": read_str(fid, 32), + "timestamp": read_int32(fid), + "filename": read_str(fid, 256), + "total_steps": read_int32(fid), + } fid.seek(32, 1) _correct_offset(fid) - out['processing_steps'] = list() - for step in range(out['total_steps']): - this_step = {'nbytes': read_int32(fid), - 'process_type': read_str(fid, 20), - 'checksum': read_int32(fid)} - ptype = this_step['process_type'] + out["processing_steps"] = list() + for step in range(out["total_steps"]): + this_step = { + "nbytes": read_int32(fid), + "process_type": read_str(fid, 20), + "checksum": read_int32(fid), + } + ptype = this_step["process_type"] if ptype == BTI.PROC_DEFAULTS: - this_step['scale_option'] = read_int32(fid) + this_step["scale_option"] = read_int32(fid) fid.seek(4, 1) - this_step['scale'] = read_double(fid) - this_step['dtype'] = read_int32(fid) - this_step['selected'] = read_int16(fid) - this_step['color_display'] = read_int16(fid) + this_step["scale"] = read_double(fid) + this_step["dtype"] = read_int32(fid) + this_step["selected"] = read_int16(fid) + this_step["color_display"] = read_int16(fid) fid.seek(32, 1) elif ptype in BTI.PROC_FILTER: - this_step['freq'] = read_float(fid) + this_step["freq"] = read_float(fid) fid.seek(32, 1) elif ptype in BTI.PROC_BPFILTER: - this_step['high_freq'] = read_float(fid) - this_step['low_freq'] = read_float(fid) + this_step["high_freq"] = read_float(fid) + this_step["low_freq"] = read_float(fid) else: - jump = this_step['user_space_size'] = read_int32(fid) + jump = this_step["user_space_size"] = read_int32(fid) fid.seek(32, 1) fid.seek(jump, 1) - out['processing_steps'] += [this_step] + out["processing_steps"] += [this_step] _correct_offset(fid) return out @@ -648,30 +738,32 @@ def _read_process(fid): def _read_assoc_file(fid): """Read BTi PDF assocfile.""" - out = {'file_id': read_int16(fid), - 'length': read_int16(fid)} + out = {"file_id": read_int16(fid), "length": read_int16(fid)} fid.seek(32, 1) - out['checksum'] = read_int32(fid) + out["checksum"] = read_int32(fid) return out def _read_pfid_ed(fid): """Read PDF ed file.""" - out = {'comment_size': read_int32(fid), - 'name': read_str(fid, 17)} + out = {"comment_size": read_int32(fid), "name": read_str(fid, 17)} fid.seek(9, 1) - out.update({'pdf_number': read_int16(fid), - 'total_events': read_int32(fid), - 'timestamp': read_int32(fid), - 'flags': read_int32(fid), - 'de_process': read_int32(fid), - 'checksum': read_int32(fid), - 'ed_id': read_int32(fid), - 'win_width': read_float(fid), - 'win_offset': read_float(fid)}) + out.update( + { + "pdf_number": read_int16(fid), + "total_events": read_int32(fid), + "timestamp": read_int32(fid), + "flags": read_int32(fid), + "de_process": read_int32(fid), + "checksum": read_int32(fid), + "ed_id": read_int32(fid), + "win_width": read_float(fid), + "win_offset": read_float(fid), + } + ) fid.seek(8, 1) @@ -680,61 +772,74 @@ def _read_pfid_ed(fid): def _read_coil_def(fid): """Read coil definition.""" - coildef = {'position': read_double_matrix(fid, 1, 3), - 'orientation': read_double_matrix(fid, 1, 3), - 'radius': read_double(fid), - 'wire_radius': read_double(fid), - 'turns': read_int16(fid)} + coildef = { + "position": read_double_matrix(fid, 1, 3), + "orientation": read_double_matrix(fid, 1, 3), + "radius": read_double(fid), + "wire_radius": read_double(fid), + "turns": read_int16(fid), + } fid.seek(fid, 2, 1) - coildef['checksum'] = read_int32(fid) - coildef['reserved'] = read_str(fid, 32) + coildef["checksum"] = read_int32(fid) + coildef["reserved"] = read_str(fid, 32) def _read_ch_config(fid): """Read BTi channel config.""" - cfg = {'name': read_str(fid, BTI.FILE_CONF_CH_NAME), - 'chan_no': read_int16(fid), - 'ch_type': read_uint16(fid), - 'sensor_no': read_int16(fid)} + cfg = { + "name": read_str(fid, BTI.FILE_CONF_CH_NAME), + "chan_no": read_int16(fid), + "ch_type": read_uint16(fid), + "sensor_no": read_int16(fid), + } fid.seek(fid, BTI.FILE_CONF_CH_NEXT, 1) - cfg.update({'gain': read_float(fid), - 'units_per_bit': read_float(fid), - 'yaxis_label': read_str(fid, BTI.FILE_CONF_CH_YLABEL), - 'aar_val': read_double(fid), - 'checksum': read_int32(fid), - 'reserved': read_str(fid, BTI.FILE_CONF_CH_RESERVED)}) + cfg.update( + { + "gain": read_float(fid), + "units_per_bit": read_float(fid), + "yaxis_label": read_str(fid, BTI.FILE_CONF_CH_YLABEL), + "aar_val": read_double(fid), + "checksum": read_int32(fid), + "reserved": read_str(fid, BTI.FILE_CONF_CH_RESERVED), + } + ) _correct_offset(fid) # Then the channel info - ch_type, chan = cfg['ch_type'], dict() - chan['dev'] = {'size': read_int32(fid), - 'checksum': read_int32(fid), - 'reserved': read_str(fid, 32)} + ch_type, chan = cfg["ch_type"], dict() + chan["dev"] = { + "size": read_int32(fid), + "checksum": read_int32(fid), + "reserved": read_str(fid, 32), + } if ch_type in [BTI.CHTYPE_MEG, BTI.CHTYPE_REF]: - chan['loops'] = [_read_coil_def(fid) for d in - range(chan['dev']['total_loops'])] + chan["loops"] = [_read_coil_def(fid) for d in range(chan["dev"]["total_loops"])] elif ch_type == BTI.CHTYPE_EEG: - chan['impedance'] = read_float(fid) - chan['padding'] = read_str(fid, BTI.FILE_CONF_CH_PADDING) - chan['transform'] = read_transform(fid) - chan['reserved'] = read_char(fid, BTI.FILE_CONF_CH_RESERVED) - - elif ch_type in [BTI.CHTYPE_TRIGGER, BTI.CHTYPE_EXTERNAL, - BTI.CHTYPE_UTILITY, BTI.CHTYPE_DERIVED]: - chan['user_space_size'] = read_int32(fid) + chan["impedance"] = read_float(fid) + chan["padding"] = read_str(fid, BTI.FILE_CONF_CH_PADDING) + chan["transform"] = read_transform(fid) + chan["reserved"] = read_char(fid, BTI.FILE_CONF_CH_RESERVED) + + elif ch_type in [ + BTI.CHTYPE_TRIGGER, + BTI.CHTYPE_EXTERNAL, + BTI.CHTYPE_UTILITY, + BTI.CHTYPE_DERIVED, + ]: + chan["user_space_size"] = read_int32(fid) if ch_type == BTI.CHTYPE_TRIGGER: fid.seek(2, 1) - chan['reserved'] = read_str(fid, BTI.FILE_CONF_CH_RESERVED) + chan["reserved"] = read_str(fid, BTI.FILE_CONF_CH_RESERVED) elif ch_type == BTI.CHTYPE_SHORTED: - chan['reserved'] = read_str(fid, BTI.FILE_CONF_CH_RESERVED) + chan["reserved"] = read_str(fid, BTI.FILE_CONF_CH_RESERVED) - cfg['chan'] = chan + cfg["chan"] = chan _correct_offset(fid) @@ -743,79 +848,88 @@ def _read_ch_config(fid): def _read_bti_header_pdf(pdf_fname): """Read header from pdf file.""" - with _bti_open(pdf_fname, 'rb') as fid: + with _bti_open(pdf_fname, "rb") as fid: fid.seek(-8, 2) start = fid.tell() header_position = read_int64(fid) check_value = header_position & BTI.FILE_MASK - if ((start + BTI.FILE_CURPOS - check_value) <= BTI.FILE_MASK): + if (start + BTI.FILE_CURPOS - check_value) <= BTI.FILE_MASK: header_position = check_value # Check header position for alignment issues - if ((header_position % 8) != 0): - header_position += (8 - (header_position % 8)) + if (header_position % 8) != 0: + header_position += 8 - (header_position % 8) fid.seek(header_position, 0) # actual header starts here - info = {'version': read_int16(fid), - 'file_type': read_str(fid, 5), - 'hdr_size': start - header_position, # add for convenience - 'start': start} + info = { + "version": read_int16(fid), + "file_type": read_str(fid, 5), + "hdr_size": start - header_position, # add for convenience + "start": start, + } fid.seek(1, 1) - info.update({'data_format': read_int16(fid), - 'acq_mode': read_int16(fid), - 'total_epochs': read_int32(fid), - 'input_epochs': read_int32(fid), - 'total_events': read_int32(fid), - 'total_fixed_events': read_int32(fid), - 'sample_period': read_float(fid), - 'xaxis_label': read_str(fid, 16), - 'total_processes': read_int32(fid), - 'total_chans': read_int16(fid)}) + info.update( + { + "data_format": read_int16(fid), + "acq_mode": read_int16(fid), + "total_epochs": read_int32(fid), + "input_epochs": read_int32(fid), + "total_events": read_int32(fid), + "total_fixed_events": read_int32(fid), + "sample_period": read_float(fid), + "xaxis_label": read_str(fid, 16), + "total_processes": read_int32(fid), + "total_chans": read_int16(fid), + } + ) fid.seek(2, 1) - info.update({'checksum': read_int32(fid), - 'total_ed_classes': read_int32(fid), - 'total_associated_files': read_int16(fid), - 'last_file_index': read_int16(fid), - 'timestamp': read_int32(fid)}) + info.update( + { + "checksum": read_int32(fid), + "total_ed_classes": read_int32(fid), + "total_associated_files": read_int16(fid), + "last_file_index": read_int16(fid), + "timestamp": read_int32(fid), + } + ) fid.seek(20, 1) _correct_offset(fid) # actual header ends here, so dar seems ok. - info['epochs'] = [_read_epoch(fid) for epoch in - range(info['total_epochs'])] + info["epochs"] = [_read_epoch(fid) for epoch in range(info["total_epochs"])] - info['chs'] = [_read_channel(fid) for ch in - range(info['total_chans'])] + info["chs"] = [_read_channel(fid) for ch in range(info["total_chans"])] - info['events'] = [_read_event(fid) for event in - range(info['total_events'])] + info["events"] = [_read_event(fid) for event in range(info["total_events"])] - info['processes'] = [_read_process(fid) for process in - range(info['total_processes'])] + info["processes"] = [ + _read_process(fid) for process in range(info["total_processes"]) + ] - info['assocfiles'] = [_read_assoc_file(fid) for af in - range(info['total_associated_files'])] + info["assocfiles"] = [ + _read_assoc_file(fid) for af in range(info["total_associated_files"]) + ] - info['edclasses'] = [_read_pfid_ed(fid) for ed_class in - range(info['total_ed_classes'])] + info["edclasses"] = [ + _read_pfid_ed(fid) for ed_class in range(info["total_ed_classes"]) + ] - info['extra_data'] = fid.read(start - fid.tell()) - info['pdf_fname'] = pdf_fname + info["extra_data"] = fid.read(start - fid.tell()) + info["pdf_fname"] = pdf_fname - info['total_slices'] = sum(e['pts_in_epoch'] for e in - info['epochs']) + info["total_slices"] = sum(e["pts_in_epoch"] for e in info["epochs"]) - info['dtype'] = DTYPES[info['data_format']] - bps = info['dtype'].itemsize * info['total_chans'] - info['bytes_per_slice'] = bps + info["dtype"] = DTYPES[info["data_format"]] + bps = info["dtype"].itemsize * info["total_chans"] + info["bytes_per_slice"] = bps return info @@ -823,71 +937,74 @@ def _read_bti_header(pdf_fname, config_fname, sort_by_ch_name=True): """Read bti PDF header.""" info = _read_bti_header_pdf(pdf_fname) if pdf_fname is not None else dict() cfg = _read_config(config_fname) - info['bti_transform'] = cfg['transforms'] + info["bti_transform"] = cfg["transforms"] # augment channel list by according info from config. # get channels from config present in PDF - chans = info.get('chs', None) + chans = info.get("chs", None) if chans is not None: - chans_cfg = [c for c in cfg['chs'] if c['chan_no'] - in [c_['chan_no'] for c_ in chans]] + chans_cfg = [ + c for c in cfg["chs"] if c["chan_no"] in [c_["chan_no"] for c_ in chans] + ] # sort chans_cfg and chans - chans = sorted(chans, key=lambda k: k['chan_no']) - chans_cfg = sorted(chans_cfg, key=lambda k: k['chan_no']) + chans = sorted(chans, key=lambda k: k["chan_no"]) + chans_cfg = sorted(chans_cfg, key=lambda k: k["chan_no"]) # check all pdf channels are present in config - match = [c['chan_no'] for c in chans_cfg] == \ - [c['chan_no'] for c in chans] + match = [c["chan_no"] for c in chans_cfg] == [c["chan_no"] for c in chans] if not match: - raise RuntimeError('Could not match raw data channels with' - ' config channels. Some of the channels' - ' found are not described in config.') + raise RuntimeError( + "Could not match raw data channels with" + " config channels. Some of the channels" + " found are not described in config." + ) else: - chans_cfg = cfg['chs'] + chans_cfg = cfg["chs"] chans = [dict() for _ in chans_cfg] # transfer channel info from config to channel info for ch, ch_cfg in zip(chans, chans_cfg): - ch['upb'] = ch_cfg['units_per_bit'] - ch['gain'] = ch_cfg['gain'] - ch['name'] = ch_cfg['name'] - if ch_cfg.get('dev', dict()).get('transform', None) is not None: - ch['loc'] = _coil_trans_to_loc(ch_cfg['dev']['transform']) + ch["upb"] = ch_cfg["units_per_bit"] + ch["gain"] = ch_cfg["gain"] + ch["name"] = ch_cfg["name"] + if ch_cfg.get("dev", dict()).get("transform", None) is not None: + ch["loc"] = _coil_trans_to_loc(ch_cfg["dev"]["transform"]) else: - ch['loc'] = np.full(12, np.nan) + ch["loc"] = np.full(12, np.nan) if pdf_fname is not None: - if info['data_format'] <= 2: # see DTYPES, implies integer - ch['cal'] = ch['scale'] * ch['upb'] / float(ch['gain']) + if info["data_format"] <= 2: # see DTYPES, implies integer + ch["cal"] = ch["scale"] * ch["upb"] / float(ch["gain"]) else: # float - ch['cal'] = ch['scale'] * ch['gain'] + ch["cal"] = ch["scale"] * ch["gain"] else: # if we are in this mode we don't read data, only channel info. - ch['cal'] = ch['scale'] = 1.0 # so we put a trivial default value + ch["cal"] = ch["scale"] = 1.0 # so we put a trivial default value if sort_by_ch_name: - by_index = [(i, d['index']) for i, d in enumerate(chans)] + by_index = [(i, d["index"]) for i, d in enumerate(chans)] by_index.sort(key=lambda c: c[1]) by_index = [idx[0] for idx in by_index] chs = [chans[pos] for pos in by_index] - sort_by_name_idx = [(i, d['name']) for i, d in enumerate(chs)] - a_chs = [c for c in sort_by_name_idx if c[1].startswith('A')] - other_chs = [c for c in sort_by_name_idx if not c[1].startswith('A')] - sort_by_name_idx = sorted( - a_chs, key=lambda c: int(c[1][1:])) + sorted(other_chs) + sort_by_name_idx = [(i, d["name"]) for i, d in enumerate(chs)] + a_chs = [c for c in sort_by_name_idx if c[1].startswith("A")] + other_chs = [c for c in sort_by_name_idx if not c[1].startswith("A")] + sort_by_name_idx = sorted(a_chs, key=lambda c: int(c[1][1:])) + sorted( + other_chs + ) sort_by_name_idx = [idx[0] for idx in sort_by_name_idx] - info['chs'] = [chans[pos] for pos in sort_by_name_idx] - info['order'] = sort_by_name_idx + info["chs"] = [chans[pos] for pos in sort_by_name_idx] + info["order"] = sort_by_name_idx else: - info['chs'] = chans - info['order'] = np.arange(len(chans)) + info["chs"] = chans + info["order"] = np.arange(len(chans)) # finally add some important fields from the config - info['e_table'] = cfg['user_blocks'][BTI.UB_B_E_TABLE_USED] - info['weights'] = cfg['user_blocks'][BTI.UB_B_WEIGHTS_USED] + info["e_table"] = cfg["user_blocks"][BTI.UB_B_E_TABLE_USED] + info["weights"] = cfg["user_blocks"][BTI.UB_B_WEIGHTS_USED] return info @@ -896,11 +1013,11 @@ def _correct_trans(t, check=True): """Convert to a transformation matrix.""" t = np.array(t, np.float64) t[:3, :3] *= t[3, :3][:, np.newaxis] # apply scalings - t[3, :3] = 0. # remove them + t[3, :3] = 0.0 # remove them if check: - assert t[3, 3] == 1. + assert t[3, 3] == 1.0 else: - t[3, 3] = 1. + t[3, 3] = 1.0 return t @@ -942,58 +1059,77 @@ class RawBTi(BaseRaw): """ @verbose - def __init__(self, pdf_fname, config_fname='config', - head_shape_fname='hs_file', rotation_x=0., - translation=(0.0, 0.02, 0.11), convert=True, - rename_channels=True, sort_by_ch_name=True, - ecg_ch='E31', eog_ch=('E63', 'E64'), - preload=False, verbose=None): # noqa: D102 + def __init__( + self, + pdf_fname, + config_fname="config", + head_shape_fname="hs_file", + rotation_x=0.0, + translation=(0.0, 0.02, 0.11), + convert=True, + rename_channels=True, + sort_by_ch_name=True, + ecg_ch="E31", + eog_ch=("E63", "E64"), + preload=False, + verbose=None, + ): # noqa: D102 info, bti_info = _get_bti_info( - pdf_fname=pdf_fname, config_fname=config_fname, - head_shape_fname=head_shape_fname, rotation_x=rotation_x, - translation=translation, convert=convert, ecg_ch=ecg_ch, + pdf_fname=pdf_fname, + config_fname=config_fname, + head_shape_fname=head_shape_fname, + rotation_x=rotation_x, + translation=translation, + convert=convert, + ecg_ch=ecg_ch, rename_channels=rename_channels, - sort_by_ch_name=sort_by_ch_name, eog_ch=eog_ch) - self.bti_ch_labels = [c['chan_label'] for c in bti_info['chs']] + sort_by_ch_name=sort_by_ch_name, + eog_ch=eog_ch, + ) + self.bti_ch_labels = [c["chan_label"] for c in bti_info["chs"]] # make Raw repr work if we have a BytesIO as input if isinstance(pdf_fname, BytesIO): pdf_fname = repr(pdf_fname) super(RawBTi, self).__init__( - info, preload, filenames=[pdf_fname], raw_extras=[bti_info], - last_samps=[bti_info['total_slices'] - 1], verbose=verbose) + info, + preload, + filenames=[pdf_fname], + raw_extras=[bti_info], + last_samps=[bti_info["total_slices"] - 1], + verbose=verbose, + ) def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): """Read a segment of data from a file.""" bti_info = self._raw_extras[fi] - fname = bti_info['pdf_fname'] - dtype = bti_info['dtype'] - assert len(bti_info['chs']) == self._raw_extras[fi]['orig_nchan'] - n_channels = len(bti_info['chs']) + fname = bti_info["pdf_fname"] + dtype = bti_info["dtype"] + assert len(bti_info["chs"]) == self._raw_extras[fi]["orig_nchan"] + n_channels = len(bti_info["chs"]) n_bytes = np.dtype(dtype).itemsize data_left = (stop - start) * n_channels - read_cals = np.empty((bti_info['total_chans'],)) - for ch in bti_info['chs']: - read_cals[ch['index']] = ch['cal'] + read_cals = np.empty((bti_info["total_chans"],)) + for ch in bti_info["chs"]: + read_cals[ch["index"]] = ch["cal"] block_size = ((int(100e6) // n_bytes) // n_channels) * n_channels block_size = min(data_left, block_size) # extract data in chunks - with _bti_open(fname, 'rb') as fid: - fid.seek(bti_info['bytes_per_slice'] * start, 0) - for sample_start in np.arange(0, data_left, - block_size) // n_channels: + with _bti_open(fname, "rb") as fid: + fid.seek(bti_info["bytes_per_slice"] * start, 0) + for sample_start in np.arange(0, data_left, block_size) // n_channels: count = min(block_size, data_left - sample_start * n_channels) if isinstance(fid, BytesIO): block = np.frombuffer(fid.getvalue(), dtype, count) else: block = np.fromfile(fid, dtype, count) sample_stop = sample_start + count // n_channels - shape = (sample_stop - sample_start, bti_info['total_chans']) + shape = (sample_stop - sample_start, bti_info["total_chans"]) block.shape = shape data_view = data[:, sample_start:sample_stop] one = np.empty(block.shape[::-1]) - for ii, b_i_o in enumerate(bti_info['order']): + for ii, b_i_o in enumerate(bti_info["order"]): one[ii] = block[:, b_i_o] * read_cals[b_i_o] _mult_cal_one(data_view, one, idx, cals, mult) @@ -1001,47 +1137,64 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): @functools.lru_cache(1) def _1020_names(): from mne.channels import make_standard_montage - return set(ch_name.lower() - for ch_name in make_standard_montage('standard_1005').ch_names) + + return set( + ch_name.lower() for ch_name in make_standard_montage("standard_1005").ch_names + ) def _eeg_like(ch_name): # Some bti recordigs look like "F4-POz", so let's at least mark them # as EEG - if ch_name.count('-') != 1: + if ch_name.count("-") != 1: return - ch, ref = ch_name.split('-') + ch, ref = ch_name.split("-") eeg_names = _1020_names() return ch.lower() in eeg_names and ref.lower() in eeg_names def _make_bti_digitization( - info, head_shape_fname, convert, use_hpi, bti_dev_t, dev_ctf_t): + info, head_shape_fname, convert, use_hpi, bti_dev_t, dev_ctf_t +): with info._unlock(): if head_shape_fname: - logger.info('... Reading digitization points from %s' % - head_shape_fname) - - nasion, lpa, rpa, hpi, dig_points = _read_head_shape( - head_shape_fname) - info['dig'], dev_head_t, ctf_head_t = _make_bti_dig_points( - nasion, lpa, rpa, hpi, dig_points, - convert, use_hpi, bti_dev_t, dev_ctf_t) + logger.info("... Reading digitization points from %s" % head_shape_fname) + + nasion, lpa, rpa, hpi, dig_points = _read_head_shape(head_shape_fname) + info["dig"], dev_head_t, ctf_head_t = _make_bti_dig_points( + nasion, + lpa, + rpa, + hpi, + dig_points, + convert, + use_hpi, + bti_dev_t, + dev_ctf_t, + ) else: - logger.info('... no headshape file supplied, doing nothing.') - info['dig'] = None - dev_head_t = Transform('meg', 'head', trans=None) - ctf_head_t = Transform('ctf_head', 'head', trans=None) + logger.info("... no headshape file supplied, doing nothing.") + info["dig"] = None + dev_head_t = Transform("meg", "head", trans=None) + ctf_head_t = Transform("ctf_head", "head", trans=None) - info.update(dev_head_t=dev_head_t, dev_ctf_t=dev_ctf_t, - ctf_head_t=ctf_head_t) + info.update(dev_head_t=dev_head_t, dev_ctf_t=dev_ctf_t, ctf_head_t=ctf_head_t) return info -def _get_bti_info(pdf_fname, config_fname, head_shape_fname, rotation_x, - translation, convert, ecg_ch, eog_ch, rename_channels=True, - sort_by_ch_name=True): +def _get_bti_info( + pdf_fname, + config_fname, + head_shape_fname, + rotation_x, + translation, + convert, + ecg_ch, + eog_ch, + rename_channels=True, + sort_by_ch_name=True, +): """Read BTI info. Note. This helper supports partial construction of infos when `pdf_fname` @@ -1058,170 +1211,175 @@ def _get_bti_info(pdf_fname, config_fname, head_shape_fname, rotation_x, """ if pdf_fname is None: - logger.info('No pdf_fname passed, trying to construct partial info ' - 'from config') + logger.info( + "No pdf_fname passed, trying to construct partial info " "from config" + ) if pdf_fname is not None and not isinstance(pdf_fname, BytesIO): if not op.isabs(pdf_fname): pdf_fname = op.abspath(pdf_fname) if not isinstance(config_fname, BytesIO): if not op.isabs(config_fname): - config_tries = [op.abspath(config_fname), - op.abspath(op.join(op.dirname(pdf_fname), - config_fname))] + config_tries = [ + op.abspath(config_fname), + op.abspath(op.join(op.dirname(pdf_fname), config_fname)), + ] for config_try in config_tries: if op.isfile(config_try): config_fname = config_try break if not op.isfile(config_fname): - raise ValueError('Could not find the config file %s. Please check' - ' whether you are in the right directory ' - 'or pass the full name' % config_fname) + raise ValueError( + "Could not find the config file %s. Please check" + " whether you are in the right directory " + "or pass the full name" % config_fname + ) - if head_shape_fname is not None and not isinstance( - head_shape_fname, BytesIO): + if head_shape_fname is not None and not isinstance(head_shape_fname, BytesIO): orig_name = head_shape_fname if not op.isfile(head_shape_fname): - head_shape_fname = op.join(op.dirname(pdf_fname), - head_shape_fname) + head_shape_fname = op.join(op.dirname(pdf_fname), head_shape_fname) if not op.isfile(head_shape_fname): - raise ValueError('Could not find the head_shape file "%s". ' - 'You should check whether you are in the ' - 'right directory, pass the full file name, ' - 'or pass head_shape_fname=None.' - % orig_name) - - logger.info('Reading 4D PDF file %s...' % pdf_fname) + raise ValueError( + 'Could not find the head_shape file "%s". ' + "You should check whether you are in the " + "right directory, pass the full file name, " + "or pass head_shape_fname=None." % orig_name + ) + + logger.info("Reading 4D PDF file %s..." % pdf_fname) bti_info = _read_bti_header( - pdf_fname, config_fname, sort_by_ch_name=sort_by_ch_name) + pdf_fname, config_fname, sort_by_ch_name=sort_by_ch_name + ) - dev_ctf_t = Transform('ctf_meg', 'ctf_head', - _correct_trans(bti_info['bti_transform'][0])) + dev_ctf_t = Transform( + "ctf_meg", "ctf_head", _correct_trans(bti_info["bti_transform"][0]) + ) _check_nan_dev_head_t(dev_ctf_t) # for old backward compatibility and external processing - rotation_x = 0. if rotation_x is None else rotation_x + rotation_x = 0.0 if rotation_x is None else rotation_x bti_dev_t = _get_bti_dev_t(rotation_x, translation) if convert else None - bti_dev_t = Transform('ctf_meg', 'meg', bti_dev_t) + bti_dev_t = Transform("ctf_meg", "meg", bti_dev_t) use_hpi = False # hard coded, but marked as later option. - logger.info('Creating Neuromag info structure ...') - if 'sample_period' in bti_info.keys(): - sfreq = 1. / bti_info['sample_period'] + logger.info("Creating Neuromag info structure ...") + if "sample_period" in bti_info.keys(): + sfreq = 1.0 / bti_info["sample_period"] else: sfreq = None if pdf_fname is not None: info = _empty_info(sfreq) - date = bti_info['processes'][0]['timestamp'] - info['meas_date'] = _stamp_to_dt((date, 0)) + date = bti_info["processes"][0]["timestamp"] + info["meas_date"] = _stamp_to_dt((date, 0)) else: # these cannot be guessed from config, see docstring info = _empty_info(1.0) - info['sfreq'] = None - info['lowpass'] = None - info['highpass'] = None - info['meas_date'] = None - bti_info['processes'] = list() + info["sfreq"] = None + info["lowpass"] = None + info["highpass"] = None + info["meas_date"] = None + bti_info["processes"] = list() # browse processing info for filter specs. - hp, lp = info['highpass'], info['lowpass'] - for proc in bti_info['processes']: - if 'filt' in proc['process_type']: - for step in proc['processing_steps']: - if 'high_freq' in step: - hp, lp = step['high_freq'], step['low_freq'] - elif 'hp' in step['process_type']: - hp = step['freq'] - elif 'lp' in step['process_type']: - lp = step['freq'] - - info['highpass'] = hp - info['lowpass'] = lp + hp, lp = info["highpass"], info["lowpass"] + for proc in bti_info["processes"]: + if "filt" in proc["process_type"]: + for step in proc["processing_steps"]: + if "high_freq" in step: + hp, lp = step["high_freq"], step["low_freq"] + elif "hp" in step["process_type"]: + hp = step["freq"] + elif "lp" in step["process_type"]: + lp = step["freq"] + + info["highpass"] = hp + info["lowpass"] = lp chs = [] # Note that 'name' and 'chan_label' are not the same. # We want the configured label if out IO parsed it # except for the MEG channels for which we keep the config name bti_ch_names = list() - for ch in bti_info['chs']: + for ch in bti_info["chs"]: # we have always relied on 'A' as indicator of MEG data channels. - ch_name = ch['name'] - if not ch_name.startswith('A'): - ch_name = ch.get('chan_label', ch_name) + ch_name = ch["name"] + if not ch_name.startswith("A"): + ch_name = ch.get("chan_label", ch_name) bti_ch_names.append(ch_name) - neuromag_ch_names = _rename_channels( - bti_ch_names, ecg_ch=ecg_ch, eog_ch=eog_ch) + neuromag_ch_names = _rename_channels(bti_ch_names, ecg_ch=ecg_ch, eog_ch=eog_ch) ch_mapping = zip(bti_ch_names, neuromag_ch_names) - logger.info('... Setting channel info structure.') + logger.info("... Setting channel info structure.") for idx, (chan_4d, chan_neuromag) in enumerate(ch_mapping): chan_info = _instantiate_default_info_chs() - chan_info['ch_name'] = chan_neuromag if rename_channels else chan_4d - chan_info['logno'] = idx + BTI.FIFF_LOGNO - chan_info['scanno'] = idx + 1 - chan_info['cal'] = float(bti_info['chs'][idx]['scale']) + chan_info["ch_name"] = chan_neuromag if rename_channels else chan_4d + chan_info["logno"] = idx + BTI.FIFF_LOGNO + chan_info["scanno"] = idx + 1 + chan_info["cal"] = float(bti_info["chs"][idx]["scale"]) - if any(chan_4d.startswith(k) for k in ('A', 'M', 'G')): - loc = bti_info['chs'][idx]['loc'] + if any(chan_4d.startswith(k) for k in ("A", "M", "G")): + loc = bti_info["chs"][idx]["loc"] if loc is not None: if convert: if idx == 0: - logger.info('... putting coil transforms in Neuromag ' - 'coordinates') - t = _loc_to_coil_trans(bti_info['chs'][idx]['loc']) + logger.info( + "... putting coil transforms in Neuromag " "coordinates" + ) + t = _loc_to_coil_trans(bti_info["chs"][idx]["loc"]) t = _convert_coil_trans(t, dev_ctf_t, bti_dev_t) loc = _coil_trans_to_loc(t) - chan_info['loc'] = loc + chan_info["loc"] = loc # BTI sensors are natively stored in 4D head coords we believe - meg_frame = (FIFF.FIFFV_COORD_DEVICE if convert else - FIFF.FIFFV_MNE_COORD_4D_HEAD) - eeg_frame = (FIFF.FIFFV_COORD_HEAD if convert else - FIFF.FIFFV_MNE_COORD_4D_HEAD) - if chan_4d.startswith('A'): - chan_info['kind'] = FIFF.FIFFV_MEG_CH - chan_info['coil_type'] = FIFF.FIFFV_COIL_MAGNES_MAG - chan_info['coord_frame'] = meg_frame - chan_info['unit'] = FIFF.FIFF_UNIT_T - - elif chan_4d.startswith('M'): - chan_info['kind'] = FIFF.FIFFV_REF_MEG_CH - chan_info['coil_type'] = FIFF.FIFFV_COIL_MAGNES_REF_MAG - chan_info['coord_frame'] = meg_frame - chan_info['unit'] = FIFF.FIFF_UNIT_T - - elif chan_4d.startswith('G'): - chan_info['kind'] = FIFF.FIFFV_REF_MEG_CH - chan_info['coord_frame'] = meg_frame - chan_info['unit'] = FIFF.FIFF_UNIT_T_M - if chan_4d in ('GxxA', 'GyyA'): - chan_info['coil_type'] = FIFF.FIFFV_COIL_MAGNES_REF_GRAD - elif chan_4d in ('GyxA', 'GzxA', 'GzyA'): - chan_info['coil_type'] = \ - FIFF.FIFFV_COIL_MAGNES_OFFDIAG_REF_GRAD - - elif chan_4d.startswith('EEG') or _eeg_like(chan_4d): - chan_info['kind'] = FIFF.FIFFV_EEG_CH - chan_info['coil_type'] = FIFF.FIFFV_COIL_EEG - chan_info['coord_frame'] = eeg_frame - chan_info['unit'] = FIFF.FIFF_UNIT_V + meg_frame = FIFF.FIFFV_COORD_DEVICE if convert else FIFF.FIFFV_MNE_COORD_4D_HEAD + eeg_frame = FIFF.FIFFV_COORD_HEAD if convert else FIFF.FIFFV_MNE_COORD_4D_HEAD + if chan_4d.startswith("A"): + chan_info["kind"] = FIFF.FIFFV_MEG_CH + chan_info["coil_type"] = FIFF.FIFFV_COIL_MAGNES_MAG + chan_info["coord_frame"] = meg_frame + chan_info["unit"] = FIFF.FIFF_UNIT_T + + elif chan_4d.startswith("M"): + chan_info["kind"] = FIFF.FIFFV_REF_MEG_CH + chan_info["coil_type"] = FIFF.FIFFV_COIL_MAGNES_REF_MAG + chan_info["coord_frame"] = meg_frame + chan_info["unit"] = FIFF.FIFF_UNIT_T + + elif chan_4d.startswith("G"): + chan_info["kind"] = FIFF.FIFFV_REF_MEG_CH + chan_info["coord_frame"] = meg_frame + chan_info["unit"] = FIFF.FIFF_UNIT_T_M + if chan_4d in ("GxxA", "GyyA"): + chan_info["coil_type"] = FIFF.FIFFV_COIL_MAGNES_REF_GRAD + elif chan_4d in ("GyxA", "GzxA", "GzyA"): + chan_info["coil_type"] = FIFF.FIFFV_COIL_MAGNES_OFFDIAG_REF_GRAD + + elif chan_4d.startswith("EEG") or _eeg_like(chan_4d): + chan_info["kind"] = FIFF.FIFFV_EEG_CH + chan_info["coil_type"] = FIFF.FIFFV_COIL_EEG + chan_info["coord_frame"] = eeg_frame + chan_info["unit"] = FIFF.FIFF_UNIT_V # TODO: We should use 'electrodes' to fill this in, and make sure # we turn them into dig as well - chan_info['loc'][:3] = np.nan - - elif chan_4d == 'RESPONSE': - chan_info['kind'] = FIFF.FIFFV_STIM_CH - elif chan_4d == 'TRIGGER': - chan_info['kind'] = FIFF.FIFFV_STIM_CH - elif chan_4d.startswith('EOG') or \ - chan_4d[:4] in ('HEOG', 'VEOG') or chan_4d in eog_ch: - chan_info['kind'] = FIFF.FIFFV_EOG_CH - elif chan_4d.startswith('EMG'): - chan_info['kind'] = FIFF.FIFFV_EMG_CH - elif chan_4d == ecg_ch or chan_4d.startswith('ECG'): - chan_info['kind'] = FIFF.FIFFV_ECG_CH + chan_info["loc"][:3] = np.nan + + elif chan_4d == "RESPONSE": + chan_info["kind"] = FIFF.FIFFV_STIM_CH + elif chan_4d == "TRIGGER": + chan_info["kind"] = FIFF.FIFFV_STIM_CH + elif ( + chan_4d.startswith("EOG") + or chan_4d[:4] in ("HEOG", "VEOG") + or chan_4d in eog_ch + ): + chan_info["kind"] = FIFF.FIFFV_EOG_CH + elif chan_4d.startswith("EMG"): + chan_info["kind"] = FIFF.FIFFV_EMG_CH + elif chan_4d == ecg_ch or chan_4d.startswith("ECG"): + chan_info["kind"] = FIFF.FIFFV_ECG_CH # Our default is now misc, but if we ever change that, # we'll need this: # elif chan_4d.startswith('X') or chan_4d == 'UACurrent': @@ -1229,17 +1387,19 @@ def _get_bti_info(pdf_fname, config_fname, head_shape_fname, rotation_x, chs.append(chan_info) - info['chs'] = chs + info["chs"] = chs # ### Dig stuff info = _make_bti_digitization( - info, head_shape_fname, convert, use_hpi, bti_dev_t, dev_ctf_t) + info, head_shape_fname, convert, use_hpi, bti_dev_t, dev_ctf_t + ) logger.info( - 'Currently direct inclusion of 4D weight tables is not supported.' - ' For critical use cases please take into account the MNE command' + "Currently direct inclusion of 4D weight tables is not supported." + " For critical use cases please take into account the MNE command" ' "mne_create_comp_data" to include weights as printed out by ' - 'the 4D "print_table" routine.') + 'the 4D "print_table" routine.' + ) # check that the info is complete info._unlocked = False @@ -1249,12 +1409,20 @@ def _get_bti_info(pdf_fname, config_fname, head_shape_fname, rotation_x, @verbose -def read_raw_bti(pdf_fname, config_fname='config', - head_shape_fname='hs_file', rotation_x=0., - translation=(0.0, 0.02, 0.11), convert=True, - rename_channels=True, sort_by_ch_name=True, - ecg_ch='E31', eog_ch=('E63', 'E64'), preload=False, - verbose=None): +def read_raw_bti( + pdf_fname, + config_fname="config", + head_shape_fname="hs_file", + rotation_x=0.0, + translation=(0.0, 0.02, 0.11), + convert=True, + rename_channels=True, + sort_by_ch_name=True, + ecg_ch="E31", + eog_ch=("E63", "E64"), + preload=False, + verbose=None, +): """Raw object from 4D Neuroimaging MagnesWH3600 data. .. note:: @@ -1309,9 +1477,17 @@ def read_raw_bti(pdf_fname, config_fname='config', -------- mne.io.Raw : Documentation of attributes and methods of RawBTi. """ - return RawBTi(pdf_fname, config_fname=config_fname, - head_shape_fname=head_shape_fname, - rotation_x=rotation_x, translation=translation, - convert=convert, rename_channels=rename_channels, - sort_by_ch_name=sort_by_ch_name, ecg_ch=ecg_ch, - eog_ch=eog_ch, preload=preload, verbose=verbose) + return RawBTi( + pdf_fname, + config_fname=config_fname, + head_shape_fname=head_shape_fname, + rotation_x=rotation_x, + translation=translation, + convert=convert, + rename_channels=rename_channels, + sort_by_ch_name=sort_by_ch_name, + ecg_ch=ecg_ch, + eog_ch=eog_ch, + preload=preload, + verbose=verbose, + ) diff --git a/mne/io/bti/constants.py b/mne/io/bti/constants.py index ca09e449af4..69c27da485c 100644 --- a/mne/io/bti/constants.py +++ b/mne/io/bti/constants.py @@ -6,94 +6,94 @@ BTI = BunchConst() -BTI.ELEC_STATE_NOT_COLLECTED = 0 -BTI.ELEC_STATE_COLLECTED = 1 -BTI.ELEC_STATE_SKIPPED = 2 -BTI.ELEC_STATE_NOT_APPLICABLE = 3 +BTI.ELEC_STATE_NOT_COLLECTED = 0 +BTI.ELEC_STATE_COLLECTED = 1 +BTI.ELEC_STATE_SKIPPED = 2 +BTI.ELEC_STATE_NOT_APPLICABLE = 3 # ## Byte offesets and data sizes for different files # -BTI.FILE_MASK = 2147483647 -BTI.FILE_CURPOS = 8 -BTI.FILE_END = -8 +BTI.FILE_MASK = 2147483647 +BTI.FILE_CURPOS = 8 +BTI.FILE_END = -8 -BTI.FILE_HS_VERSION = 0 -BTI.FILE_HS_TIMESTAMP = 4 -BTI.FILE_HS_CHECKSUM = 8 -BTI.FILE_HS_N_DIGPOINTS = 12 -BTI.FILE_HS_N_INDEXPOINTS = 16 +BTI.FILE_HS_VERSION = 0 +BTI.FILE_HS_TIMESTAMP = 4 +BTI.FILE_HS_CHECKSUM = 8 +BTI.FILE_HS_N_DIGPOINTS = 12 +BTI.FILE_HS_N_INDEXPOINTS = 16 -BTI.FILE_PDF_H_ENTER = 1 -BTI.FILE_PDF_H_FTYPE = 5 -BTI.FILE_PDF_H_XLABEL = 16 -BTI.FILE_PDF_H_NEXT = 2 -BTI.FILE_PDF_H_EXIT = 20 +BTI.FILE_PDF_H_ENTER = 1 +BTI.FILE_PDF_H_FTYPE = 5 +BTI.FILE_PDF_H_XLABEL = 16 +BTI.FILE_PDF_H_NEXT = 2 +BTI.FILE_PDF_H_EXIT = 20 -BTI.FILE_PDF_EPOCH_EXIT = 28 +BTI.FILE_PDF_EPOCH_EXIT = 28 -BTI.FILE_PDF_CH_NEXT = 6 -BTI.FILE_PDF_CH_LABELSIZE = 16 -BTI.FILE_PDF_CH_YLABEL = 16 -BTI.FILE_PDF_CH_OFF_FLAG = 16 -BTI.FILE_PDF_CH_EXIT = 12 +BTI.FILE_PDF_CH_NEXT = 6 +BTI.FILE_PDF_CH_LABELSIZE = 16 +BTI.FILE_PDF_CH_YLABEL = 16 +BTI.FILE_PDF_CH_OFF_FLAG = 16 +BTI.FILE_PDF_CH_EXIT = 12 -BTI.FILE_PDF_EVENT_NAME = 16 -BTI.FILE_PDF_EVENT_EXIT = 32 +BTI.FILE_PDF_EVENT_NAME = 16 +BTI.FILE_PDF_EVENT_EXIT = 32 -BTI.FILE_PDF_PROCESS_BLOCKTYPE = 20 -BTI.FILE_PDF_PROCESS_USER = 32 -BTI.FILE_PDF_PROCESS_FNAME = 256 -BTI.FILE_PDF_PROCESS_EXIT = 32 +BTI.FILE_PDF_PROCESS_BLOCKTYPE = 20 +BTI.FILE_PDF_PROCESS_USER = 32 +BTI.FILE_PDF_PROCESS_FNAME = 256 +BTI.FILE_PDF_PROCESS_EXIT = 32 -BTI.FILE_PDF_ASSOC_NEXT = 32 +BTI.FILE_PDF_ASSOC_NEXT = 32 -BTI.FILE_PDFED_NAME = 17 -BTI.FILE_PDFED_NEXT = 9 -BTI.FILE_PDFED_EXIT = 8 +BTI.FILE_PDFED_NAME = 17 +BTI.FILE_PDFED_NEXT = 9 +BTI.FILE_PDFED_EXIT = 8 # ## General data constants # -BTI.DATA_N_IDX_POINTS = 5 -BTI.DATA_ROT_N_ROW = 3 -BTI.DATA_ROT_N_COL = 3 -BTI.DATA_XFM_N_COL = 4 -BTI.DATA_XFM_N_ROW = 4 -BTI.FIFF_LOGNO = 111 +BTI.DATA_N_IDX_POINTS = 5 +BTI.DATA_ROT_N_ROW = 3 +BTI.DATA_ROT_N_COL = 3 +BTI.DATA_XFM_N_COL = 4 +BTI.DATA_XFM_N_ROW = 4 +BTI.FIFF_LOGNO = 111 # ## Channel Types # -BTI.CHTYPE_MEG = 1 -BTI.CHTYPE_EEG = 2 -BTI.CHTYPE_REFERENCE = 3 -BTI.CHTYPE_EXTERNAL = 4 -BTI.CHTYPE_TRIGGER = 5 -BTI.CHTYPE_UTILITY = 6 -BTI.CHTYPE_DERIVED = 7 -BTI.CHTYPE_SHORTED = 8 +BTI.CHTYPE_MEG = 1 +BTI.CHTYPE_EEG = 2 +BTI.CHTYPE_REFERENCE = 3 +BTI.CHTYPE_EXTERNAL = 4 +BTI.CHTYPE_TRIGGER = 5 +BTI.CHTYPE_UTILITY = 6 +BTI.CHTYPE_DERIVED = 7 +BTI.CHTYPE_SHORTED = 8 # ## Processes # -BTI.PROC_DEFAULTS = 'BTi_defaults' -BTI.PROC_FILTER = 'b_filt_hp,b_filt_lp,b_filt_notch' -BTI.PROC_BPFILTER = 'b_filt_b_pass,b_filt_b_reject' +BTI.PROC_DEFAULTS = "BTi_defaults" +BTI.PROC_FILTER = "b_filt_hp,b_filt_lp,b_filt_notch" +BTI.PROC_BPFILTER = "b_filt_b_pass,b_filt_b_reject" # ## User blocks # -BTI.UB_B_MAG_INFO = 'B_Mag_Info' -BTI.UB_B_COH_POINTS = 'B_COH_Points' -BTI.UB_B_CCP_XFM_BLOCK = 'b_ccp_xfm_block' -BTI.UB_B_EEG_LOCS = 'b_eeg_elec_locs' -BTI.UB_B_WHC_CHAN_MAP_VER = 'B_WHChanMapVer' -BTI.UB_B_WHC_CHAN_MAP = 'B_WHChanMap' -BTI.UB_B_WHS_SUBSYS_VER = 'B_WHSubsysVer' # B_WHSubsysVer -BTI.UB_B_WHS_SUBSYS = 'B_WHSubsys' -BTI.UB_B_CH_LABELS = 'B_ch_labels' -BTI.UB_B_CALIBRATION = 'B_Calibration' -BTI.UB_B_SYS_CONFIG_TIME = 'B_SysConfigTime' -BTI.UB_B_DELTA_ENABLED = 'B_DELTA_ENABLED' -BTI.UB_B_E_TABLE_USED = 'B_E_table_used' -BTI.UB_B_E_TABLE = 'B_E_TABLE' -BTI.UB_B_WEIGHTS_USED = 'B_weights_used' -BTI.UB_B_TRIG_MASK = 'B_trig_mask' -BTI.UB_B_WEIGHT_TABLE = 'BWT_' +BTI.UB_B_MAG_INFO = "B_Mag_Info" +BTI.UB_B_COH_POINTS = "B_COH_Points" +BTI.UB_B_CCP_XFM_BLOCK = "b_ccp_xfm_block" +BTI.UB_B_EEG_LOCS = "b_eeg_elec_locs" +BTI.UB_B_WHC_CHAN_MAP_VER = "B_WHChanMapVer" +BTI.UB_B_WHC_CHAN_MAP = "B_WHChanMap" +BTI.UB_B_WHS_SUBSYS_VER = "B_WHSubsysVer" # B_WHSubsysVer +BTI.UB_B_WHS_SUBSYS = "B_WHSubsys" +BTI.UB_B_CH_LABELS = "B_ch_labels" +BTI.UB_B_CALIBRATION = "B_Calibration" +BTI.UB_B_SYS_CONFIG_TIME = "B_SysConfigTime" +BTI.UB_B_DELTA_ENABLED = "B_DELTA_ENABLED" +BTI.UB_B_E_TABLE_USED = "B_E_table_used" +BTI.UB_B_E_TABLE = "B_E_TABLE" +BTI.UB_B_WEIGHTS_USED = "B_weights_used" +BTI.UB_B_TRIG_MASK = "B_trig_mask" +BTI.UB_B_WEIGHT_TABLE = "BWT_" diff --git a/mne/io/bti/read.py b/mne/io/bti/read.py index 210ff827992..f3f9e889ecd 100644 --- a/mne/io/bti/read.py +++ b/mne/io/bti/read.py @@ -11,8 +11,7 @@ def _unpack_matrix(fid, rows, cols, dtype, out_dtype): dtype = np.dtype(dtype) string = fid.read(int(dtype.itemsize * rows * cols)) - out = np.frombuffer(string, dtype=dtype).reshape( - rows, cols).astype(out_dtype) + out = np.frombuffer(string, dtype=dtype).reshape(rows, cols).astype(out_dtype) return out @@ -29,80 +28,77 @@ def _unpack_simple(fid, dtype, out_dtype): def read_char(fid, count=1): """Read character from bti file.""" - return _unpack_simple(fid, '>S%s' % count, 'S') + return _unpack_simple(fid, ">S%s" % count, "S") def read_bool(fid): """Read bool value from bti file.""" - return _unpack_simple(fid, '>?', bool) + return _unpack_simple(fid, ">?", bool) def read_uint8(fid): """Read unsigned 8bit integer from bti file.""" - return _unpack_simple(fid, '>u1', np.uint8) + return _unpack_simple(fid, ">u1", np.uint8) def read_int8(fid): """Read 8bit integer from bti file.""" - return _unpack_simple(fid, '>i1', np.int8) + return _unpack_simple(fid, ">i1", np.int8) def read_uint16(fid): """Read unsigned 16bit integer from bti file.""" - return _unpack_simple(fid, '>u2', np.uint16) + return _unpack_simple(fid, ">u2", np.uint16) def read_int16(fid): """Read 16bit integer from bti file.""" - return _unpack_simple(fid, '>i2', np.int16) + return _unpack_simple(fid, ">i2", np.int16) def read_uint32(fid): """Read unsigned 32bit integer from bti file.""" - return _unpack_simple(fid, '>u4', np.uint32) + return _unpack_simple(fid, ">u4", np.uint32) def read_int32(fid): """Read 32bit integer from bti file.""" - return _unpack_simple(fid, '>i4', np.int32) + return _unpack_simple(fid, ">i4", np.int32) def read_uint64(fid): """Read unsigned 64bit integer from bti file.""" - return _unpack_simple(fid, '>u8', np.uint64) + return _unpack_simple(fid, ">u8", np.uint64) def read_int64(fid): """Read 64bit integer from bti file.""" - return _unpack_simple(fid, '>u8', np.int64) + return _unpack_simple(fid, ">u8", np.int64) def read_float(fid): """Read 32bit float from bti file.""" - return _unpack_simple(fid, '>f4', np.float32) + return _unpack_simple(fid, ">f4", np.float32) def read_double(fid): """Read 64bit float from bti file.""" - return _unpack_simple(fid, '>f8', np.float64) + return _unpack_simple(fid, ">f8", np.float64) def read_int16_matrix(fid, rows, cols): """Read 16bit integer matrix from bti file.""" - return _unpack_matrix(fid, rows, cols, dtype='>i2', - out_dtype=np.int16) + return _unpack_matrix(fid, rows, cols, dtype=">i2", out_dtype=np.int16) def read_float_matrix(fid, rows, cols): """Read 32bit float matrix from bti file.""" - return _unpack_matrix(fid, rows, cols, dtype='>f4', - out_dtype=np.float32) + return _unpack_matrix(fid, rows, cols, dtype=">f4", out_dtype=np.float32) def read_double_matrix(fid, rows, cols): """Read 64bit float matrix from bti file.""" - return _unpack_matrix(fid, rows, cols, dtype='>f8', - out_dtype=np.float64) + return _unpack_matrix(fid, rows, cols, dtype=">f8", out_dtype=np.float64) def read_transform(fid): @@ -112,5 +108,4 @@ def read_transform(fid): def read_dev_header(x): """Create a dev header.""" - return dict(size=read_int32(x), checksum=read_int32(x), - reserved=read_str(x, 32)) + return dict(size=read_int32(x), checksum=read_int32(x), reserved=read_str(x, 32)) diff --git a/mne/io/bti/tests/test_bti.py b/mne/io/bti/tests/test_bti.py index 0df08917bc6..fef5594cf67 100644 --- a/mne/io/bti/tests/test_bti.py +++ b/mne/io/bti/tests/test_bti.py @@ -9,19 +9,30 @@ from pathlib import Path import numpy as np -from numpy.testing import (assert_array_almost_equal, assert_array_equal, - assert_allclose, assert_equal) +from numpy.testing import ( + assert_array_almost_equal, + assert_array_equal, + assert_allclose, + assert_equal, +) import pytest import mne from mne.datasets import testing from mne.io import read_raw_fif, read_raw_bti from mne.io._digitization import _make_bti_dig_points -from mne.io.bti.bti import (_read_config, _read_head_shape, - _read_bti_header, _get_bti_dev_t, - _correct_trans, _get_bti_info, - _loc_to_coil_trans, _convert_coil_trans, - _check_nan_dev_head_t, _rename_channels) +from mne.io.bti.bti import ( + _read_config, + _read_head_shape, + _read_bti_header, + _get_bti_dev_t, + _correct_trans, + _get_bti_info, + _loc_to_coil_trans, + _convert_coil_trans, + _check_nan_dev_head_t, + _rename_channels, +) from mne.io.tests.test_raw import _test_raw_reader from mne.io.pick import pick_info from mne.io.constants import FIFF @@ -55,21 +66,25 @@ def test_read_2500(): def test_no_loc_none(monkeypatch): """Test that we don't set loc to None when no trans is found.""" - ch_name = 'MLzA' + ch_name = "MLzA" def _read_config_bad(*args, **kwargs): cfg = _read_config(*args, **kwargs) - idx = [ch['name'] for ch in cfg['chs']].index(ch_name) - del cfg['chs'][idx]['dev']['transform'] + idx = [ch["name"] for ch in cfg["chs"]].index(ch_name) + del cfg["chs"][idx]["dev"]["transform"] return cfg - monkeypatch.setattr(mne.io.bti.bti, '_read_config', _read_config_bad) - kwargs = dict(pdf_fname=pdf_fnames[0], config_fname=config_fnames[0], - head_shape_fname=hs_fnames[0], rename_channels=False, - sort_by_ch_name=False) + monkeypatch.setattr(mne.io.bti.bti, "_read_config", _read_config_bad) + kwargs = dict( + pdf_fname=pdf_fnames[0], + config_fname=config_fnames[0], + head_shape_fname=hs_fnames[0], + rename_channels=False, + sort_by_ch_name=False, + ) raw = read_raw_bti(**kwargs) idx = raw.ch_names.index(ch_name) - assert_allclose(raw.info['chs'][idx]['loc'], np.full(12, np.nan)) + assert_allclose(raw.info["chs"][idx]["loc"], np.full(12, np.nan)) def test_read_config(): @@ -77,101 +92,117 @@ def test_read_config(): # for config in config_fname, config_solaris_fname: for config in config_fnames: cfg = _read_config(config) - assert all('unknown' not in block.lower() and block != '' - for block in cfg['user_blocks']) + assert all( + "unknown" not in block.lower() and block != "" + for block in cfg["user_blocks"] + ) def test_crop_append(): """Test crop and append raw.""" raw = _test_raw_reader( - read_raw_bti, pdf_fname=pdf_fnames[0], - config_fname=config_fnames[0], head_shape_fname=hs_fnames[0]) + read_raw_bti, + pdf_fname=pdf_fnames[0], + config_fname=config_fnames[0], + head_shape_fname=hs_fnames[0], + ) y, t = raw[:] t0, t1 = 0.25 * t[-1], 0.75 * t[-1] mask = (t0 <= t) * (t <= t1) raw_ = raw.copy().crop(t0, t1) y_, _ = raw_[:] - assert (y_.shape[1] == mask.sum()) - assert (y_.shape[0] == y.shape[0]) + assert y_.shape[1] == mask.sum() + assert y_.shape[0] == y.shape[0] def test_transforms(): """Test transformations.""" bti_trans = (0.0, 0.02, 0.11) - bti_dev_t = Transform('ctf_meg', 'meg', _get_bti_dev_t(0.0, bti_trans)) - for pdf, config, hs, in zip(pdf_fnames, config_fnames, hs_fnames): + bti_dev_t = Transform("ctf_meg", "meg", _get_bti_dev_t(0.0, bti_trans)) + for ( + pdf, + config, + hs, + ) in zip(pdf_fnames, config_fnames, hs_fnames): raw = read_raw_bti(pdf, config, hs, preload=False) - dev_ctf_t = raw.info['dev_ctf_t'] - dev_head_t_old = raw.info['dev_head_t'] - ctf_head_t = raw.info['ctf_head_t'] + dev_ctf_t = raw.info["dev_ctf_t"] + dev_head_t_old = raw.info["dev_head_t"] + ctf_head_t = raw.info["ctf_head_t"] # 1) get BTI->Neuromag - bti_dev_t = Transform('ctf_meg', 'meg', _get_bti_dev_t(0.0, bti_trans)) + bti_dev_t = Transform("ctf_meg", "meg", _get_bti_dev_t(0.0, bti_trans)) # 2) get Neuromag->BTI head - t = combine_transforms(invert_transform(bti_dev_t), dev_ctf_t, - 'meg', 'ctf_head') + t = combine_transforms( + invert_transform(bti_dev_t), dev_ctf_t, "meg", "ctf_head" + ) # 3) get Neuromag->head - dev_head_t_new = combine_transforms(t, ctf_head_t, 'meg', 'head') + dev_head_t_new = combine_transforms(t, ctf_head_t, "meg", "head") - assert_array_equal(dev_head_t_new['trans'], dev_head_t_old['trans']) + assert_array_equal(dev_head_t_new["trans"], dev_head_t_old["trans"]) @pytest.mark.slowtest def test_raw(): """Test bti conversion to Raw object.""" - for pdf, config, hs, exported in zip(pdf_fnames, config_fnames, hs_fnames, - exported_fnames): + for pdf, config, hs, exported in zip( + pdf_fnames, config_fnames, hs_fnames, exported_fnames + ): # rx = 2 if 'linux' in pdf else 0 - pytest.raises(ValueError, read_raw_bti, pdf, 'eggs', preload=False) - pytest.raises(ValueError, read_raw_bti, pdf, config, 'spam', - preload=False) + pytest.raises(ValueError, read_raw_bti, pdf, "eggs", preload=False) + pytest.raises(ValueError, read_raw_bti, pdf, config, "spam", preload=False) if tmp_raw_fname.exists(): os.remove(tmp_raw_fname) ex = read_raw_fif(exported, preload=True) ra = read_raw_bti(pdf, config, hs, preload=False) - assert ('RawBTi' in repr(ra)) + assert "RawBTi" in repr(ra) assert_equal(ex.ch_names[:NCH], ra.ch_names[:NCH]) - assert_array_almost_equal(ex.info['dev_head_t']['trans'], - ra.info['dev_head_t']['trans'], 7) - assert len(ex.info['dig']) in (3563, 5154) + assert_array_almost_equal( + ex.info["dev_head_t"]["trans"], ra.info["dev_head_t"]["trans"], 7 + ) + assert len(ex.info["dig"]) in (3563, 5154) assert_dig_allclose(ex.info, ra.info, limit=100) - coil1, coil2 = [np.concatenate([d['loc'].flatten() - for d in r_.info['chs'][:NCH]]) - for r_ in (ra, ex)] + coil1, coil2 = [ + np.concatenate([d["loc"].flatten() for d in r_.info["chs"][:NCH]]) + for r_ in (ra, ex) + ] assert_array_almost_equal(coil1, coil2, 7) - loc1, loc2 = [np.concatenate([d['loc'].flatten() - for d in r_.info['chs'][:NCH]]) - for r_ in (ra, ex)] + loc1, loc2 = [ + np.concatenate([d["loc"].flatten() for d in r_.info["chs"][:NCH]]) + for r_ in (ra, ex) + ] assert_allclose(loc1, loc2) assert_allclose(ra[:NCH][0], ex[:NCH][0]) - assert_array_equal([c['range'] for c in ra.info['chs'][:NCH]], - [c['range'] for c in ex.info['chs'][:NCH]]) - assert_array_equal([c['cal'] for c in ra.info['chs'][:NCH]], - [c['cal'] for c in ex.info['chs'][:NCH]]) + assert_array_equal( + [c["range"] for c in ra.info["chs"][:NCH]], + [c["range"] for c in ex.info["chs"][:NCH]], + ) + assert_array_equal( + [c["cal"] for c in ra.info["chs"][:NCH]], + [c["cal"] for c in ex.info["chs"][:NCH]], + ) assert_array_equal(ra._cals[:NCH], ex._cals[:NCH]) # check our transforms - for key in ('dev_head_t', 'dev_ctf_t', 'ctf_head_t'): + for key in ("dev_head_t", "dev_ctf_t", "ctf_head_t"): if ex.info[key] is None: pass else: - assert (ra.info[key] is not None) - for ent in ('to', 'from', 'trans'): - assert_allclose(ex.info[key][ent], - ra.info[key][ent]) + assert ra.info[key] is not None + for ent in ("to", "from", "trans"): + assert_allclose(ex.info[key][ent], ra.info[key][ent]) ra.save(tmp_raw_fname) re = read_raw_fif(tmp_raw_fname) print(re) - for key in ('dev_head_t', 'dev_ctf_t', 'ctf_head_t'): - assert (isinstance(re.info[key], dict)) - this_t = re.info[key]['trans'] + for key in ("dev_head_t", "dev_ctf_t", "ctf_head_t"): + assert isinstance(re.info[key], dict) + this_t = re.info[key]["trans"] assert_equal(this_t.shape, (4, 4)) # check that matrix by is not identity - assert (not np.allclose(this_t, np.eye(4))) + assert not np.allclose(this_t, np.eye(4)) os.remove(tmp_raw_fname) @@ -179,56 +210,80 @@ def test_info_no_rename_no_reorder_no_pdf(): """Test private renaming, reordering and partial construction option.""" for pdf, config, hs in zip(pdf_fnames, config_fnames, hs_fnames): info, bti_info = _get_bti_info( - pdf_fname=pdf, config_fname=config, head_shape_fname=hs, - rotation_x=0.0, translation=(0.0, 0.02, 0.11), convert=False, - ecg_ch='E31', eog_ch=('E63', 'E64'), - rename_channels=False, sort_by_ch_name=False) + pdf_fname=pdf, + config_fname=config, + head_shape_fname=hs, + rotation_x=0.0, + translation=(0.0, 0.02, 0.11), + convert=False, + ecg_ch="E31", + eog_ch=("E63", "E64"), + rename_channels=False, + sort_by_ch_name=False, + ) info2, bti_info = _get_bti_info( - pdf_fname=None, config_fname=config, head_shape_fname=hs, - rotation_x=0.0, translation=(0.0, 0.02, 0.11), convert=False, - ecg_ch='E31', eog_ch=('E63', 'E64'), - rename_channels=False, sort_by_ch_name=False) - - assert_equal(info['ch_names'], - [ch['ch_name'] for ch in info['chs']]) - assert_equal([n for n in info['ch_names'] if n.startswith('A')][:5], - ['A22', 'A2', 'A104', 'A241', 'A138']) - assert_equal([n for n in info['ch_names'] if n.startswith('A')][-5:], - ['A133', 'A158', 'A44', 'A134', 'A216']) - - info = pick_info(info, pick_types(info, meg=True, stim=True, - resp=True)) - info2 = pick_info(info2, pick_types(info2, meg=True, stim=True, - resp=True)) - - assert (info['sfreq'] is not None) - assert (info['lowpass'] is not None) - assert (info['highpass'] is not None) - assert (info['meas_date'] is not None) - - assert_equal(info2['sfreq'], None) - assert_equal(info2['lowpass'], None) - assert_equal(info2['highpass'], None) - assert_equal(info2['meas_date'], None) - - assert_equal(info['ch_names'], info2['ch_names']) - assert_equal(info['ch_names'], info2['ch_names']) - for key in ['dev_ctf_t', 'dev_head_t', 'ctf_head_t']: - assert_array_equal(info[key]['trans'], info2[key]['trans']) + pdf_fname=None, + config_fname=config, + head_shape_fname=hs, + rotation_x=0.0, + translation=(0.0, 0.02, 0.11), + convert=False, + ecg_ch="E31", + eog_ch=("E63", "E64"), + rename_channels=False, + sort_by_ch_name=False, + ) + + assert_equal(info["ch_names"], [ch["ch_name"] for ch in info["chs"]]) + assert_equal( + [n for n in info["ch_names"] if n.startswith("A")][:5], + ["A22", "A2", "A104", "A241", "A138"], + ) + assert_equal( + [n for n in info["ch_names"] if n.startswith("A")][-5:], + ["A133", "A158", "A44", "A134", "A216"], + ) + + info = pick_info(info, pick_types(info, meg=True, stim=True, resp=True)) + info2 = pick_info(info2, pick_types(info2, meg=True, stim=True, resp=True)) + + assert info["sfreq"] is not None + assert info["lowpass"] is not None + assert info["highpass"] is not None + assert info["meas_date"] is not None + + assert_equal(info2["sfreq"], None) + assert_equal(info2["lowpass"], None) + assert_equal(info2["highpass"], None) + assert_equal(info2["meas_date"], None) + + assert_equal(info["ch_names"], info2["ch_names"]) + assert_equal(info["ch_names"], info2["ch_names"]) + for key in ["dev_ctf_t", "dev_head_t", "ctf_head_t"]: + assert_array_equal(info[key]["trans"], info2[key]["trans"]) assert_array_equal( - np.array([ch['loc'] for ch in info['chs']]), - np.array([ch['loc'] for ch in info2['chs']])) + np.array([ch["loc"] for ch in info["chs"]]), + np.array([ch["loc"] for ch in info2["chs"]]), + ) # just check reading data | corner case raw1 = read_raw_bti( - pdf_fname=pdf, config_fname=config, head_shape_fname=None, - sort_by_ch_name=False, preload=True) + pdf_fname=pdf, + config_fname=config, + head_shape_fname=None, + sort_by_ch_name=False, + preload=True, + ) # just check reading data | corner case raw2 = read_raw_bti( - pdf_fname=pdf, config_fname=config, head_shape_fname=None, + pdf_fname=pdf, + config_fname=config, + head_shape_fname=None, rename_channels=False, - sort_by_ch_name=True, preload=True) + sort_by_ch_name=True, + preload=True, + ) sort_idx = [raw1.bti_ch_labels.index(ch) for ch in raw2.bti_ch_labels] raw1._data = raw1._data[sort_idx] @@ -240,60 +295,70 @@ def test_no_conversion(): """Test bti no-conversion option.""" get_info = partial( _get_bti_info, - rotation_x=0.0, translation=(0.0, 0.02, 0.11), convert=False, - ecg_ch='E31', eog_ch=('E63', 'E64'), - rename_channels=False, sort_by_ch_name=False) + rotation_x=0.0, + translation=(0.0, 0.02, 0.11), + convert=False, + ecg_ch="E31", + eog_ch=("E63", "E64"), + rename_channels=False, + sort_by_ch_name=False, + ) for pdf, config, hs in zip(pdf_fnames, config_fnames, hs_fnames): raw_info, _ = get_info(pdf, config, hs, convert=False) raw_info_con = read_raw_bti( - pdf_fname=pdf, config_fname=config, head_shape_fname=hs, - convert=True, preload=False).info - - pick_info(raw_info_con, - pick_types(raw_info_con, meg=True, ref_meg=True), - copy=False) - pick_info(raw_info, - pick_types(raw_info, meg=True, ref_meg=True), copy=False) + pdf_fname=pdf, + config_fname=config, + head_shape_fname=hs, + convert=True, + preload=False, + ).info + + pick_info( + raw_info_con, pick_types(raw_info_con, meg=True, ref_meg=True), copy=False + ) + pick_info(raw_info, pick_types(raw_info, meg=True, ref_meg=True), copy=False) bti_info = _read_bti_header(pdf, config) - dev_ctf_t = _correct_trans(bti_info['bti_transform'][0]) - assert_array_equal(dev_ctf_t, raw_info['dev_ctf_t']['trans']) - assert_array_equal(raw_info['dev_head_t']['trans'], np.eye(4)) - assert_array_equal(raw_info['ctf_head_t']['trans'], np.eye(4)) + dev_ctf_t = _correct_trans(bti_info["bti_transform"][0]) + assert_array_equal(dev_ctf_t, raw_info["dev_ctf_t"]["trans"]) + assert_array_equal(raw_info["dev_head_t"]["trans"], np.eye(4)) + assert_array_equal(raw_info["ctf_head_t"]["trans"], np.eye(4)) nasion, lpa, rpa, hpi, dig_points = _read_head_shape(hs) - dig, t, _ = _make_bti_dig_points(nasion, lpa, rpa, hpi, dig_points, - convert=False, use_hpi=False) + dig, t, _ = _make_bti_dig_points( + nasion, lpa, rpa, hpi, dig_points, convert=False, use_hpi=False + ) - assert_array_equal(t['trans'], np.eye(4)) + assert_array_equal(t["trans"], np.eye(4)) - for ii, (old, new, con) in enumerate(zip( - dig, raw_info['dig'], raw_info_con['dig'])): - assert_equal(old['ident'], new['ident']) - assert_array_equal(old['r'], new['r']) - assert (not np.allclose(old['r'], con['r'])) + for ii, (old, new, con) in enumerate( + zip(dig, raw_info["dig"], raw_info_con["dig"]) + ): + assert_equal(old["ident"], new["ident"]) + assert_array_equal(old["r"], new["r"]) + assert not np.allclose(old["r"], con["r"]) if ii > 10: break - ch_map = {ch['chan_label']: ch['loc'] for ch in bti_info['chs']} + ch_map = {ch["chan_label"]: ch["loc"] for ch in bti_info["chs"]} - for ii, ch_label in enumerate(raw_info['ch_names']): - if not ch_label.startswith('A'): + for ii, ch_label in enumerate(raw_info["ch_names"]): + if not ch_label.startswith("A"): continue t1 = ch_map[ch_label] # correction already performed in bti_info - t2 = raw_info['chs'][ii]['loc'] - t3 = raw_info_con['chs'][ii]['loc'] + t2 = raw_info["chs"][ii]["loc"] + t3 = raw_info_con["chs"][ii]["loc"] assert_allclose(t1, t2, atol=1e-15) - assert (not np.allclose(t1, t3)) - idx_a = raw_info_con['ch_names'].index('MEG 001') - idx_b = raw_info['ch_names'].index('A22') + assert not np.allclose(t1, t3) + idx_a = raw_info_con["ch_names"].index("MEG 001") + idx_b = raw_info["ch_names"].index("A22") assert_equal( - raw_info_con['chs'][idx_a]['coord_frame'], - FIFF.FIFFV_COORD_DEVICE) + raw_info_con["chs"][idx_a]["coord_frame"], FIFF.FIFFV_COORD_DEVICE + ) assert_equal( - raw_info['chs'][idx_b]['coord_frame'], - FIFF.FIFFV_MNE_COORD_4D_HEAD) + raw_info["chs"][idx_b]["coord_frame"], FIFF.FIFFV_MNE_COORD_4D_HEAD + ) def test_bytes_io(): @@ -301,11 +366,11 @@ def test_bytes_io(): for pdf, config, hs in zip(pdf_fnames, config_fnames, hs_fnames): raw = read_raw_bti(pdf, config, hs, convert=True, preload=False) - with open(pdf, 'rb') as fid: + with open(pdf, "rb") as fid: pdf = BytesIO(fid.read()) - with open(config, 'rb') as fid: + with open(config, "rb") as fid: config = BytesIO(fid.read()) - with open(hs, 'rb') as fid: + with open(hs, "rb") as fid: hs = BytesIO(fid.read()) raw2 = read_raw_bti(pdf, config, hs, convert=True, preload=False) @@ -319,56 +384,54 @@ def test_setup_headshape(): nasion, lpa, rpa, hpi, dig_points = _read_head_shape(hs) dig, t, _ = _make_bti_dig_points(nasion, lpa, rpa, hpi, dig_points) - expected = {'kind', 'ident', 'r'} - found = set(reduce(lambda x, y: list(x) + list(y), - [d.keys() for d in dig])) - assert (not expected - found) + expected = {"kind", "ident", "r"} + found = set(reduce(lambda x, y: list(x) + list(y), [d.keys() for d in dig])) + assert not expected - found def test_nan_trans(): """Test unlikely case that the device to head transform is empty.""" for ii, pdf_fname in enumerate(pdf_fnames): - bti_info = _read_bti_header( - pdf_fname, config_fnames[ii], sort_by_ch_name=True) + bti_info = _read_bti_header(pdf_fname, config_fnames[ii], sort_by_ch_name=True) - dev_ctf_t = Transform('ctf_meg', 'ctf_head', - _correct_trans(bti_info['bti_transform'][0])) + dev_ctf_t = Transform( + "ctf_meg", "ctf_head", _correct_trans(bti_info["bti_transform"][0]) + ) # reading params convert = True - rotation_x = 0. + rotation_x = 0.0 translation = (0.0, 0.02, 0.11) bti_dev_t = _get_bti_dev_t(rotation_x, translation) - bti_dev_t = Transform('ctf_meg', 'meg', bti_dev_t) - ecg_ch = 'E31' - eog_ch = ('E63', 'E64') + bti_dev_t = Transform("ctf_meg", "meg", bti_dev_t) + ecg_ch = "E31" + eog_ch = ("E63", "E64") # read parts of info to get trans bti_ch_names = list() - for ch in bti_info['chs']: - ch_name = ch['name'] - if not ch_name.startswith('A'): - ch_name = ch.get('chan_label', ch_name) + for ch in bti_info["chs"]: + ch_name = ch["name"] + if not ch_name.startswith("A"): + ch_name = ch.get("chan_label", ch_name) bti_ch_names.append(ch_name) - neuromag_ch_names = _rename_channels( - bti_ch_names, ecg_ch=ecg_ch, eog_ch=eog_ch) + neuromag_ch_names = _rename_channels(bti_ch_names, ecg_ch=ecg_ch, eog_ch=eog_ch) ch_mapping = zip(bti_ch_names, neuromag_ch_names) # add some nan in some locations! - dev_ctf_t['trans'][:, 3] = np.nan + dev_ctf_t["trans"][:, 3] = np.nan _check_nan_dev_head_t(dev_ctf_t) for idx, (chan_4d, chan_neuromag) in enumerate(ch_mapping): - loc = bti_info['chs'][idx]['loc'] + loc = bti_info["chs"][idx]["loc"] if loc is not None: if convert: - t = _loc_to_coil_trans(bti_info['chs'][idx]['loc']) + t = _loc_to_coil_trans(bti_info["chs"][idx]["loc"]) t = _convert_coil_trans(t, dev_ctf_t, bti_dev_t) @testing.requires_testing_data -@pytest.mark.parametrize('fname', (fname_sim, fname_sim_filt)) -@pytest.mark.parametrize('preload', (True, False)) +@pytest.mark.parametrize("fname", (fname_sim, fname_sim_filt)) +@pytest.mark.parametrize("preload", (True, False)) def test_bti_ch_data(fname, preload): """Test for gh-6048.""" read_raw_bti(fname, preload=preload) # used to fail with ascii decode err @@ -377,9 +440,9 @@ def test_bti_ch_data(fname, preload): @testing.requires_testing_data def test_bti_set_eog(): """Check that EOG channels can be set (gh-10092).""" - raw = read_raw_bti(fname_sim, - preload=False, - eog_ch=('X65', 'X67', 'X69', 'X66', 'X68')) + raw = read_raw_bti( + fname_sim, preload=False, eog_ch=("X65", "X67", "X69", "X66", "X68") + ) assert_equal(len(pick_types(raw.info, eog=True)), 5) @@ -402,17 +465,17 @@ def test_bti_ecg_eog_emg(monkeypatch): # already exist got_map = dict(zip(raw.ch_names, ch_types)) kind_map = dict( - stim=['TRIGGER', 'RESPONSE'], - misc=['UACurrent'], + stim=["TRIGGER", "RESPONSE"], + misc=["UACurrent"], ) for kind, ch_names in kind_map.items(): for ch_name in ch_names: assert got_map[ch_name] == kind kind_map = dict( - misc=['SA1', 'SA2', 'SA3'], - ecg=['ECG+', 'ECG-'], - eog=['VEOG+', 'HEOG+', 'VEOG-', 'HEOG-'], - emg=['EMG_LF', 'EMG_LH', 'EMG_RF', 'EMG_RH'], + misc=["SA1", "SA2", "SA3"], + ecg=["ECG+", "ECG-"], + eog=["VEOG+", "HEOG+", "VEOG-", "HEOG-"], + emg=["EMG_LF", "EMG_LH", "EMG_RF", "EMG_RH"], ) new_names = sum(kind_map.values(), list()) assert len(new_names) == 13 @@ -420,11 +483,11 @@ def test_bti_ecg_eog_emg(monkeypatch): def _read_bti_header_2(*args, **kwargs): bti_info = _read_bti_header(*args, **kwargs) - for ch_name, ch in zip(new_names, bti_info['chs'][::-1]): - ch['chan_label'] = ch_name + for ch_name, ch in zip(new_names, bti_info["chs"][::-1]): + ch["chan_label"] = ch_name return bti_info - monkeypatch.setattr(mne.io.bti.bti, '_read_bti_header', _read_bti_header_2) + monkeypatch.setattr(mne.io.bti.bti, "_read_bti_header", _read_bti_header_2) raw = read_raw_bti(fname_2500, **kwargs) got_map = dict(zip(raw.ch_names, raw.get_channel_types())) got = Counter(got_map.values()) @@ -437,5 +500,5 @@ def _read_bti_header_2(*args, **kwargs): for kind, ch_names in kind_map.items(): for ch_name in ch_names: assert ch_name in raw.ch_names - err_msg = f'{ch_name} type {got_map[ch_name]} !+ {kind}' + err_msg = f"{ch_name} type {got_map[ch_name]} !+ {kind}" assert got_map[ch_name] == kind, err_msg diff --git a/mne/io/cnt/_utils.py b/mne/io/cnt/_utils.py index 19e459ac506..dd13e688b8f 100644 --- a/mne/io/cnt/_utils.py +++ b/mne/io/cnt/_utils.py @@ -25,15 +25,14 @@ def _read_teeg(f, teeg_offset): } TEEG; """ # we use a more descriptive names based on TEEG doc comments - Teeg = namedtuple('Teeg', 'event_type total_length offset') - teeg_parser = Struct('3 range 0-15 bit coded response pad */ @@ -60,23 +63,27 @@ def _read_teeg(f, teeg_offset): # needed for backward compat: EVENT type 3 has the same structure as type 2 -CNTEventType3 = namedtuple('CNTEventType3', - ('StimType KeyBoard KeyPad_Accept Offset Type ' - 'Code Latency EpochEvent Accept2 Accuracy')) +CNTEventType3 = namedtuple( + "CNTEventType3", + ( + "StimType KeyBoard KeyPad_Accept Offset Type " + "Code Latency EpochEvent Accept2 Accuracy" + ), +) def _get_event_parser(event_type): if event_type == 1: event_maker = CNTEventType1 - struct_pattern = ' 0 else '' - first_name = patient_name[-1] if len(patient_name) > 0 else '' + last_name = patient_name[0] if len(patient_name) > 0 else "" + first_name = patient_name[-1] if len(patient_name) > 0 else "" fid.seek(2, 1) sex = read_str(fid, 1) - if sex == 'M': + if sex == "M": sex = FIFF.FIFFV_SUBJ_SEX_MALE - elif sex == 'F': + elif sex == "F": sex = FIFF.FIFFV_SUBJ_SEX_FEMALE else: # can be 'U' sex = FIFF.FIFFV_SUBJ_SEX_UNKNOWN hand = read_str(fid, 1) - if hand == 'R': + if hand == "R": hand = FIFF.FIFFV_SUBJ_HAND_RIGHT - elif hand == 'L': + elif hand == "L": hand = FIFF.FIFFV_SUBJ_HAND_LEFT else: # can be 'M' for mixed or 'U' hand = None fid.seek(205) session_label = read_str(fid, 20) - session_date = ('%s %s' % (read_str(fid, 10), read_str(fid, 12))) + session_date = "%s %s" % (read_str(fid, 10), read_str(fid, 12)) meas_date = _session_date_2_meas_date(session_date, date_format) fid.seek(370) - n_channels = np.fromfile(fid, dtype='= 0] + eog = [idx for idx in np.fromfile(fid, dtype="i2", count=2) if idx >= 0] fid.seek(438) - lowpass_toggle = np.fromfile(fid, 'i1', count=1).item() - highpass_toggle = np.fromfile(fid, 'i1', count=1).item() + lowpass_toggle = np.fromfile(fid, "i1", count=1).item() + highpass_toggle = np.fromfile(fid, "i1", count=1).item() # Header has a field for number of samples, but it does not seem to be # too reliable. That's why we have option for setting n_bytes manually. fid.seek(864) - n_samples = np.fromfile(fid, dtype=' 1: - cnt_info['channel_offset'] //= n_bytes + cnt_info["channel_offset"] = np.fromfile(fid, dtype=" 1: + cnt_info["channel_offset"] //= n_bytes else: - cnt_info['channel_offset'] = 1 + cnt_info["channel_offset"] = 1 - ch_names, cals, baselines, chs, pos = ( - list(), list(), list(), list(), list() - ) + ch_names, cals, baselines, chs, pos = (list(), list(), list(), list(), list()) bads = list() for ch_idx in range(n_channels): # ELECTLOC fields @@ -325,57 +347,67 @@ def _get_cnt_info(input_fname, eog, ecg, emg, misc, data_format, date_format): ch_name = read_str(fid, 10) ch_names.append(ch_name) fid.seek(data_offset + 75 * ch_idx + 4) - if np.fromfile(fid, dtype='u1', count=1).item(): + if np.fromfile(fid, dtype="u1", count=1).item(): bads.append(ch_name) fid.seek(data_offset + 75 * ch_idx + 19) - xy = np.fromfile(fid, dtype='f4', count=2) + xy = np.fromfile(fid, dtype="f4", count=2) xy[1] *= -1 # invert y-axis pos.append(xy) fid.seek(data_offset + 75 * ch_idx + 47) # Baselines are subtracted before scaling the data. - baselines.append(np.fromfile(fid, dtype='i2', count=1).item()) + baselines.append(np.fromfile(fid, dtype="i2", count=1).item()) fid.seek(data_offset + 75 * ch_idx + 59) - sensitivity = np.fromfile(fid, dtype='f4', count=1).item() + sensitivity = np.fromfile(fid, dtype="f4", count=1).item() fid.seek(data_offset + 75 * ch_idx + 71) - cal = np.fromfile(fid, dtype='f4', count=1).item() + cal = np.fromfile(fid, dtype="f4", count=1).item() cals.append(cal * sensitivity * 1e-6 / 204.8) info = _empty_info(sfreq) if lowpass_toggle == 1: - info['lowpass'] = highcutoff + info["lowpass"] = highcutoff if highpass_toggle == 1: - info['highpass'] = lowcutoff - subject_info = {'hand': hand, 'id': patient_id, 'sex': sex, - 'first_name': first_name, 'last_name': last_name} - - if eog == 'auto': - eog = _find_channels(ch_names, 'EOG') - if ecg == 'auto': - ecg = _find_channels(ch_names, 'ECG') - if emg == 'auto': - emg = _find_channels(ch_names, 'EMG') - - chs = _create_chs(ch_names, cals, FIFF.FIFFV_COIL_EEG, - FIFF.FIFFV_EEG_CH, eog, ecg, emg, misc) - eegs = [idx for idx, ch in enumerate(chs) if - ch['coil_type'] == FIFF.FIFFV_COIL_EEG] + info["highpass"] = lowcutoff + subject_info = { + "hand": hand, + "id": patient_id, + "sex": sex, + "first_name": first_name, + "last_name": last_name, + } + + if eog == "auto": + eog = _find_channels(ch_names, "EOG") + if ecg == "auto": + ecg = _find_channels(ch_names, "ECG") + if emg == "auto": + emg = _find_channels(ch_names, "EMG") + + chs = _create_chs( + ch_names, cals, FIFF.FIFFV_COIL_EEG, FIFF.FIFFV_EEG_CH, eog, ecg, emg, misc + ) + eegs = [idx for idx, ch in enumerate(chs) if ch["coil_type"] == FIFF.FIFFV_COIL_EEG] coords = _topo_to_sphere(pos, eegs) locs = np.full((len(chs), 12), np.nan) locs[:, :3] = coords dig = _make_dig_points( dig_ch_pos=dict(zip(ch_names, coords)), - coord_frame="head", add_missing_fiducials=True, + coord_frame="head", + add_missing_fiducials=True, ) for ch, loc in zip(chs, locs): ch.update(loc=loc) - cnt_info.update(baselines=np.array(baselines), n_samples=n_samples, - n_bytes=n_bytes) + cnt_info.update(baselines=np.array(baselines), n_samples=n_samples, n_bytes=n_bytes) - session_label = None if str(session_label) == '' else str(session_label) - info.update(meas_date=meas_date, dig=dig, - description=session_label, bads=bads, - subject_info=subject_info, chs=chs) + session_label = None if str(session_label) == "" else str(session_label) + info.update( + meas_date=meas_date, + dig=dig, + description=session_label, + bads=bads, + subject_info=subject_info, + chs=chs, + ) info._unlocked = False info._update_redundant() return info, cnt_info @@ -439,42 +471,58 @@ class RawCNT(BaseRaw): mne.io.Raw : Documentation of attributes and methods. """ - def __init__(self, input_fname, eog=(), misc=(), - ecg=(), emg=(), data_format='auto', date_format='mm/dd/yy', - preload=False, verbose=None): # noqa: D102 - - _check_option('date_format', date_format, ['mm/dd/yy', 'dd/mm/yy']) - if date_format == 'dd/mm/yy': - _date_format = '%d/%m/%y %H:%M:%S' + def __init__( + self, + input_fname, + eog=(), + misc=(), + ecg=(), + emg=(), + data_format="auto", + date_format="mm/dd/yy", + preload=False, + verbose=None, + ): # noqa: D102 + _check_option("date_format", date_format, ["mm/dd/yy", "dd/mm/yy"]) + if date_format == "dd/mm/yy": + _date_format = "%d/%m/%y %H:%M:%S" else: - _date_format = '%m/%d/%y %H:%M:%S' + _date_format = "%m/%d/%y %H:%M:%S" input_fname = path.abspath(input_fname) - info, cnt_info = _get_cnt_info(input_fname, eog, ecg, emg, misc, - data_format, _date_format) - last_samps = [cnt_info['n_samples'] - 1] + info, cnt_info = _get_cnt_info( + input_fname, eog, ecg, emg, misc, data_format, _date_format + ) + last_samps = [cnt_info["n_samples"] - 1] super(RawCNT, self).__init__( - info, preload, filenames=[input_fname], raw_extras=[cnt_info], - last_samps=last_samps, orig_format='int', verbose=verbose) + info, + preload, + filenames=[input_fname], + raw_extras=[cnt_info], + last_samps=last_samps, + orig_format="int", + verbose=verbose, + ) - data_format = 'int32' if cnt_info['n_bytes'] == 4 else 'int16' + data_format = "int32" if cnt_info["n_bytes"] == 4 else "int16" self.set_annotations( - _read_annotations_cnt(input_fname, data_format=data_format)) + _read_annotations_cnt(input_fname, data_format=data_format) + ) def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): """Take a chunk of raw data, multiply by mult or cals, and store.""" - n_channels = self._raw_extras[fi]['orig_nchan'] - if 'stim_channel' in self._raw_extras[fi]: + n_channels = self._raw_extras[fi]["orig_nchan"] + if "stim_channel" in self._raw_extras[fi]: f_channels = n_channels - 1 # Stim channel already read. - stim_ch = self._raw_extras[fi]['stim_channel'] + stim_ch = self._raw_extras[fi]["stim_channel"] else: f_channels = n_channels stim_ch = None - channel_offset = self._raw_extras[fi]['channel_offset'] - baselines = self._raw_extras[fi]['baselines'] - n_bytes = self._raw_extras[fi]['n_bytes'] - dtype = '= (channel_offset / 2): # Extend at the end. extra_samps += chunk_size count = n_samps // channel_offset * chunk_size + extra_samps n_chunks = count // chunk_size samps = np.fromfile(fid, dtype=dtype, count=count) - samps = samps.reshape((n_chunks, f_channels, channel_offset), - order='C') + samps = samps.reshape((n_chunks, f_channels, channel_offset), order="C") # Intermediate shaping to chunk sizes. block = np.zeros((n_channels, channel_offset * n_chunks)) for set_idx, row in enumerate(samps): # Final shape. - block_slice = slice(set_idx * channel_offset, - (set_idx + 1) * channel_offset) + block_slice = slice( + set_idx * channel_offset, (set_idx + 1) * channel_offset + ) block[:f_channels, block_slice] = row - if 'stim_channel' in self._raw_extras[fi]: + if "stim_channel" in self._raw_extras[fi]: _data_start = start + sample_start _data_stop = start + sample_stop block[-1] = stim_ch[_data_start:_data_stop] - one[idx] = block[idx, s_offset:n_samps + s_offset] + one[idx] = block[idx, s_offset : n_samps + s_offset] one[idx] -= baselines[idx][:, None] - _mult_cal_one(data[:, sample_start:sample_stop], one, idx, - cals, mult) + _mult_cal_one(data[:, sample_start:sample_stop], one, idx, cals, mult) diff --git a/mne/io/cnt/tests/test_cnt.py b/mne/io/cnt/tests/test_cnt.py index f4af393f06d..76cd5c0acc1 100644 --- a/mne/io/cnt/tests/test_cnt.py +++ b/mne/io/cnt/tests/test_cnt.py @@ -21,41 +21,40 @@ @testing.requires_testing_data def test_data(): """Test reading raw cnt files.""" - with pytest.warns(RuntimeWarning, match='number of bytes'): - raw = _test_raw_reader(read_raw_cnt, input_fname=fname, - eog='auto', misc=['NA1', 'LEFT_EAR']) + with pytest.warns(RuntimeWarning, match="number of bytes"): + raw = _test_raw_reader( + read_raw_cnt, input_fname=fname, eog="auto", misc=["NA1", "LEFT_EAR"] + ) # make sure we use annotations event if we synthesized stim assert len(raw.annotations) == 6 eog_chs = pick_types(raw.info, eog=True, exclude=[]) assert len(eog_chs) == 2 # test eog='auto' - assert raw.info['bads'] == ['LEFT_EAR', 'VEOGR'] # test bads + assert raw.info["bads"] == ["LEFT_EAR", "VEOGR"] # test bads # the data has "05/10/200 17:35:31" so it is set to None - assert raw.info['meas_date'] is None + assert raw.info["meas_date"] is None @testing.requires_testing_data def test_compare_events_and_annotations(): """Test comparing annotations and events.""" - with pytest.warns(RuntimeWarning, match='Could not parse meas date'): + with pytest.warns(RuntimeWarning, match="Could not parse meas date"): raw = read_raw_cnt(fname) - events = np.array([[333, 0, 7], - [1010, 0, 7], - [1664, 0, 109], - [2324, 0, 7], - [2984, 0, 109]]) + events = np.array( + [[333, 0, 7], [1010, 0, 7], [1664, 0, 109], [2324, 0, 7], [2984, 0, 109]] + ) annot = read_annotations(fname) assert len(annot) == 6 - assert_array_equal(annot.onset[:-1], events[:, 0] / raw.info['sfreq']) - assert 'STI 014' not in raw.info['ch_names'] + assert_array_equal(annot.onset[:-1], events[:, 0] / raw.info["sfreq"]) + assert "STI 014" not in raw.info["ch_names"] @testing.requires_testing_data def test_bad_spans(): """Test reading raw cnt files with bad spans.""" annot = read_annotations(fname_bad_spans) - temp = '\t'.join(annot.description) - assert 'BAD' in temp + temp = "\t".join(annot.description) + assert "BAD" in temp diff --git a/mne/io/compensator.py b/mne/io/compensator.py index 220de1f8259..2a5334e7138 100644 --- a/mne/io/compensator.py +++ b/mne/io/compensator.py @@ -8,59 +8,60 @@ def get_current_comp(info): """Get the current compensation in effect in the data.""" comp = None first_comp = -1 - for k, chan in enumerate(info['chs']): - if chan['kind'] == FIFF.FIFFV_MEG_CH: - comp = int(chan['coil_type']) >> 16 + for k, chan in enumerate(info["chs"]): + if chan["kind"] == FIFF.FIFFV_MEG_CH: + comp = int(chan["coil_type"]) >> 16 if first_comp < 0: first_comp = comp elif comp != first_comp: - raise ValueError('Compensation is not set equally on ' - 'all MEG channels') + raise ValueError( + "Compensation is not set equally on " "all MEG channels" + ) return comp def set_current_comp(info, comp): """Set the current compensation in effect in the data.""" comp_now = get_current_comp(info) - for k, chan in enumerate(info['chs']): - if chan['kind'] == FIFF.FIFFV_MEG_CH: - rem = chan['coil_type'] - (comp_now << 16) - chan['coil_type'] = int(rem + (comp << 16)) + for k, chan in enumerate(info["chs"]): + if chan["kind"] == FIFF.FIFFV_MEG_CH: + rem = chan["coil_type"] - (comp_now << 16) + chan["coil_type"] = int(rem + (comp << 16)) def _make_compensator(info, grade): """Auxiliary function for make_compensator.""" - for k in range(len(info['comps'])): - if info['comps'][k]['kind'] == grade: - this_data = info['comps'][k]['data'] + for k in range(len(info["comps"])): + if info["comps"][k]["kind"] == grade: + this_data = info["comps"][k]["data"] # Create the preselector - presel = np.zeros((this_data['ncol'], info['nchan'])) - for col, col_name in enumerate(this_data['col_names']): - ind = [k for k, ch in enumerate(info['ch_names']) - if ch == col_name] + presel = np.zeros((this_data["ncol"], info["nchan"])) + for col, col_name in enumerate(this_data["col_names"]): + ind = [k for k, ch in enumerate(info["ch_names"]) if ch == col_name] if len(ind) == 0: - raise ValueError('Channel %s is not available in ' - 'data' % col_name) + raise ValueError( + "Channel %s is not available in " "data" % col_name + ) elif len(ind) > 1: - raise ValueError('Ambiguous channel %s' % col_name) + raise ValueError("Ambiguous channel %s" % col_name) presel[col, ind[0]] = 1.0 # Create the postselector (zero entries for channels not found) - postsel = np.zeros((info['nchan'], this_data['nrow'])) - for c, ch_name in enumerate(info['ch_names']): - ind = [k for k, ch in enumerate(this_data['row_names']) - if ch == ch_name] + postsel = np.zeros((info["nchan"], this_data["nrow"])) + for c, ch_name in enumerate(info["ch_names"]): + ind = [ + k for k, ch in enumerate(this_data["row_names"]) if ch == ch_name + ] if len(ind) > 1: - raise ValueError('Ambiguous channel %s' % ch_name) + raise ValueError("Ambiguous channel %s" % ch_name) elif len(ind) == 1: postsel[c, ind[0]] = 1.0 # else, don't use it at all (postsel[c, ?] = 0.0) by allocation - this_comp = np.dot(postsel, np.dot(this_data['data'], presel)) + this_comp = np.dot(postsel, np.dot(this_data["data"], presel)) return this_comp - raise ValueError('Desired compensation matrix (grade = %d) not' - ' found' % grade) + raise ValueError("Desired compensation matrix (grade = %d) not" " found" % grade) @fill_doc @@ -94,10 +95,10 @@ def make_compensator(info, from_, to, exclude_comp_chs=False): # s_to = (I - C2)*(I + C1)*s_from = (I + C1 - C2 - C2*C1)*s_from if from_ != 0: C1 = _make_compensator(info, from_) - comp_from_0 = np.linalg.inv(np.eye(info['nchan']) - C1) + comp_from_0 = np.linalg.inv(np.eye(info["nchan"]) - C1) if to != 0: C2 = _make_compensator(info, to) - comp_0_to = np.eye(info['nchan']) - C2 + comp_0_to = np.eye(info["nchan"]) - C2 if from_ != 0: if to != 0: # This is mathematically equivalent, but has higher numerical @@ -111,12 +112,14 @@ def make_compensator(info, from_, to, exclude_comp_chs=False): comp = comp_0_to if exclude_comp_chs: - pick = [k for k, c in enumerate(info['chs']) - if c['kind'] != FIFF.FIFFV_REF_MEG_CH] + pick = [ + k for k, c in enumerate(info["chs"]) if c["kind"] != FIFF.FIFFV_REF_MEG_CH + ] if len(pick) == 0: - raise ValueError('Nothing remains after excluding the ' - 'compensation channels') + raise ValueError( + "Nothing remains after excluding the " "compensation channels" + ) comp = comp[pick, :] diff --git a/mne/io/constants.py b/mne/io/constants.py index f2847644f07..be5c5e57044 100644 --- a/mne/io/constants.py +++ b/mne/io/constants.py @@ -17,313 +17,331 @@ # # Blocks # -FIFF.FIFFB_ROOT = 999 -FIFF.FIFFB_MEAS = 100 -FIFF.FIFFB_MEAS_INFO = 101 -FIFF.FIFFB_RAW_DATA = 102 -FIFF.FIFFB_PROCESSED_DATA = 103 -FIFF.FIFFB_EVOKED = 104 -FIFF.FIFFB_ASPECT = 105 -FIFF.FIFFB_SUBJECT = 106 -FIFF.FIFFB_ISOTRAK = 107 -FIFF.FIFFB_HPI_MEAS = 108 # HPI measurement -FIFF.FIFFB_HPI_RESULT = 109 # Result of a HPI fitting procedure -FIFF.FIFFB_HPI_COIL = 110 # Data acquired from one HPI coil -FIFF.FIFFB_PROJECT = 111 -FIFF.FIFFB_CONTINUOUS_DATA = 112 -FIFF.FIFFB_CH_INFO = 113 # Extra channel information -FIFF.FIFFB_VOID = 114 -FIFF.FIFFB_EVENTS = 115 -FIFF.FIFFB_INDEX = 116 -FIFF.FIFFB_DACQ_PARS = 117 -FIFF.FIFFB_REF = 118 -FIFF.FIFFB_IAS_RAW_DATA = 119 -FIFF.FIFFB_IAS_ASPECT = 120 -FIFF.FIFFB_HPI_SUBSYSTEM = 121 +FIFF.FIFFB_ROOT = 999 +FIFF.FIFFB_MEAS = 100 +FIFF.FIFFB_MEAS_INFO = 101 +FIFF.FIFFB_RAW_DATA = 102 +FIFF.FIFFB_PROCESSED_DATA = 103 +FIFF.FIFFB_EVOKED = 104 +FIFF.FIFFB_ASPECT = 105 +FIFF.FIFFB_SUBJECT = 106 +FIFF.FIFFB_ISOTRAK = 107 +FIFF.FIFFB_HPI_MEAS = 108 # HPI measurement +FIFF.FIFFB_HPI_RESULT = 109 # Result of a HPI fitting procedure +FIFF.FIFFB_HPI_COIL = 110 # Data acquired from one HPI coil +FIFF.FIFFB_PROJECT = 111 +FIFF.FIFFB_CONTINUOUS_DATA = 112 +FIFF.FIFFB_CH_INFO = 113 # Extra channel information +FIFF.FIFFB_VOID = 114 +FIFF.FIFFB_EVENTS = 115 +FIFF.FIFFB_INDEX = 116 +FIFF.FIFFB_DACQ_PARS = 117 +FIFF.FIFFB_REF = 118 +FIFF.FIFFB_IAS_RAW_DATA = 119 +FIFF.FIFFB_IAS_ASPECT = 120 +FIFF.FIFFB_HPI_SUBSYSTEM = 121 # FIFF.FIFFB_PHANTOM_SUBSYSTEM = 122 # FIFF.FIFFB_STATUS_SUBSYSTEM = 123 -FIFF.FIFFB_DEVICE = 124 -FIFF.FIFFB_HELIUM = 125 -FIFF.FIFFB_CHANNEL_INFO = 126 - -FIFF.FIFFB_SPHERE = 300 # Concentric sphere model related -FIFF.FIFFB_BEM = 310 # Boundary-element method -FIFF.FIFFB_BEM_SURF = 311 # Boundary-element method surfaces -FIFF.FIFFB_CONDUCTOR_MODEL = 312 # One conductor model definition -FIFF.FIFFB_PROJ = 313 -FIFF.FIFFB_PROJ_ITEM = 314 -FIFF.FIFFB_MRI = 200 -FIFF.FIFFB_MRI_SET = 201 -FIFF.FIFFB_MRI_SLICE = 202 -FIFF.FIFFB_MRI_SCENERY = 203 # These are for writing unrelated 'slices' -FIFF.FIFFB_MRI_SCENE = 204 # Which are actually 3D scenes... -FIFF.FIFFB_MRI_SEG = 205 # MRI segmentation data -FIFF.FIFFB_MRI_SEG_REGION = 206 # One MRI segmentation region +FIFF.FIFFB_DEVICE = 124 +FIFF.FIFFB_HELIUM = 125 +FIFF.FIFFB_CHANNEL_INFO = 126 + +FIFF.FIFFB_SPHERE = 300 # Concentric sphere model related +FIFF.FIFFB_BEM = 310 # Boundary-element method +FIFF.FIFFB_BEM_SURF = 311 # Boundary-element method surfaces +FIFF.FIFFB_CONDUCTOR_MODEL = 312 # One conductor model definition +FIFF.FIFFB_PROJ = 313 +FIFF.FIFFB_PROJ_ITEM = 314 +FIFF.FIFFB_MRI = 200 +FIFF.FIFFB_MRI_SET = 201 +FIFF.FIFFB_MRI_SLICE = 202 +FIFF.FIFFB_MRI_SCENERY = 203 # These are for writing unrelated 'slices' +FIFF.FIFFB_MRI_SCENE = 204 # Which are actually 3D scenes... +FIFF.FIFFB_MRI_SEG = 205 # MRI segmentation data +FIFF.FIFFB_MRI_SEG_REGION = 206 # One MRI segmentation region FIFF.FIFFB_PROCESSING_HISTORY = 900 -FIFF.FIFFB_PROCESSING_RECORD = 901 +FIFF.FIFFB_PROCESSING_RECORD = 901 -FIFF.FIFFB_DATA_CORRECTION = 500 -FIFF.FIFFB_CHANNEL_DECOUPLER = 501 -FIFF.FIFFB_SSS_INFO = 502 -FIFF.FIFFB_SSS_CAL = 503 -FIFF.FIFFB_SSS_ST_INFO = 504 -FIFF.FIFFB_SSS_BASES = 505 -FIFF.FIFFB_IAS = 510 +FIFF.FIFFB_DATA_CORRECTION = 500 +FIFF.FIFFB_CHANNEL_DECOUPLER = 501 +FIFF.FIFFB_SSS_INFO = 502 +FIFF.FIFFB_SSS_CAL = 503 +FIFF.FIFFB_SSS_ST_INFO = 504 +FIFF.FIFFB_SSS_BASES = 505 +FIFF.FIFFB_IAS = 510 # # Of general interest # -FIFF.FIFF_FILE_ID = 100 -FIFF.FIFF_DIR_POINTER = 101 -FIFF.FIFF_BLOCK_ID = 103 -FIFF.FIFF_BLOCK_START = 104 -FIFF.FIFF_BLOCK_END = 105 -FIFF.FIFF_FREE_LIST = 106 -FIFF.FIFF_FREE_BLOCK = 107 -FIFF.FIFF_NOP = 108 -FIFF.FIFF_PARENT_FILE_ID = 109 +FIFF.FIFF_FILE_ID = 100 +FIFF.FIFF_DIR_POINTER = 101 +FIFF.FIFF_BLOCK_ID = 103 +FIFF.FIFF_BLOCK_START = 104 +FIFF.FIFF_BLOCK_END = 105 +FIFF.FIFF_FREE_LIST = 106 +FIFF.FIFF_FREE_BLOCK = 107 +FIFF.FIFF_NOP = 108 +FIFF.FIFF_PARENT_FILE_ID = 109 FIFF.FIFF_PARENT_BLOCK_ID = 110 -FIFF.FIFF_BLOCK_NAME = 111 -FIFF.FIFF_BLOCK_VERSION = 112 -FIFF.FIFF_CREATOR = 113 # Program that created the file (string) -FIFF.FIFF_MODIFIER = 114 # Program that modified the file (string) -FIFF.FIFF_REF_ROLE = 115 -FIFF.FIFF_REF_FILE_ID = 116 -FIFF.FIFF_REF_FILE_NUM = 117 -FIFF.FIFF_REF_FILE_NAME = 118 +FIFF.FIFF_BLOCK_NAME = 111 +FIFF.FIFF_BLOCK_VERSION = 112 +FIFF.FIFF_CREATOR = 113 # Program that created the file (string) +FIFF.FIFF_MODIFIER = 114 # Program that modified the file (string) +FIFF.FIFF_REF_ROLE = 115 +FIFF.FIFF_REF_FILE_ID = 116 +FIFF.FIFF_REF_FILE_NUM = 117 +FIFF.FIFF_REF_FILE_NAME = 118 # # Megacq saves the parameters in these tags # -FIFF.FIFF_DACQ_PARS = 150 -FIFF.FIFF_DACQ_STIM = 151 +FIFF.FIFF_DACQ_PARS = 150 +FIFF.FIFF_DACQ_STIM = 151 -FIFF.FIFF_DEVICE_TYPE = 152 -FIFF.FIFF_DEVICE_MODEL = 153 -FIFF.FIFF_DEVICE_SERIAL = 154 -FIFF.FIFF_DEVICE_SITE = 155 +FIFF.FIFF_DEVICE_TYPE = 152 +FIFF.FIFF_DEVICE_MODEL = 153 +FIFF.FIFF_DEVICE_SERIAL = 154 +FIFF.FIFF_DEVICE_SITE = 155 -FIFF.FIFF_HE_LEVEL_RAW = 156 -FIFF.FIFF_HELIUM_LEVEL = 157 +FIFF.FIFF_HE_LEVEL_RAW = 156 +FIFF.FIFF_HELIUM_LEVEL = 157 FIFF.FIFF_ORIG_FILE_GUID = 158 -FIFF.FIFF_UTC_OFFSET = 159 - -FIFF.FIFF_NCHAN = 200 -FIFF.FIFF_SFREQ = 201 -FIFF.FIFF_DATA_PACK = 202 -FIFF.FIFF_CH_INFO = 203 -FIFF.FIFF_MEAS_DATE = 204 -FIFF.FIFF_SUBJECT = 205 -FIFF.FIFF_COMMENT = 206 -FIFF.FIFF_NAVE = 207 -FIFF.FIFF_FIRST_SAMPLE = 208 # The first sample of an epoch -FIFF.FIFF_LAST_SAMPLE = 209 # The last sample of an epoch -FIFF.FIFF_ASPECT_KIND = 210 -FIFF.FIFF_REF_EVENT = 211 +FIFF.FIFF_UTC_OFFSET = 159 + +FIFF.FIFF_NCHAN = 200 +FIFF.FIFF_SFREQ = 201 +FIFF.FIFF_DATA_PACK = 202 +FIFF.FIFF_CH_INFO = 203 +FIFF.FIFF_MEAS_DATE = 204 +FIFF.FIFF_SUBJECT = 205 +FIFF.FIFF_COMMENT = 206 +FIFF.FIFF_NAVE = 207 +FIFF.FIFF_FIRST_SAMPLE = 208 # The first sample of an epoch +FIFF.FIFF_LAST_SAMPLE = 209 # The last sample of an epoch +FIFF.FIFF_ASPECT_KIND = 210 +FIFF.FIFF_REF_EVENT = 211 FIFF.FIFF_EXPERIMENTER = 212 -FIFF.FIFF_DIG_POINT = 213 -FIFF.FIFF_CH_POS = 214 -FIFF.FIFF_HPI_SLOPES = 215 # HPI data -FIFF.FIFF_HPI_NCOIL = 216 -FIFF.FIFF_REQ_EVENT = 217 -FIFF.FIFF_REQ_LIMIT = 218 -FIFF.FIFF_LOWPASS = 219 -FIFF.FIFF_BAD_CHS = 220 +FIFF.FIFF_DIG_POINT = 213 +FIFF.FIFF_CH_POS = 214 +FIFF.FIFF_HPI_SLOPES = 215 # HPI data +FIFF.FIFF_HPI_NCOIL = 216 +FIFF.FIFF_REQ_EVENT = 217 +FIFF.FIFF_REQ_LIMIT = 218 +FIFF.FIFF_LOWPASS = 219 +FIFF.FIFF_BAD_CHS = 220 FIFF.FIFF_ARTEF_REMOVAL = 221 FIFF.FIFF_COORD_TRANS = 222 -FIFF.FIFF_HIGHPASS = 223 -FIFF.FIFF_CH_CALS = 224 # This will not occur in new files -FIFF.FIFF_HPI_BAD_CHS = 225 # List of channels considered to be bad in hpi -FIFF.FIFF_HPI_CORR_COEFF = 226 # HPI curve fit correlations -FIFF.FIFF_EVENT_COMMENT = 227 # Comment about the events used in averaging -FIFF.FIFF_NO_SAMPLES = 228 # Number of samples in an epoch -FIFF.FIFF_FIRST_TIME = 229 # Time scale minimum - -FIFF.FIFF_SUBAVE_SIZE = 230 # Size of a subaverage -FIFF.FIFF_SUBAVE_FIRST = 231 # The first epoch # contained in the subaverage -FIFF.FIFF_NAME = 233 # Intended to be a short name. -FIFF.FIFF_DESCRIPTION = FIFF.FIFF_COMMENT # (Textual) Description of an object -FIFF.FIFF_DIG_STRING = 234 # String of digitized points -FIFF.FIFF_LINE_FREQ = 235 # Line frequency -FIFF.FIFF_GANTRY_ANGLE = 282 # Tilt angle of the gantry in degrees. +FIFF.FIFF_HIGHPASS = 223 +FIFF.FIFF_CH_CALS = 224 # This will not occur in new files +FIFF.FIFF_HPI_BAD_CHS = 225 # List of channels considered to be bad in hpi +FIFF.FIFF_HPI_CORR_COEFF = 226 # HPI curve fit correlations +FIFF.FIFF_EVENT_COMMENT = 227 # Comment about the events used in averaging +FIFF.FIFF_NO_SAMPLES = 228 # Number of samples in an epoch +FIFF.FIFF_FIRST_TIME = 229 # Time scale minimum + +FIFF.FIFF_SUBAVE_SIZE = 230 # Size of a subaverage +FIFF.FIFF_SUBAVE_FIRST = 231 # The first epoch # contained in the subaverage +FIFF.FIFF_NAME = 233 # Intended to be a short name. +FIFF.FIFF_DESCRIPTION = FIFF.FIFF_COMMENT # (Textual) Description of an object +FIFF.FIFF_DIG_STRING = 234 # String of digitized points +FIFF.FIFF_LINE_FREQ = 235 # Line frequency +FIFF.FIFF_GANTRY_ANGLE = 282 # Tilt angle of the gantry in degrees. # # HPI fitting program tags # -FIFF.FIFF_HPI_COIL_FREQ = 236 # HPI coil excitation frequency -FIFF.FIFF_HPI_COIL_MOMENTS = 240 # Estimated moment vectors for the HPI coil magnetic dipoles -FIFF.FIFF_HPI_FIT_GOODNESS = 241 # Three floats indicating the goodness of fit -FIFF.FIFF_HPI_FIT_ACCEPT = 242 # Bitmask indicating acceptance (see below) -FIFF.FIFF_HPI_FIT_GOOD_LIMIT = 243 # Limit for the goodness-of-fit -FIFF.FIFF_HPI_FIT_DIST_LIMIT = 244 # Limit for the coil distance difference -FIFF.FIFF_HPI_COIL_NO = 245 # Coil number listed by HPI measurement -FIFF.FIFF_HPI_COILS_USED = 246 # List of coils finally used when the transformation was computed -FIFF.FIFF_HPI_DIGITIZATION_ORDER = 247 # Which Isotrak digitization point corresponds to each of the coils energized +FIFF.FIFF_HPI_COIL_FREQ = 236 # HPI coil excitation frequency +FIFF.FIFF_HPI_COIL_MOMENTS = ( + 240 # Estimated moment vectors for the HPI coil magnetic dipoles +) +FIFF.FIFF_HPI_FIT_GOODNESS = 241 # Three floats indicating the goodness of fit +FIFF.FIFF_HPI_FIT_ACCEPT = 242 # Bitmask indicating acceptance (see below) +FIFF.FIFF_HPI_FIT_GOOD_LIMIT = 243 # Limit for the goodness-of-fit +FIFF.FIFF_HPI_FIT_DIST_LIMIT = 244 # Limit for the coil distance difference +FIFF.FIFF_HPI_COIL_NO = 245 # Coil number listed by HPI measurement +FIFF.FIFF_HPI_COILS_USED = ( + 246 # List of coils finally used when the transformation was computed +) +FIFF.FIFF_HPI_DIGITIZATION_ORDER = ( + 247 # Which Isotrak digitization point corresponds to each of the coils energized +) # # Tags used for storing channel info # -FIFF.FIFF_CH_SCAN_NO = 250 # Channel scan number. Corresponds to fiffChInfoRec.scanNo field -FIFF.FIFF_CH_LOGICAL_NO = 251 # Channel logical number. Corresponds to fiffChInfoRec.logNo field -FIFF.FIFF_CH_KIND = 252 # Channel type. Corresponds to fiffChInfoRec.kind field" -FIFF.FIFF_CH_RANGE = 253 # Conversion from recorded number to (possibly virtual) voltage at the output" -FIFF.FIFF_CH_CAL = 254 # Calibration coefficient from output voltage to some real units -FIFF.FIFF_CH_LOC = 255 # Channel loc -FIFF.FIFF_CH_UNIT = 256 # Unit of the data -FIFF.FIFF_CH_UNIT_MUL = 257 # Unit multiplier exponent -FIFF.FIFF_CH_DACQ_NAME = 258 # Name of the channel in the data acquisition system. Corresponds to fiffChInfoRec.name. -FIFF.FIFF_CH_COIL_TYPE = 350 # Coil type in coil_def.dat -FIFF.FIFF_CH_COORD_FRAME = 351 # Coordinate frame (integer) +FIFF.FIFF_CH_SCAN_NO = ( + 250 # Channel scan number. Corresponds to fiffChInfoRec.scanNo field +) +FIFF.FIFF_CH_LOGICAL_NO = ( + 251 # Channel logical number. Corresponds to fiffChInfoRec.logNo field +) +FIFF.FIFF_CH_KIND = 252 # Channel type. Corresponds to fiffChInfoRec.kind field" +FIFF.FIFF_CH_RANGE = ( + 253 # Conversion from recorded number to (possibly virtual) voltage at the output" +) +FIFF.FIFF_CH_CAL = 254 # Calibration coefficient from output voltage to some real units +FIFF.FIFF_CH_LOC = 255 # Channel loc +FIFF.FIFF_CH_UNIT = 256 # Unit of the data +FIFF.FIFF_CH_UNIT_MUL = 257 # Unit multiplier exponent +FIFF.FIFF_CH_DACQ_NAME = 258 # Name of the channel in the data acquisition system. Corresponds to fiffChInfoRec.name. +FIFF.FIFF_CH_COIL_TYPE = 350 # Coil type in coil_def.dat +FIFF.FIFF_CH_COORD_FRAME = 351 # Coordinate frame (integer) # # Pointers # -FIFF.FIFFV_NEXT_SEQ = 0 -FIFF.FIFFV_NEXT_NONE = -1 +FIFF.FIFFV_NEXT_SEQ = 0 +FIFF.FIFFV_NEXT_NONE = -1 # # Channel types # -FIFF.FIFFV_BIO_CH = 102 -FIFF.FIFFV_MEG_CH = 1 -FIFF.FIFFV_REF_MEG_CH = 301 -FIFF.FIFFV_EEG_CH = 2 -FIFF.FIFFV_MCG_CH = 201 -FIFF.FIFFV_STIM_CH = 3 -FIFF.FIFFV_EOG_CH = 202 -FIFF.FIFFV_EMG_CH = 302 -FIFF.FIFFV_ECG_CH = 402 -FIFF.FIFFV_MISC_CH = 502 -FIFF.FIFFV_RESP_CH = 602 # Respiration monitoring -FIFF.FIFFV_SEEG_CH = 802 # stereotactic EEG -FIFF.FIFFV_DBS_CH = 803 # deep brain stimulation -FIFF.FIFFV_SYST_CH = 900 # some system status information (on Triux systems only) -FIFF.FIFFV_ECOG_CH = 902 -FIFF.FIFFV_IAS_CH = 910 # Internal Active Shielding data (maybe on Triux only) -FIFF.FIFFV_EXCI_CH = 920 # flux excitation channel used to be a stimulus channel -FIFF.FIFFV_DIPOLE_WAVE = 1000 # Dipole time curve (xplotter/xfit) +FIFF.FIFFV_BIO_CH = 102 +FIFF.FIFFV_MEG_CH = 1 +FIFF.FIFFV_REF_MEG_CH = 301 +FIFF.FIFFV_EEG_CH = 2 +FIFF.FIFFV_MCG_CH = 201 +FIFF.FIFFV_STIM_CH = 3 +FIFF.FIFFV_EOG_CH = 202 +FIFF.FIFFV_EMG_CH = 302 +FIFF.FIFFV_ECG_CH = 402 +FIFF.FIFFV_MISC_CH = 502 +FIFF.FIFFV_RESP_CH = 602 # Respiration monitoring +FIFF.FIFFV_SEEG_CH = 802 # stereotactic EEG +FIFF.FIFFV_DBS_CH = 803 # deep brain stimulation +FIFF.FIFFV_SYST_CH = 900 # some system status information (on Triux systems only) +FIFF.FIFFV_ECOG_CH = 902 +FIFF.FIFFV_IAS_CH = 910 # Internal Active Shielding data (maybe on Triux only) +FIFF.FIFFV_EXCI_CH = 920 # flux excitation channel used to be a stimulus channel +FIFF.FIFFV_DIPOLE_WAVE = 1000 # Dipole time curve (xplotter/xfit) FIFF.FIFFV_GOODNESS_FIT = 1001 # Goodness of fit (xplotter/xfit) -FIFF.FIFFV_FNIRS_CH = 1100 # Functional near-infrared spectroscopy +FIFF.FIFFV_FNIRS_CH = 1100 # Functional near-infrared spectroscopy FIFF.FIFFV_TEMPERATURE_CH = 1200 # Functional near-infrared spectroscopy -FIFF.FIFFV_GALVANIC_CH = 1300 # Galvanic skin response -FIFF.FIFFV_EYETRACK_CH = 1400 # Eye-tracking - -_ch_kind_named = {key: key for key in ( - FIFF.FIFFV_BIO_CH, - FIFF.FIFFV_MEG_CH, - FIFF.FIFFV_REF_MEG_CH, - FIFF.FIFFV_EEG_CH, - FIFF.FIFFV_MCG_CH, - FIFF.FIFFV_STIM_CH, - FIFF.FIFFV_EOG_CH, - FIFF.FIFFV_EMG_CH, - FIFF.FIFFV_ECG_CH, - FIFF.FIFFV_MISC_CH, - FIFF.FIFFV_RESP_CH, - FIFF.FIFFV_SEEG_CH, - FIFF.FIFFV_DBS_CH, - FIFF.FIFFV_SYST_CH, - FIFF.FIFFV_ECOG_CH, - FIFF.FIFFV_IAS_CH, - FIFF.FIFFV_EXCI_CH, - FIFF.FIFFV_DIPOLE_WAVE, - FIFF.FIFFV_GOODNESS_FIT, - FIFF.FIFFV_FNIRS_CH, - FIFF.FIFFV_GALVANIC_CH, - FIFF.FIFFV_TEMPERATURE_CH, - FIFF.FIFFV_EYETRACK_CH -)} +FIFF.FIFFV_GALVANIC_CH = 1300 # Galvanic skin response +FIFF.FIFFV_EYETRACK_CH = 1400 # Eye-tracking + +_ch_kind_named = { + key: key + for key in ( + FIFF.FIFFV_BIO_CH, + FIFF.FIFFV_MEG_CH, + FIFF.FIFFV_REF_MEG_CH, + FIFF.FIFFV_EEG_CH, + FIFF.FIFFV_MCG_CH, + FIFF.FIFFV_STIM_CH, + FIFF.FIFFV_EOG_CH, + FIFF.FIFFV_EMG_CH, + FIFF.FIFFV_ECG_CH, + FIFF.FIFFV_MISC_CH, + FIFF.FIFFV_RESP_CH, + FIFF.FIFFV_SEEG_CH, + FIFF.FIFFV_DBS_CH, + FIFF.FIFFV_SYST_CH, + FIFF.FIFFV_ECOG_CH, + FIFF.FIFFV_IAS_CH, + FIFF.FIFFV_EXCI_CH, + FIFF.FIFFV_DIPOLE_WAVE, + FIFF.FIFFV_GOODNESS_FIT, + FIFF.FIFFV_FNIRS_CH, + FIFF.FIFFV_GALVANIC_CH, + FIFF.FIFFV_TEMPERATURE_CH, + FIFF.FIFFV_EYETRACK_CH, + ) +} # # Quaternion channels for head position monitoring # -FIFF.FIFFV_QUAT_0 = 700 # Quaternion param q0 obsolete for unit quaternion -FIFF.FIFFV_QUAT_1 = 701 # Quaternion param q1 rotation -FIFF.FIFFV_QUAT_2 = 702 # Quaternion param q2 rotation -FIFF.FIFFV_QUAT_3 = 703 # Quaternion param q3 rotation -FIFF.FIFFV_QUAT_4 = 704 # Quaternion param q4 translation -FIFF.FIFFV_QUAT_5 = 705 # Quaternion param q5 translation -FIFF.FIFFV_QUAT_6 = 706 # Quaternion param q6 translation -FIFF.FIFFV_HPI_G = 707 # Goodness-of-fit in continuous hpi -FIFF.FIFFV_HPI_ERR = 708 # Estimation error in continuous hpi -FIFF.FIFFV_HPI_MOV = 709 # Estimated head movement speed in continuous hpi +FIFF.FIFFV_QUAT_0 = 700 # Quaternion param q0 obsolete for unit quaternion +FIFF.FIFFV_QUAT_1 = 701 # Quaternion param q1 rotation +FIFF.FIFFV_QUAT_2 = 702 # Quaternion param q2 rotation +FIFF.FIFFV_QUAT_3 = 703 # Quaternion param q3 rotation +FIFF.FIFFV_QUAT_4 = 704 # Quaternion param q4 translation +FIFF.FIFFV_QUAT_5 = 705 # Quaternion param q5 translation +FIFF.FIFFV_QUAT_6 = 706 # Quaternion param q6 translation +FIFF.FIFFV_HPI_G = 707 # Goodness-of-fit in continuous hpi +FIFF.FIFFV_HPI_ERR = 708 # Estimation error in continuous hpi +FIFF.FIFFV_HPI_MOV = 709 # Estimated head movement speed in continuous hpi # # Coordinate frames # -FIFF.FIFFV_COORD_UNKNOWN = 0 -FIFF.FIFFV_COORD_DEVICE = 1 -FIFF.FIFFV_COORD_ISOTRAK = 2 -FIFF.FIFFV_COORD_HPI = 3 -FIFF.FIFFV_COORD_HEAD = 4 -FIFF.FIFFV_COORD_MRI = 5 -FIFF.FIFFV_COORD_MRI_SLICE = 6 -FIFF.FIFFV_COORD_MRI_DISPLAY = 7 -FIFF.FIFFV_COORD_DICOM_DEVICE = 8 +FIFF.FIFFV_COORD_UNKNOWN = 0 +FIFF.FIFFV_COORD_DEVICE = 1 +FIFF.FIFFV_COORD_ISOTRAK = 2 +FIFF.FIFFV_COORD_HPI = 3 +FIFF.FIFFV_COORD_HEAD = 4 +FIFF.FIFFV_COORD_MRI = 5 +FIFF.FIFFV_COORD_MRI_SLICE = 6 +FIFF.FIFFV_COORD_MRI_DISPLAY = 7 +FIFF.FIFFV_COORD_DICOM_DEVICE = 8 FIFF.FIFFV_COORD_IMAGING_DEVICE = 9 -_coord_frame_named = {key: key for key in ( - FIFF.FIFFV_COORD_UNKNOWN, - FIFF.FIFFV_COORD_DEVICE, - FIFF.FIFFV_COORD_ISOTRAK, - FIFF.FIFFV_COORD_HPI, - FIFF.FIFFV_COORD_HEAD, - FIFF.FIFFV_COORD_MRI, - FIFF.FIFFV_COORD_MRI_SLICE, - FIFF.FIFFV_COORD_MRI_DISPLAY, - FIFF.FIFFV_COORD_DICOM_DEVICE, - FIFF.FIFFV_COORD_IMAGING_DEVICE, -)} +_coord_frame_named = { + key: key + for key in ( + FIFF.FIFFV_COORD_UNKNOWN, + FIFF.FIFFV_COORD_DEVICE, + FIFF.FIFFV_COORD_ISOTRAK, + FIFF.FIFFV_COORD_HPI, + FIFF.FIFFV_COORD_HEAD, + FIFF.FIFFV_COORD_MRI, + FIFF.FIFFV_COORD_MRI_SLICE, + FIFF.FIFFV_COORD_MRI_DISPLAY, + FIFF.FIFFV_COORD_DICOM_DEVICE, + FIFF.FIFFV_COORD_IMAGING_DEVICE, + ) +} # # Needed for raw and evoked-response data # -FIFF.FIFF_DATA_BUFFER = 300 # Buffer containing measurement data -FIFF.FIFF_DATA_SKIP = 301 # Data skip in buffers -FIFF.FIFF_EPOCH = 302 # Buffer containing one epoch and channel -FIFF.FIFF_DATA_SKIP_SAMP = 303 # Data skip in samples +FIFF.FIFF_DATA_BUFFER = 300 # Buffer containing measurement data +FIFF.FIFF_DATA_SKIP = 301 # Data skip in buffers +FIFF.FIFF_EPOCH = 302 # Buffer containing one epoch and channel +FIFF.FIFF_DATA_SKIP_SAMP = 303 # Data skip in samples # # Info on subject # -FIFF.FIFF_SUBJ_ID = 400 # Subject ID -FIFF.FIFF_SUBJ_FIRST_NAME = 401 # First name of the subject -FIFF.FIFF_SUBJ_MIDDLE_NAME = 402 # Middle name of the subject -FIFF.FIFF_SUBJ_LAST_NAME = 403 # Last name of the subject -FIFF.FIFF_SUBJ_BIRTH_DAY = 404 # Birthday of the subject -FIFF.FIFF_SUBJ_SEX = 405 # Sex of the subject -FIFF.FIFF_SUBJ_HAND = 406 # Handedness of the subject -FIFF.FIFF_SUBJ_WEIGHT = 407 # Weight of the subject in kg -FIFF.FIFF_SUBJ_HEIGHT = 408 # Height of the subject in m -FIFF.FIFF_SUBJ_COMMENT = 409 # Comment about the subject -FIFF.FIFF_SUBJ_HIS_ID = 410 # ID used in the Hospital Information System - -FIFF.FIFFV_SUBJ_HAND_RIGHT = 1 # Righthanded -FIFF.FIFFV_SUBJ_HAND_LEFT = 2 # Lefthanded -FIFF.FIFFV_SUBJ_HAND_AMBI = 3 # Ambidextrous - -FIFF.FIFFV_SUBJ_SEX_UNKNOWN = 0 # Unknown gender -FIFF.FIFFV_SUBJ_SEX_MALE = 1 # Male -FIFF.FIFFV_SUBJ_SEX_FEMALE = 2 # Female - -FIFF.FIFF_PROJ_ID = 500 -FIFF.FIFF_PROJ_NAME = 501 -FIFF.FIFF_PROJ_AIM = 502 -FIFF.FIFF_PROJ_PERSONS = 503 -FIFF.FIFF_PROJ_COMMENT = 504 - -FIFF.FIFF_EVENT_CHANNELS = 600 # Event channel numbers -FIFF.FIFF_EVENT_LIST = 601 # List of events (integers: -FIFF.FIFF_EVENT_CHANNEL = 602 # Event channel -FIFF.FIFF_EVENT_BITS = 603 # Event bits array +FIFF.FIFF_SUBJ_ID = 400 # Subject ID +FIFF.FIFF_SUBJ_FIRST_NAME = 401 # First name of the subject +FIFF.FIFF_SUBJ_MIDDLE_NAME = 402 # Middle name of the subject +FIFF.FIFF_SUBJ_LAST_NAME = 403 # Last name of the subject +FIFF.FIFF_SUBJ_BIRTH_DAY = 404 # Birthday of the subject +FIFF.FIFF_SUBJ_SEX = 405 # Sex of the subject +FIFF.FIFF_SUBJ_HAND = 406 # Handedness of the subject +FIFF.FIFF_SUBJ_WEIGHT = 407 # Weight of the subject in kg +FIFF.FIFF_SUBJ_HEIGHT = 408 # Height of the subject in m +FIFF.FIFF_SUBJ_COMMENT = 409 # Comment about the subject +FIFF.FIFF_SUBJ_HIS_ID = 410 # ID used in the Hospital Information System + +FIFF.FIFFV_SUBJ_HAND_RIGHT = 1 # Righthanded +FIFF.FIFFV_SUBJ_HAND_LEFT = 2 # Lefthanded +FIFF.FIFFV_SUBJ_HAND_AMBI = 3 # Ambidextrous + +FIFF.FIFFV_SUBJ_SEX_UNKNOWN = 0 # Unknown gender +FIFF.FIFFV_SUBJ_SEX_MALE = 1 # Male +FIFF.FIFFV_SUBJ_SEX_FEMALE = 2 # Female + +FIFF.FIFF_PROJ_ID = 500 +FIFF.FIFF_PROJ_NAME = 501 +FIFF.FIFF_PROJ_AIM = 502 +FIFF.FIFF_PROJ_PERSONS = 503 +FIFF.FIFF_PROJ_COMMENT = 504 + +FIFF.FIFF_EVENT_CHANNELS = 600 # Event channel numbers +FIFF.FIFF_EVENT_LIST = 601 # List of events (integers: +FIFF.FIFF_EVENT_CHANNEL = 602 # Event channel +FIFF.FIFF_EVENT_BITS = 603 # Event bits array # # Tags used in saving SQUID characteristics etc. # -FIFF.FIFF_SQUID_BIAS = 701 -FIFF.FIFF_SQUID_OFFSET = 702 -FIFF.FIFF_SQUID_GATE = 703 +FIFF.FIFF_SQUID_BIAS = 701 +FIFF.FIFF_SQUID_OFFSET = 702 +FIFF.FIFF_SQUID_GATE = 703 # # Aspect values used to save characteristic curves of SQUIDs. (mjk) # -FIFF.FIFFV_ASPECT_IFII_LOW = 1100 +FIFF.FIFFV_ASPECT_IFII_LOW = 1100 FIFF.FIFFV_ASPECT_IFII_HIGH = 1101 -FIFF.FIFFV_ASPECT_GATE = 1102 +FIFF.FIFFV_ASPECT_GATE = 1102 # # Values for file references @@ -334,71 +352,74 @@ # # References # -FIFF.FIFF_REF_PATH = 1101 +FIFF.FIFF_REF_PATH = 1101 # # Different aspects of data # -FIFF.FIFFV_ASPECT_AVERAGE = 100 # Normal average of epochs -FIFF.FIFFV_ASPECT_STD_ERR = 101 # Std. error of mean -FIFF.FIFFV_ASPECT_SINGLE = 102 # Single epoch cut out from the continuous data -FIFF.FIFFV_ASPECT_SUBAVERAGE = 103 # Partial average (subaverage) -FIFF.FIFFV_ASPECT_ALTAVERAGE = 104 # Alternating subaverage -FIFF.FIFFV_ASPECT_SAMPLE = 105 # A sample cut out by graph +FIFF.FIFFV_ASPECT_AVERAGE = 100 # Normal average of epochs +FIFF.FIFFV_ASPECT_STD_ERR = 101 # Std. error of mean +FIFF.FIFFV_ASPECT_SINGLE = 102 # Single epoch cut out from the continuous data +FIFF.FIFFV_ASPECT_SUBAVERAGE = 103 # Partial average (subaverage) +FIFF.FIFFV_ASPECT_ALTAVERAGE = 104 # Alternating subaverage +FIFF.FIFFV_ASPECT_SAMPLE = 105 # A sample cut out by graph FIFF.FIFFV_ASPECT_POWER_DENSITY = 106 # Power density spectrum -FIFF.FIFFV_ASPECT_DIPOLE_WAVE = 200 # Dipole amplitude curve +FIFF.FIFFV_ASPECT_DIPOLE_WAVE = 200 # Dipole amplitude curve # # BEM surface IDs # -FIFF.FIFFV_BEM_SURF_ID_UNKNOWN = -1 -FIFF.FIFFV_BEM_SURF_ID_NOT_KNOWN = 0 -FIFF.FIFFV_BEM_SURF_ID_BRAIN = 1 -FIFF.FIFFV_BEM_SURF_ID_CSF = 2 -FIFF.FIFFV_BEM_SURF_ID_SKULL = 3 -FIFF.FIFFV_BEM_SURF_ID_HEAD = 4 +FIFF.FIFFV_BEM_SURF_ID_UNKNOWN = -1 +FIFF.FIFFV_BEM_SURF_ID_NOT_KNOWN = 0 +FIFF.FIFFV_BEM_SURF_ID_BRAIN = 1 +FIFF.FIFFV_BEM_SURF_ID_CSF = 2 +FIFF.FIFFV_BEM_SURF_ID_SKULL = 3 +FIFF.FIFFV_BEM_SURF_ID_HEAD = 4 -FIFF.FIFF_SPHERE_ORIGIN = 3001 -FIFF.FIFF_SPHERE_RADIUS = 3002 +FIFF.FIFF_SPHERE_ORIGIN = 3001 +FIFF.FIFF_SPHERE_RADIUS = 3002 -FIFF.FIFF_BEM_SURF_ID = 3101 # int surface number -FIFF.FIFF_BEM_SURF_NAME = 3102 # string surface name -FIFF.FIFF_BEM_SURF_NNODE = 3103 # int number of nodes on a surface -FIFF.FIFF_BEM_SURF_NTRI = 3104 # int number of triangles on a surface -FIFF.FIFF_BEM_SURF_NODES = 3105 # float surface nodes (nnode,3) -FIFF.FIFF_BEM_SURF_TRIANGLES = 3106 # int surface triangles (ntri,3) -FIFF.FIFF_BEM_SURF_NORMALS = 3107 # float surface node normal unit vectors +FIFF.FIFF_BEM_SURF_ID = 3101 # int surface number +FIFF.FIFF_BEM_SURF_NAME = 3102 # string surface name +FIFF.FIFF_BEM_SURF_NNODE = 3103 # int number of nodes on a surface +FIFF.FIFF_BEM_SURF_NTRI = 3104 # int number of triangles on a surface +FIFF.FIFF_BEM_SURF_NODES = 3105 # float surface nodes (nnode,3) +FIFF.FIFF_BEM_SURF_TRIANGLES = 3106 # int surface triangles (ntri,3) +FIFF.FIFF_BEM_SURF_NORMALS = 3107 # float surface node normal unit vectors -FIFF.FIFF_BEM_POT_SOLUTION = 3110 # float ** The solution matrix -FIFF.FIFF_BEM_APPROX = 3111 # int approximation method, see below -FIFF.FIFF_BEM_COORD_FRAME = 3112 # The coordinate frame of the model -FIFF.FIFF_BEM_SIGMA = 3113 # Conductivity of a compartment -FIFF.FIFFV_BEM_APPROX_CONST = 1 # The constant potential approach -FIFF.FIFFV_BEM_APPROX_LINEAR = 2 # The linear potential approach +FIFF.FIFF_BEM_POT_SOLUTION = 3110 # float ** The solution matrix +FIFF.FIFF_BEM_APPROX = 3111 # int approximation method, see below +FIFF.FIFF_BEM_COORD_FRAME = 3112 # The coordinate frame of the model +FIFF.FIFF_BEM_SIGMA = 3113 # Conductivity of a compartment +FIFF.FIFFV_BEM_APPROX_CONST = 1 # The constant potential approach +FIFF.FIFFV_BEM_APPROX_LINEAR = 2 # The linear potential approach # # More of those defined in MNE # -FIFF.FIFFV_MNE_SURF_UNKNOWN = -1 -FIFF.FIFFV_MNE_SURF_LEFT_HEMI = 101 -FIFF.FIFFV_MNE_SURF_RIGHT_HEMI = 102 -FIFF.FIFFV_MNE_SURF_MEG_HELMET = 201 # Use this irrespective of the system +FIFF.FIFFV_MNE_SURF_UNKNOWN = -1 +FIFF.FIFFV_MNE_SURF_LEFT_HEMI = 101 +FIFF.FIFFV_MNE_SURF_RIGHT_HEMI = 102 +FIFF.FIFFV_MNE_SURF_MEG_HELMET = 201 # Use this irrespective of the system # # These relate to the Isotrak data (enum(point)) # FIFF.FIFFV_POINT_CARDINAL = 1 -FIFF.FIFFV_POINT_HPI = 2 -FIFF.FIFFV_POINT_EEG = 3 -FIFF.FIFFV_POINT_ECG = FIFF.FIFFV_POINT_EEG -FIFF.FIFFV_POINT_EXTRA = 4 -FIFF.FIFFV_POINT_HEAD = 5 # Point on the surface of the head -_dig_kind_named = {key: key for key in( - FIFF.FIFFV_POINT_CARDINAL, - FIFF.FIFFV_POINT_HPI, - FIFF.FIFFV_POINT_EEG, - FIFF.FIFFV_POINT_EXTRA, - FIFF.FIFFV_POINT_HEAD, -)} +FIFF.FIFFV_POINT_HPI = 2 +FIFF.FIFFV_POINT_EEG = 3 +FIFF.FIFFV_POINT_ECG = FIFF.FIFFV_POINT_EEG +FIFF.FIFFV_POINT_EXTRA = 4 +FIFF.FIFFV_POINT_HEAD = 5 # Point on the surface of the head +_dig_kind_named = { + key: key + for key in ( + FIFF.FIFFV_POINT_CARDINAL, + FIFF.FIFFV_POINT_HPI, + FIFF.FIFFV_POINT_EEG, + FIFF.FIFFV_POINT_EXTRA, + FIFF.FIFFV_POINT_HEAD, + ) +} # # Cardinal point types (enum(cardinal_point)) # @@ -406,155 +427,178 @@ FIFF.FIFFV_POINT_NASION = 2 FIFF.FIFFV_POINT_RPA = 3 FIFF.FIFFV_POINT_INION = 4 -_dig_cardinal_named = {key: key for key in ( - FIFF.FIFFV_POINT_LPA, - FIFF.FIFFV_POINT_NASION, - FIFF.FIFFV_POINT_RPA, - FIFF.FIFFV_POINT_INION, -)} +_dig_cardinal_named = { + key: key + for key in ( + FIFF.FIFFV_POINT_LPA, + FIFF.FIFFV_POINT_NASION, + FIFF.FIFFV_POINT_RPA, + FIFF.FIFFV_POINT_INION, + ) +} # # SSP # -FIFF.FIFF_PROJ_ITEM_KIND = 3411 -FIFF.FIFF_PROJ_ITEM_TIME = 3412 -FIFF.FIFF_PROJ_ITEM_NVEC = 3414 -FIFF.FIFF_PROJ_ITEM_VECTORS = 3415 -FIFF.FIFF_PROJ_ITEM_DEFINITION = 3416 +FIFF.FIFF_PROJ_ITEM_KIND = 3411 +FIFF.FIFF_PROJ_ITEM_TIME = 3412 +FIFF.FIFF_PROJ_ITEM_NVEC = 3414 +FIFF.FIFF_PROJ_ITEM_VECTORS = 3415 +FIFF.FIFF_PROJ_ITEM_DEFINITION = 3416 FIFF.FIFF_PROJ_ITEM_CH_NAME_LIST = 3417 # XPlotter -FIFF.FIFF_XPLOTTER_LAYOUT = 3501 # string - "Xplotter layout tag" +FIFF.FIFF_XPLOTTER_LAYOUT = 3501 # string - "Xplotter layout tag" # # MRIs # -FIFF.FIFF_MRI_SOURCE_PATH = FIFF.FIFF_REF_PATH -FIFF.FIFF_MRI_SOURCE_FORMAT = 2002 -FIFF.FIFF_MRI_PIXEL_ENCODING = 2003 +FIFF.FIFF_MRI_SOURCE_PATH = FIFF.FIFF_REF_PATH +FIFF.FIFF_MRI_SOURCE_FORMAT = 2002 +FIFF.FIFF_MRI_PIXEL_ENCODING = 2003 FIFF.FIFF_MRI_PIXEL_DATA_OFFSET = 2004 -FIFF.FIFF_MRI_PIXEL_SCALE = 2005 -FIFF.FIFF_MRI_PIXEL_DATA = 2006 +FIFF.FIFF_MRI_PIXEL_SCALE = 2005 +FIFF.FIFF_MRI_PIXEL_DATA = 2006 FIFF.FIFF_MRI_PIXEL_OVERLAY_ENCODING = 2007 -FIFF.FIFF_MRI_PIXEL_OVERLAY_DATA = 2008 -FIFF.FIFF_MRI_BOUNDING_BOX = 2009 -FIFF.FIFF_MRI_WIDTH = 2010 -FIFF.FIFF_MRI_WIDTH_M = 2011 -FIFF.FIFF_MRI_HEIGHT = 2012 -FIFF.FIFF_MRI_HEIGHT_M = 2013 -FIFF.FIFF_MRI_DEPTH = 2014 -FIFF.FIFF_MRI_DEPTH_M = 2015 -FIFF.FIFF_MRI_THICKNESS = 2016 -FIFF.FIFF_MRI_SCENE_AIM = 2017 -FIFF.FIFF_MRI_ORIG_SOURCE_PATH = 2020 -FIFF.FIFF_MRI_ORIG_SOURCE_FORMAT = 2021 -FIFF.FIFF_MRI_ORIG_PIXEL_ENCODING = 2022 +FIFF.FIFF_MRI_PIXEL_OVERLAY_DATA = 2008 +FIFF.FIFF_MRI_BOUNDING_BOX = 2009 +FIFF.FIFF_MRI_WIDTH = 2010 +FIFF.FIFF_MRI_WIDTH_M = 2011 +FIFF.FIFF_MRI_HEIGHT = 2012 +FIFF.FIFF_MRI_HEIGHT_M = 2013 +FIFF.FIFF_MRI_DEPTH = 2014 +FIFF.FIFF_MRI_DEPTH_M = 2015 +FIFF.FIFF_MRI_THICKNESS = 2016 +FIFF.FIFF_MRI_SCENE_AIM = 2017 +FIFF.FIFF_MRI_ORIG_SOURCE_PATH = 2020 +FIFF.FIFF_MRI_ORIG_SOURCE_FORMAT = 2021 +FIFF.FIFF_MRI_ORIG_PIXEL_ENCODING = 2022 FIFF.FIFF_MRI_ORIG_PIXEL_DATA_OFFSET = 2023 -FIFF.FIFF_MRI_VOXEL_DATA = 2030 -FIFF.FIFF_MRI_VOXEL_ENCODING = 2031 -FIFF.FIFF_MRI_MRILAB_SETUP = 2100 -FIFF.FIFF_MRI_SEG_REGION_ID = 2200 -# -FIFF.FIFFV_MRI_PIXEL_UNKNOWN = 0 -FIFF.FIFFV_MRI_PIXEL_BYTE = 1 -FIFF.FIFFV_MRI_PIXEL_WORD = 2 -FIFF.FIFFV_MRI_PIXEL_SWAP_WORD = 3 -FIFF.FIFFV_MRI_PIXEL_FLOAT = 4 +FIFF.FIFF_MRI_VOXEL_DATA = 2030 +FIFF.FIFF_MRI_VOXEL_ENCODING = 2031 +FIFF.FIFF_MRI_MRILAB_SETUP = 2100 +FIFF.FIFF_MRI_SEG_REGION_ID = 2200 +# +FIFF.FIFFV_MRI_PIXEL_UNKNOWN = 0 +FIFF.FIFFV_MRI_PIXEL_BYTE = 1 +FIFF.FIFFV_MRI_PIXEL_WORD = 2 +FIFF.FIFFV_MRI_PIXEL_SWAP_WORD = 3 +FIFF.FIFFV_MRI_PIXEL_FLOAT = 4 FIFF.FIFFV_MRI_PIXEL_BYTE_INDEXED_COLOR = 5 -FIFF.FIFFV_MRI_PIXEL_BYTE_RGB_COLOR = 6 +FIFF.FIFFV_MRI_PIXEL_BYTE_RGB_COLOR = 6 FIFF.FIFFV_MRI_PIXEL_BYTE_RLE_RGB_COLOR = 7 -FIFF.FIFFV_MRI_PIXEL_BIT_RLE = 8 +FIFF.FIFFV_MRI_PIXEL_BIT_RLE = 8 # # These are the MNE fiff definitions (range 350-390 reserved for MNE) # -FIFF.FIFFB_MNE = 350 -FIFF.FIFFB_MNE_SOURCE_SPACE = 351 -FIFF.FIFFB_MNE_FORWARD_SOLUTION = 352 -FIFF.FIFFB_MNE_PARENT_MRI_FILE = 353 -FIFF.FIFFB_MNE_PARENT_MEAS_FILE = 354 -FIFF.FIFFB_MNE_COV = 355 -FIFF.FIFFB_MNE_INVERSE_SOLUTION = 356 -FIFF.FIFFB_MNE_NAMED_MATRIX = 357 -FIFF.FIFFB_MNE_ENV = 358 -FIFF.FIFFB_MNE_BAD_CHANNELS = 359 -FIFF.FIFFB_MNE_VERTEX_MAP = 360 -FIFF.FIFFB_MNE_EVENTS = 361 -FIFF.FIFFB_MNE_MORPH_MAP = 362 -FIFF.FIFFB_MNE_SURFACE_MAP = 363 -FIFF.FIFFB_MNE_SURFACE_MAP_GROUP = 364 +FIFF.FIFFB_MNE = 350 +FIFF.FIFFB_MNE_SOURCE_SPACE = 351 +FIFF.FIFFB_MNE_FORWARD_SOLUTION = 352 +FIFF.FIFFB_MNE_PARENT_MRI_FILE = 353 +FIFF.FIFFB_MNE_PARENT_MEAS_FILE = 354 +FIFF.FIFFB_MNE_COV = 355 +FIFF.FIFFB_MNE_INVERSE_SOLUTION = 356 +FIFF.FIFFB_MNE_NAMED_MATRIX = 357 +FIFF.FIFFB_MNE_ENV = 358 +FIFF.FIFFB_MNE_BAD_CHANNELS = 359 +FIFF.FIFFB_MNE_VERTEX_MAP = 360 +FIFF.FIFFB_MNE_EVENTS = 361 +FIFF.FIFFB_MNE_MORPH_MAP = 362 +FIFF.FIFFB_MNE_SURFACE_MAP = 363 +FIFF.FIFFB_MNE_SURFACE_MAP_GROUP = 364 # # CTF compensation data # -FIFF.FIFFB_MNE_CTF_COMP = 370 -FIFF.FIFFB_MNE_CTF_COMP_DATA = 371 -FIFF.FIFFB_MNE_DERIVATIONS = 372 +FIFF.FIFFB_MNE_CTF_COMP = 370 +FIFF.FIFFB_MNE_CTF_COMP_DATA = 371 +FIFF.FIFFB_MNE_DERIVATIONS = 372 -FIFF.FIFFB_MNE_EPOCHS = 373 -FIFF.FIFFB_MNE_ICA = 374 +FIFF.FIFFB_MNE_EPOCHS = 373 +FIFF.FIFFB_MNE_ICA = 374 # # Fiff tags associated with MNE computations (3500...) # # # 3500... Bookkeeping # -FIFF.FIFF_MNE_ROW_NAMES = 3502 -FIFF.FIFF_MNE_COL_NAMES = 3503 -FIFF.FIFF_MNE_NROW = 3504 -FIFF.FIFF_MNE_NCOL = 3505 -FIFF.FIFF_MNE_COORD_FRAME = 3506 # Coordinate frame employed. Defaults: - # FIFFB_MNE_SOURCE_SPACE FIFFV_COORD_MRI - # FIFFB_MNE_FORWARD_SOLUTION FIFFV_COORD_HEAD - # FIFFB_MNE_INVERSE_SOLUTION FIFFV_COORD_HEAD -FIFF.FIFF_MNE_CH_NAME_LIST = 3507 -FIFF.FIFF_MNE_FILE_NAME = 3508 # This removes the collision with fiff_file.h (used to be 3501) +FIFF.FIFF_MNE_ROW_NAMES = 3502 +FIFF.FIFF_MNE_COL_NAMES = 3503 +FIFF.FIFF_MNE_NROW = 3504 +FIFF.FIFF_MNE_NCOL = 3505 +FIFF.FIFF_MNE_COORD_FRAME = 3506 # Coordinate frame employed. Defaults: +# FIFFB_MNE_SOURCE_SPACE FIFFV_COORD_MRI +# FIFFB_MNE_FORWARD_SOLUTION FIFFV_COORD_HEAD +# FIFFB_MNE_INVERSE_SOLUTION FIFFV_COORD_HEAD +FIFF.FIFF_MNE_CH_NAME_LIST = 3507 +FIFF.FIFF_MNE_FILE_NAME = ( + 3508 # This removes the collision with fiff_file.h (used to be 3501) +) # # 3510... 3590... Source space or surface # -FIFF.FIFF_MNE_SOURCE_SPACE_POINTS = 3510 # The vertices -FIFF.FIFF_MNE_SOURCE_SPACE_NORMALS = 3511 # The vertex normals -FIFF.FIFF_MNE_SOURCE_SPACE_NPOINTS = 3512 # How many vertices -FIFF.FIFF_MNE_SOURCE_SPACE_SELECTION = 3513 # Which are selected to the source space -FIFF.FIFF_MNE_SOURCE_SPACE_NUSE = 3514 # How many are in use -FIFF.FIFF_MNE_SOURCE_SPACE_NEAREST = 3515 # Nearest source space vertex for all vertices -FIFF.FIFF_MNE_SOURCE_SPACE_NEAREST_DIST = 3516 # Distance to the Nearest source space vertex for all vertices -FIFF.FIFF_MNE_SOURCE_SPACE_ID = 3517 # Identifier -FIFF.FIFF_MNE_SOURCE_SPACE_TYPE = 3518 # Surface or volume -FIFF.FIFF_MNE_SOURCE_SPACE_VERTICES = 3519 # List of vertices (zero based) - -FIFF.FIFF_MNE_SOURCE_SPACE_VOXEL_DIMS = 3596 # Voxel space dimensions in a volume source space -FIFF.FIFF_MNE_SOURCE_SPACE_INTERPOLATOR = 3597 # Matrix to interpolate a volume source space into a mri volume -FIFF.FIFF_MNE_SOURCE_SPACE_MRI_FILE = 3598 # MRI file used in the interpolation - -FIFF.FIFF_MNE_SOURCE_SPACE_NTRI = 3590 # Number of triangles -FIFF.FIFF_MNE_SOURCE_SPACE_TRIANGLES = 3591 # The triangulation -FIFF.FIFF_MNE_SOURCE_SPACE_NUSE_TRI = 3592 # Number of triangles corresponding to the number of vertices in use -FIFF.FIFF_MNE_SOURCE_SPACE_USE_TRIANGLES = 3593 # The triangulation of the used vertices in the source space -FIFF.FIFF_MNE_SOURCE_SPACE_NNEIGHBORS = 3594 # Number of neighbors for each source space point (used for volume source spaces) -FIFF.FIFF_MNE_SOURCE_SPACE_NEIGHBORS = 3595 # Neighbors for each source space point (used for volume source spaces) - -FIFF.FIFF_MNE_SOURCE_SPACE_DIST = 3599 # Distances between vertices in use (along the surface) -FIFF.FIFF_MNE_SOURCE_SPACE_DIST_LIMIT = 3600 # If distance is above this limit (in the volume) it has not been calculated - -FIFF.FIFF_MNE_SURFACE_MAP_DATA = 3610 # Surface map data -FIFF.FIFF_MNE_SURFACE_MAP_KIND = 3611 # Type of map +FIFF.FIFF_MNE_SOURCE_SPACE_POINTS = 3510 # The vertices +FIFF.FIFF_MNE_SOURCE_SPACE_NORMALS = 3511 # The vertex normals +FIFF.FIFF_MNE_SOURCE_SPACE_NPOINTS = 3512 # How many vertices +FIFF.FIFF_MNE_SOURCE_SPACE_SELECTION = 3513 # Which are selected to the source space +FIFF.FIFF_MNE_SOURCE_SPACE_NUSE = 3514 # How many are in use +FIFF.FIFF_MNE_SOURCE_SPACE_NEAREST = ( + 3515 # Nearest source space vertex for all vertices +) +FIFF.FIFF_MNE_SOURCE_SPACE_NEAREST_DIST = ( + 3516 # Distance to the Nearest source space vertex for all vertices +) +FIFF.FIFF_MNE_SOURCE_SPACE_ID = 3517 # Identifier +FIFF.FIFF_MNE_SOURCE_SPACE_TYPE = 3518 # Surface or volume +FIFF.FIFF_MNE_SOURCE_SPACE_VERTICES = 3519 # List of vertices (zero based) + +FIFF.FIFF_MNE_SOURCE_SPACE_VOXEL_DIMS = ( + 3596 # Voxel space dimensions in a volume source space +) +FIFF.FIFF_MNE_SOURCE_SPACE_INTERPOLATOR = ( + 3597 # Matrix to interpolate a volume source space into a mri volume +) +FIFF.FIFF_MNE_SOURCE_SPACE_MRI_FILE = 3598 # MRI file used in the interpolation + +FIFF.FIFF_MNE_SOURCE_SPACE_NTRI = 3590 # Number of triangles +FIFF.FIFF_MNE_SOURCE_SPACE_TRIANGLES = 3591 # The triangulation +FIFF.FIFF_MNE_SOURCE_SPACE_NUSE_TRI = ( + 3592 # Number of triangles corresponding to the number of vertices in use +) +FIFF.FIFF_MNE_SOURCE_SPACE_USE_TRIANGLES = ( + 3593 # The triangulation of the used vertices in the source space +) +FIFF.FIFF_MNE_SOURCE_SPACE_NNEIGHBORS = 3594 # Number of neighbors for each source space point (used for volume source spaces) +FIFF.FIFF_MNE_SOURCE_SPACE_NEIGHBORS = ( + 3595 # Neighbors for each source space point (used for volume source spaces) +) + +FIFF.FIFF_MNE_SOURCE_SPACE_DIST = ( + 3599 # Distances between vertices in use (along the surface) +) +FIFF.FIFF_MNE_SOURCE_SPACE_DIST_LIMIT = ( + 3600 # If distance is above this limit (in the volume) it has not been calculated +) + +FIFF.FIFF_MNE_SURFACE_MAP_DATA = 3610 # Surface map data +FIFF.FIFF_MNE_SURFACE_MAP_KIND = 3611 # Type of map # # 3520... Forward solution # -FIFF.FIFF_MNE_FORWARD_SOLUTION = 3520 -FIFF.FIFF_MNE_SOURCE_ORIENTATION = 3521 # Fixed or free -FIFF.FIFF_MNE_INCLUDED_METHODS = 3522 -FIFF.FIFF_MNE_FORWARD_SOLUTION_GRAD = 3523 +FIFF.FIFF_MNE_FORWARD_SOLUTION = 3520 +FIFF.FIFF_MNE_SOURCE_ORIENTATION = 3521 # Fixed or free +FIFF.FIFF_MNE_INCLUDED_METHODS = 3522 +FIFF.FIFF_MNE_FORWARD_SOLUTION_GRAD = 3523 # # 3530... Covariance matrix # -FIFF.FIFF_MNE_COV_KIND = 3530 # What kind of a covariance matrix -FIFF.FIFF_MNE_COV_DIM = 3531 # Matrix dimension -FIFF.FIFF_MNE_COV = 3532 # Full matrix in packed representation (lower triangle) -FIFF.FIFF_MNE_COV_DIAG = 3533 # Diagonal matrix -FIFF.FIFF_MNE_COV_EIGENVALUES = 3534 # Eigenvalues and eigenvectors of the above -FIFF.FIFF_MNE_COV_EIGENVECTORS = 3535 -FIFF.FIFF_MNE_COV_NFREE = 3536 # Number of degrees of freedom -FIFF.FIFF_MNE_COV_METHOD = 3537 # The estimator used -FIFF.FIFF_MNE_COV_SCORE = 3538 # Negative log-likelihood +FIFF.FIFF_MNE_COV_KIND = 3530 # What kind of a covariance matrix +FIFF.FIFF_MNE_COV_DIM = 3531 # Matrix dimension +FIFF.FIFF_MNE_COV = 3532 # Full matrix in packed representation (lower triangle) +FIFF.FIFF_MNE_COV_DIAG = 3533 # Diagonal matrix +FIFF.FIFF_MNE_COV_EIGENVALUES = 3534 # Eigenvalues and eigenvectors of the above +FIFF.FIFF_MNE_COV_EIGENVECTORS = 3535 +FIFF.FIFF_MNE_COV_NFREE = 3536 # Number of degrees of freedom +FIFF.FIFF_MNE_COV_METHOD = 3537 # The estimator used +FIFF.FIFF_MNE_COV_SCORE = 3538 # Negative log-likelihood # # 3540... Inverse operator @@ -562,196 +606,218 @@ # We store the inverse operator as the eigenleads, eigenfields, # and weights # -FIFF.FIFF_MNE_INVERSE_LEADS = 3540 # The eigenleads -FIFF.FIFF_MNE_INVERSE_LEADS_WEIGHTED = 3546 # The eigenleads (already weighted with R^0.5) -FIFF.FIFF_MNE_INVERSE_FIELDS = 3541 # The eigenfields -FIFF.FIFF_MNE_INVERSE_SING = 3542 # The singular values -FIFF.FIFF_MNE_PRIORS_USED = 3543 # Which kind of priors have been used for the source covariance matrix -FIFF.FIFF_MNE_INVERSE_FULL = 3544 # Inverse operator as one matrix - # This matrix includes the whitening operator as well - # The regularization is applied -FIFF.FIFF_MNE_INVERSE_SOURCE_ORIENTATIONS = 3545 # Contains the orientation of one source per row - # The source orientations must be expressed in the coordinate system - # given by FIFF_MNE_COORD_FRAME -FIFF.FIFF_MNE_INVERSE_SOURCE_UNIT = 3547 # Are the sources given in Am or Am/m^2 ? +FIFF.FIFF_MNE_INVERSE_LEADS = 3540 # The eigenleads +FIFF.FIFF_MNE_INVERSE_LEADS_WEIGHTED = ( + 3546 # The eigenleads (already weighted with R^0.5) +) +FIFF.FIFF_MNE_INVERSE_FIELDS = 3541 # The eigenfields +FIFF.FIFF_MNE_INVERSE_SING = 3542 # The singular values +FIFF.FIFF_MNE_PRIORS_USED = ( + 3543 # Which kind of priors have been used for the source covariance matrix +) +FIFF.FIFF_MNE_INVERSE_FULL = 3544 # Inverse operator as one matrix +# This matrix includes the whitening operator as well +# The regularization is applied +FIFF.FIFF_MNE_INVERSE_SOURCE_ORIENTATIONS = ( + 3545 # Contains the orientation of one source per row +) +# The source orientations must be expressed in the coordinate system +# given by FIFF_MNE_COORD_FRAME +FIFF.FIFF_MNE_INVERSE_SOURCE_UNIT = 3547 # Are the sources given in Am or Am/m^2 ? # # 3550... Saved environment info # -FIFF.FIFF_MNE_ENV_WORKING_DIR = 3550 # Working directory where the file was created -FIFF.FIFF_MNE_ENV_COMMAND_LINE = 3551 # The command used to create the file -FIFF.FIFF_MNE_EXTERNAL_BIG_ENDIAN = 3552 # Reference to an external binary file (big-endian) */ -FIFF.FIFF_MNE_EXTERNAL_LITTLE_ENDIAN = 3553 # Reference to an external binary file (little-endian) */ +FIFF.FIFF_MNE_ENV_WORKING_DIR = 3550 # Working directory where the file was created +FIFF.FIFF_MNE_ENV_COMMAND_LINE = 3551 # The command used to create the file +FIFF.FIFF_MNE_EXTERNAL_BIG_ENDIAN = ( + 3552 # Reference to an external binary file (big-endian) */ +) +FIFF.FIFF_MNE_EXTERNAL_LITTLE_ENDIAN = ( + 3553 # Reference to an external binary file (little-endian) */ +) # # 3560... Miscellaneous # -FIFF.FIFF_MNE_PROJ_ITEM_ACTIVE = 3560 # Is this projection item active? -FIFF.FIFF_MNE_EVENT_LIST = 3561 # An event list (for STI101 / STI 014) -FIFF.FIFF_MNE_HEMI = 3562 # Hemisphere association for general purposes -FIFF.FIFF_MNE_DATA_SKIP_NOP = 3563 # A data skip turned off in the raw data -FIFF.FIFF_MNE_ORIG_CH_INFO = 3564 # Channel information before any changes -FIFF.FIFF_MNE_EVENT_TRIGGER_MASK = 3565 # Mask applied to the trigger channel values -FIFF.FIFF_MNE_EVENT_COMMENTS = 3566 # Event comments merged into one long string -FIFF.FIFF_MNE_CUSTOM_REF = 3567 # Whether a custom reference was applied to the data -FIFF.FIFF_MNE_BASELINE_MIN = 3568 # Time of baseline beginning -FIFF.FIFF_MNE_BASELINE_MAX = 3569 # Time of baseline end +FIFF.FIFF_MNE_PROJ_ITEM_ACTIVE = 3560 # Is this projection item active? +FIFF.FIFF_MNE_EVENT_LIST = 3561 # An event list (for STI101 / STI 014) +FIFF.FIFF_MNE_HEMI = 3562 # Hemisphere association for general purposes +FIFF.FIFF_MNE_DATA_SKIP_NOP = 3563 # A data skip turned off in the raw data +FIFF.FIFF_MNE_ORIG_CH_INFO = 3564 # Channel information before any changes +FIFF.FIFF_MNE_EVENT_TRIGGER_MASK = 3565 # Mask applied to the trigger channel values +FIFF.FIFF_MNE_EVENT_COMMENTS = 3566 # Event comments merged into one long string +FIFF.FIFF_MNE_CUSTOM_REF = 3567 # Whether a custom reference was applied to the data +FIFF.FIFF_MNE_BASELINE_MIN = 3568 # Time of baseline beginning +FIFF.FIFF_MNE_BASELINE_MAX = 3569 # Time of baseline end # # 3570... Morphing maps # -FIFF.FIFF_MNE_MORPH_MAP = 3570 # Mapping of closest vertices on the sphere -FIFF.FIFF_MNE_MORPH_MAP_FROM = 3571 # Which subject is this map from -FIFF.FIFF_MNE_MORPH_MAP_TO = 3572 # Which subject is this map to +FIFF.FIFF_MNE_MORPH_MAP = 3570 # Mapping of closest vertices on the sphere +FIFF.FIFF_MNE_MORPH_MAP_FROM = 3571 # Which subject is this map from +FIFF.FIFF_MNE_MORPH_MAP_TO = 3572 # Which subject is this map to # # 3580... CTF compensation data # -FIFF.FIFF_MNE_CTF_COMP_KIND = 3580 # What kind of compensation -FIFF.FIFF_MNE_CTF_COMP_DATA = 3581 # The compensation data itself -FIFF.FIFF_MNE_CTF_COMP_CALIBRATED = 3582 # Are the coefficients calibrated? +FIFF.FIFF_MNE_CTF_COMP_KIND = 3580 # What kind of compensation +FIFF.FIFF_MNE_CTF_COMP_DATA = 3581 # The compensation data itself +FIFF.FIFF_MNE_CTF_COMP_CALIBRATED = 3582 # Are the coefficients calibrated? -FIFF.FIFF_MNE_DERIVATION_DATA = 3585 # Used to store information about EEG and other derivations +FIFF.FIFF_MNE_DERIVATION_DATA = ( + 3585 # Used to store information about EEG and other derivations +) # # 3601... values associated with ICA decomposition # -FIFF.FIFF_MNE_ICA_INTERFACE_PARAMS = 3601 # ICA interface parameters -FIFF.FIFF_MNE_ICA_CHANNEL_NAMES = 3602 # ICA channel names -FIFF.FIFF_MNE_ICA_WHITENER = 3603 # ICA whitener -FIFF.FIFF_MNE_ICA_PCA_COMPONENTS = 3604 # PCA components -FIFF.FIFF_MNE_ICA_PCA_EXPLAINED_VAR = 3605 # PCA explained variance -FIFF.FIFF_MNE_ICA_PCA_MEAN = 3606 # PCA mean -FIFF.FIFF_MNE_ICA_MATRIX = 3607 # ICA unmixing matrix -FIFF.FIFF_MNE_ICA_BADS = 3608 # ICA bad sources -FIFF.FIFF_MNE_ICA_MISC_PARAMS = 3609 # ICA misc params +FIFF.FIFF_MNE_ICA_INTERFACE_PARAMS = 3601 # ICA interface parameters +FIFF.FIFF_MNE_ICA_CHANNEL_NAMES = 3602 # ICA channel names +FIFF.FIFF_MNE_ICA_WHITENER = 3603 # ICA whitener +FIFF.FIFF_MNE_ICA_PCA_COMPONENTS = 3604 # PCA components +FIFF.FIFF_MNE_ICA_PCA_EXPLAINED_VAR = 3605 # PCA explained variance +FIFF.FIFF_MNE_ICA_PCA_MEAN = 3606 # PCA mean +FIFF.FIFF_MNE_ICA_MATRIX = 3607 # ICA unmixing matrix +FIFF.FIFF_MNE_ICA_BADS = 3608 # ICA bad sources +FIFF.FIFF_MNE_ICA_MISC_PARAMS = 3609 # ICA misc params # # Miscellaneous # -FIFF.FIFF_MNE_KIT_SYSTEM_ID = 3612 # Unique ID assigned to KIT systems +FIFF.FIFF_MNE_KIT_SYSTEM_ID = 3612 # Unique ID assigned to KIT systems # # Maxfilter tags # -FIFF.FIFF_SSS_FRAME = 263 -FIFF.FIFF_SSS_JOB = 264 -FIFF.FIFF_SSS_ORIGIN = 265 -FIFF.FIFF_SSS_ORD_IN = 266 -FIFF.FIFF_SSS_ORD_OUT = 267 -FIFF.FIFF_SSS_NMAG = 268 -FIFF.FIFF_SSS_COMPONENTS = 269 -FIFF.FIFF_SSS_CAL_CHANS = 270 -FIFF.FIFF_SSS_CAL_CORRS = 271 -FIFF.FIFF_SSS_ST_CORR = 272 -FIFF.FIFF_SSS_NFREE = 278 -FIFF.FIFF_SSS_ST_LENGTH = 279 -FIFF.FIFF_DECOUPLER_MATRIX = 800 +FIFF.FIFF_SSS_FRAME = 263 +FIFF.FIFF_SSS_JOB = 264 +FIFF.FIFF_SSS_ORIGIN = 265 +FIFF.FIFF_SSS_ORD_IN = 266 +FIFF.FIFF_SSS_ORD_OUT = 267 +FIFF.FIFF_SSS_NMAG = 268 +FIFF.FIFF_SSS_COMPONENTS = 269 +FIFF.FIFF_SSS_CAL_CHANS = 270 +FIFF.FIFF_SSS_CAL_CORRS = 271 +FIFF.FIFF_SSS_ST_CORR = 272 +FIFF.FIFF_SSS_NFREE = 278 +FIFF.FIFF_SSS_ST_LENGTH = 279 +FIFF.FIFF_DECOUPLER_MATRIX = 800 # # Fiff values associated with MNE computations # -FIFF.FIFFV_MNE_UNKNOWN_ORI = 0 -FIFF.FIFFV_MNE_FIXED_ORI = 1 -FIFF.FIFFV_MNE_FREE_ORI = 2 +FIFF.FIFFV_MNE_UNKNOWN_ORI = 0 +FIFF.FIFFV_MNE_FIXED_ORI = 1 +FIFF.FIFFV_MNE_FREE_ORI = 2 -FIFF.FIFFV_MNE_MEG = 1 -FIFF.FIFFV_MNE_EEG = 2 -FIFF.FIFFV_MNE_MEG_EEG = 3 +FIFF.FIFFV_MNE_MEG = 1 +FIFF.FIFFV_MNE_EEG = 2 +FIFF.FIFFV_MNE_MEG_EEG = 3 -FIFF.FIFFV_MNE_PRIORS_NONE = 0 -FIFF.FIFFV_MNE_PRIORS_DEPTH = 1 -FIFF.FIFFV_MNE_PRIORS_LORETA = 2 -FIFF.FIFFV_MNE_PRIORS_SULCI = 3 +FIFF.FIFFV_MNE_PRIORS_NONE = 0 +FIFF.FIFFV_MNE_PRIORS_DEPTH = 1 +FIFF.FIFFV_MNE_PRIORS_LORETA = 2 +FIFF.FIFFV_MNE_PRIORS_SULCI = 3 -FIFF.FIFFV_MNE_UNKNOWN_COV = 0 -FIFF.FIFFV_MNE_SENSOR_COV = 1 -FIFF.FIFFV_MNE_NOISE_COV = 1 # This is what it should have been called -FIFF.FIFFV_MNE_SOURCE_COV = 2 -FIFF.FIFFV_MNE_FMRI_PRIOR_COV = 3 -FIFF.FIFFV_MNE_SIGNAL_COV = 4 # This will be potentially employed in beamformers -FIFF.FIFFV_MNE_DEPTH_PRIOR_COV = 5 # The depth weighting prior -FIFF.FIFFV_MNE_ORIENT_PRIOR_COV = 6 # The orientation prior +FIFF.FIFFV_MNE_UNKNOWN_COV = 0 +FIFF.FIFFV_MNE_SENSOR_COV = 1 +FIFF.FIFFV_MNE_NOISE_COV = 1 # This is what it should have been called +FIFF.FIFFV_MNE_SOURCE_COV = 2 +FIFF.FIFFV_MNE_FMRI_PRIOR_COV = 3 +FIFF.FIFFV_MNE_SIGNAL_COV = 4 # This will be potentially employed in beamformers +FIFF.FIFFV_MNE_DEPTH_PRIOR_COV = 5 # The depth weighting prior +FIFF.FIFFV_MNE_ORIENT_PRIOR_COV = 6 # The orientation prior # # Output map types # -FIFF.FIFFV_MNE_MAP_UNKNOWN = -1 # Unspecified -FIFF.FIFFV_MNE_MAP_SCALAR_CURRENT = 1 # Scalar current value -FIFF.FIFFV_MNE_MAP_SCALAR_CURRENT_SIZE = 2 # Absolute value of the above -FIFF.FIFFV_MNE_MAP_VECTOR_CURRENT = 3 # Current vector components -FIFF.FIFFV_MNE_MAP_VECTOR_CURRENT_SIZE = 4 # Vector current size -FIFF.FIFFV_MNE_MAP_T_STAT = 5 # Student's t statistic -FIFF.FIFFV_MNE_MAP_F_STAT = 6 # F statistic -FIFF.FIFFV_MNE_MAP_F_STAT_SQRT = 7 # Square root of the F statistic -FIFF.FIFFV_MNE_MAP_CHI2_STAT = 8 # (Approximate) chi^2 statistic -FIFF.FIFFV_MNE_MAP_CHI2_STAT_SQRT = 9 # Square root of the (approximate) chi^2 statistic -FIFF.FIFFV_MNE_MAP_SCALAR_CURRENT_NOISE = 10 # Current noise approximation (scalar) -FIFF.FIFFV_MNE_MAP_VECTOR_CURRENT_NOISE = 11 # Current noise approximation (vector) +FIFF.FIFFV_MNE_MAP_UNKNOWN = -1 # Unspecified +FIFF.FIFFV_MNE_MAP_SCALAR_CURRENT = 1 # Scalar current value +FIFF.FIFFV_MNE_MAP_SCALAR_CURRENT_SIZE = 2 # Absolute value of the above +FIFF.FIFFV_MNE_MAP_VECTOR_CURRENT = 3 # Current vector components +FIFF.FIFFV_MNE_MAP_VECTOR_CURRENT_SIZE = 4 # Vector current size +FIFF.FIFFV_MNE_MAP_T_STAT = 5 # Student's t statistic +FIFF.FIFFV_MNE_MAP_F_STAT = 6 # F statistic +FIFF.FIFFV_MNE_MAP_F_STAT_SQRT = 7 # Square root of the F statistic +FIFF.FIFFV_MNE_MAP_CHI2_STAT = 8 # (Approximate) chi^2 statistic +FIFF.FIFFV_MNE_MAP_CHI2_STAT_SQRT = ( + 9 # Square root of the (approximate) chi^2 statistic +) +FIFF.FIFFV_MNE_MAP_SCALAR_CURRENT_NOISE = 10 # Current noise approximation (scalar) +FIFF.FIFFV_MNE_MAP_VECTOR_CURRENT_NOISE = 11 # Current noise approximation (vector) # # Source space types (values of FIFF_MNE_SOURCE_SPACE_TYPE) # -FIFF.FIFFV_MNE_SPACE_UNKNOWN = -1 -FIFF.FIFFV_MNE_SPACE_SURFACE = 1 -FIFF.FIFFV_MNE_SPACE_VOLUME = 2 +FIFF.FIFFV_MNE_SPACE_UNKNOWN = -1 +FIFF.FIFFV_MNE_SPACE_SURFACE = 1 +FIFF.FIFFV_MNE_SPACE_VOLUME = 2 FIFF.FIFFV_MNE_SPACE_DISCRETE = 3 # # Covariance matrix channel classification # -FIFF.FIFFV_MNE_COV_CH_UNKNOWN = -1 # No idea -FIFF.FIFFV_MNE_COV_CH_MEG_MAG = 0 # Axial gradiometer or magnetometer [T] -FIFF.FIFFV_MNE_COV_CH_MEG_GRAD = 1 # Planar gradiometer [T/m] -FIFF.FIFFV_MNE_COV_CH_EEG = 2 # EEG [V] +FIFF.FIFFV_MNE_COV_CH_UNKNOWN = -1 # No idea +FIFF.FIFFV_MNE_COV_CH_MEG_MAG = 0 # Axial gradiometer or magnetometer [T] +FIFF.FIFFV_MNE_COV_CH_MEG_GRAD = 1 # Planar gradiometer [T/m] +FIFF.FIFFV_MNE_COV_CH_EEG = 2 # EEG [V] # # Projection item kinds # -FIFF.FIFFV_PROJ_ITEM_NONE = 0 -FIFF.FIFFV_PROJ_ITEM_FIELD = 1 -FIFF.FIFFV_PROJ_ITEM_DIP_FIX = 2 -FIFF.FIFFV_PROJ_ITEM_DIP_ROT = 3 -FIFF.FIFFV_PROJ_ITEM_HOMOG_GRAD = 4 -FIFF.FIFFV_PROJ_ITEM_HOMOG_FIELD = 5 -FIFF.FIFFV_PROJ_ITEM_EEG_AVREF = 10 # Linear projection related to EEG average reference -FIFF.FIFFV_MNE_PROJ_ITEM_EEG_AVREF = FIFF.FIFFV_PROJ_ITEM_EEG_AVREF # backward compat alias +FIFF.FIFFV_PROJ_ITEM_NONE = 0 +FIFF.FIFFV_PROJ_ITEM_FIELD = 1 +FIFF.FIFFV_PROJ_ITEM_DIP_FIX = 2 +FIFF.FIFFV_PROJ_ITEM_DIP_ROT = 3 +FIFF.FIFFV_PROJ_ITEM_HOMOG_GRAD = 4 +FIFF.FIFFV_PROJ_ITEM_HOMOG_FIELD = 5 +FIFF.FIFFV_PROJ_ITEM_EEG_AVREF = ( + 10 # Linear projection related to EEG average reference +) +FIFF.FIFFV_MNE_PROJ_ITEM_EEG_AVREF = ( + FIFF.FIFFV_PROJ_ITEM_EEG_AVREF +) # backward compat alias # # Custom EEG references # -FIFF.FIFFV_MNE_CUSTOM_REF_OFF = 0 -FIFF.FIFFV_MNE_CUSTOM_REF_ON = 1 -FIFF.FIFFV_MNE_CUSTOM_REF_CSD = 2 +FIFF.FIFFV_MNE_CUSTOM_REF_OFF = 0 +FIFF.FIFFV_MNE_CUSTOM_REF_ON = 1 +FIFF.FIFFV_MNE_CUSTOM_REF_CSD = 2 # # SSS job options # -FIFF.FIFFV_SSS_JOB_NOTHING = 0 # No SSS, just copy input to output -FIFF.FIFFV_SSS_JOB_CTC = 1 # No SSS, only cross-talk correction -FIFF.FIFFV_SSS_JOB_FILTER = 2 # Spatial maxwell filtering -FIFF.FIFFV_SSS_JOB_VIRT = 3 # Transform data to another sensor array -FIFF.FIFFV_SSS_JOB_HEAD_POS = 4 # Estimate head positions, no SSS -FIFF.FIFFV_SSS_JOB_MOVEC_FIT = 5 # Estimate and compensate head movement -FIFF.FIFFV_SSS_JOB_MOVEC_QUA = 6 # Compensate head movement from previously estimated head positions -FIFF.FIFFV_SSS_JOB_REC_ALL = 7 # Reconstruct inside and outside signals -FIFF.FIFFV_SSS_JOB_REC_IN = 8 # Reconstruct inside signals -FIFF.FIFFV_SSS_JOB_REC_OUT = 9 # Reconstruct outside signals -FIFF.FIFFV_SSS_JOB_ST = 10 # Spatio-temporal maxwell filtering -FIFF.FIFFV_SSS_JOB_TPROJ = 11 # Temporal projection, no SSS -FIFF.FIFFV_SSS_JOB_XSSS = 12 # Cross-validation SSS -FIFF.FIFFV_SSS_JOB_XSUB = 13 # Cross-validation subtraction, no SSS -FIFF.FIFFV_SSS_JOB_XWAV = 14 # Cross-validation noise waveforms -FIFF.FIFFV_SSS_JOB_NCOV = 15 # Noise covariance estimation -FIFF.FIFFV_SSS_JOB_SCOV = 16 # SSS sample covariance estimation -#} +FIFF.FIFFV_SSS_JOB_NOTHING = 0 # No SSS, just copy input to output +FIFF.FIFFV_SSS_JOB_CTC = 1 # No SSS, only cross-talk correction +FIFF.FIFFV_SSS_JOB_FILTER = 2 # Spatial maxwell filtering +FIFF.FIFFV_SSS_JOB_VIRT = 3 # Transform data to another sensor array +FIFF.FIFFV_SSS_JOB_HEAD_POS = 4 # Estimate head positions, no SSS +FIFF.FIFFV_SSS_JOB_MOVEC_FIT = 5 # Estimate and compensate head movement +FIFF.FIFFV_SSS_JOB_MOVEC_QUA = ( + 6 # Compensate head movement from previously estimated head positions +) +FIFF.FIFFV_SSS_JOB_REC_ALL = 7 # Reconstruct inside and outside signals +FIFF.FIFFV_SSS_JOB_REC_IN = 8 # Reconstruct inside signals +FIFF.FIFFV_SSS_JOB_REC_OUT = 9 # Reconstruct outside signals +FIFF.FIFFV_SSS_JOB_ST = 10 # Spatio-temporal maxwell filtering +FIFF.FIFFV_SSS_JOB_TPROJ = 11 # Temporal projection, no SSS +FIFF.FIFFV_SSS_JOB_XSSS = 12 # Cross-validation SSS +FIFF.FIFFV_SSS_JOB_XSUB = 13 # Cross-validation subtraction, no SSS +FIFF.FIFFV_SSS_JOB_XWAV = 14 # Cross-validation noise waveforms +FIFF.FIFFV_SSS_JOB_NCOV = 15 # Noise covariance estimation +FIFF.FIFFV_SSS_JOB_SCOV = 16 # SSS sample covariance estimation +# } # # Additional coordinate frames # -FIFF.FIFFV_MNE_COORD_TUFTS_EEG = 300 # For Tufts EEG data -FIFF.FIFFV_MNE_COORD_CTF_DEVICE = 1001 # CTF device coordinates -FIFF.FIFFV_MNE_COORD_CTF_HEAD = 1004 # CTF head coordinates -FIFF.FIFFV_MNE_COORD_DIGITIZER = FIFF.FIFFV_COORD_ISOTRAK # Original (Polhemus) digitizer coordinates -FIFF.FIFFV_MNE_COORD_SURFACE_RAS = FIFF.FIFFV_COORD_MRI # The surface RAS coordinates -FIFF.FIFFV_MNE_COORD_MRI_VOXEL = 2001 # The MRI voxel coordinates -FIFF.FIFFV_MNE_COORD_RAS = 2002 # Surface RAS coordinates with non-zero origin -FIFF.FIFFV_MNE_COORD_MNI_TAL = 2003 # MNI Talairach coordinates -FIFF.FIFFV_MNE_COORD_FS_TAL_GTZ = 2004 # FreeSurfer Talairach coordinates (MNI z > 0) -FIFF.FIFFV_MNE_COORD_FS_TAL_LTZ = 2005 # FreeSurfer Talairach coordinates (MNI z < 0) -FIFF.FIFFV_MNE_COORD_FS_TAL = 2006 # FreeSurfer Talairach coordinates +FIFF.FIFFV_MNE_COORD_TUFTS_EEG = 300 # For Tufts EEG data +FIFF.FIFFV_MNE_COORD_CTF_DEVICE = 1001 # CTF device coordinates +FIFF.FIFFV_MNE_COORD_CTF_HEAD = 1004 # CTF head coordinates +FIFF.FIFFV_MNE_COORD_DIGITIZER = ( + FIFF.FIFFV_COORD_ISOTRAK +) # Original (Polhemus) digitizer coordinates +FIFF.FIFFV_MNE_COORD_SURFACE_RAS = FIFF.FIFFV_COORD_MRI # The surface RAS coordinates +FIFF.FIFFV_MNE_COORD_MRI_VOXEL = 2001 # The MRI voxel coordinates +FIFF.FIFFV_MNE_COORD_RAS = 2002 # Surface RAS coordinates with non-zero origin +FIFF.FIFFV_MNE_COORD_MNI_TAL = 2003 # MNI Talairach coordinates +FIFF.FIFFV_MNE_COORD_FS_TAL_GTZ = 2004 # FreeSurfer Talairach coordinates (MNI z > 0) +FIFF.FIFFV_MNE_COORD_FS_TAL_LTZ = 2005 # FreeSurfer Talairach coordinates (MNI z < 0) +FIFF.FIFFV_MNE_COORD_FS_TAL = 2006 # FreeSurfer Talairach coordinates # # 4D and KIT use the same head coordinate system definition as CTF # -FIFF.FIFFV_MNE_COORD_4D_HEAD = FIFF.FIFFV_MNE_COORD_CTF_HEAD -FIFF.FIFFV_MNE_COORD_KIT_HEAD = FIFF.FIFFV_MNE_COORD_CTF_HEAD +FIFF.FIFFV_MNE_COORD_4D_HEAD = FIFF.FIFFV_MNE_COORD_CTF_HEAD +FIFF.FIFFV_MNE_COORD_KIT_HEAD = FIFF.FIFFV_MNE_COORD_CTF_HEAD # # FWD Types @@ -759,52 +825,52 @@ FWD = BunchConstNamed() -FWD.COIL_UNKNOWN = 0 -FWD.COILC_UNKNOWN = 0 -FWD.COILC_EEG = 1000 -FWD.COILC_MAG = 1 -FWD.COILC_AXIAL_GRAD = 2 -FWD.COILC_PLANAR_GRAD = 3 -FWD.COILC_AXIAL_GRAD2 = 4 +FWD.COIL_UNKNOWN = 0 +FWD.COILC_UNKNOWN = 0 +FWD.COILC_EEG = 1000 +FWD.COILC_MAG = 1 +FWD.COILC_AXIAL_GRAD = 2 +FWD.COILC_PLANAR_GRAD = 3 +FWD.COILC_AXIAL_GRAD2 = 4 -FWD.COIL_ACCURACY_POINT = 0 -FWD.COIL_ACCURACY_NORMAL = 1 -FWD.COIL_ACCURACY_ACCURATE = 2 +FWD.COIL_ACCURACY_POINT = 0 +FWD.COIL_ACCURACY_NORMAL = 1 +FWD.COIL_ACCURACY_ACCURATE = 2 -FWD.BEM_IP_APPROACH_LIMIT = 0.1 +FWD.BEM_IP_APPROACH_LIMIT = 0.1 -FWD.BEM_LIN_FIELD_SIMPLE = 1 -FWD.BEM_LIN_FIELD_FERGUSON = 2 -FWD.BEM_LIN_FIELD_URANKAR = 3 +FWD.BEM_LIN_FIELD_SIMPLE = 1 +FWD.BEM_LIN_FIELD_FERGUSON = 2 +FWD.BEM_LIN_FIELD_URANKAR = 3 # # Data types # -FIFF.FIFFT_VOID = 0 -FIFF.FIFFT_BYTE = 1 -FIFF.FIFFT_SHORT = 2 -FIFF.FIFFT_INT = 3 -FIFF.FIFFT_FLOAT = 4 -FIFF.FIFFT_DOUBLE = 5 -FIFF.FIFFT_JULIAN = 6 -FIFF.FIFFT_USHORT = 7 -FIFF.FIFFT_UINT = 8 -FIFF.FIFFT_ULONG = 9 -FIFF.FIFFT_STRING = 10 -FIFF.FIFFT_LONG = 11 -FIFF.FIFFT_DAU_PACK13 = 13 -FIFF.FIFFT_DAU_PACK14 = 14 -FIFF.FIFFT_DAU_PACK16 = 16 -FIFF.FIFFT_COMPLEX_FLOAT = 20 -FIFF.FIFFT_COMPLEX_DOUBLE = 21 -FIFF.FIFFT_OLD_PACK = 23 -FIFF.FIFFT_CH_INFO_STRUCT = 30 -FIFF.FIFFT_ID_STRUCT = 31 -FIFF.FIFFT_DIR_ENTRY_STRUCT = 32 -FIFF.FIFFT_DIG_POINT_STRUCT = 33 -FIFF.FIFFT_CH_POS_STRUCT = 34 -FIFF.FIFFT_COORD_TRANS_STRUCT = 35 -FIFF.FIFFT_DIG_STRING_STRUCT = 36 +FIFF.FIFFT_VOID = 0 +FIFF.FIFFT_BYTE = 1 +FIFF.FIFFT_SHORT = 2 +FIFF.FIFFT_INT = 3 +FIFF.FIFFT_FLOAT = 4 +FIFF.FIFFT_DOUBLE = 5 +FIFF.FIFFT_JULIAN = 6 +FIFF.FIFFT_USHORT = 7 +FIFF.FIFFT_UINT = 8 +FIFF.FIFFT_ULONG = 9 +FIFF.FIFFT_STRING = 10 +FIFF.FIFFT_LONG = 11 +FIFF.FIFFT_DAU_PACK13 = 13 +FIFF.FIFFT_DAU_PACK14 = 14 +FIFF.FIFFT_DAU_PACK16 = 16 +FIFF.FIFFT_COMPLEX_FLOAT = 20 +FIFF.FIFFT_COMPLEX_DOUBLE = 21 +FIFF.FIFFT_OLD_PACK = 23 +FIFF.FIFFT_CH_INFO_STRUCT = 30 +FIFF.FIFFT_ID_STRUCT = 31 +FIFF.FIFFT_DIR_ENTRY_STRUCT = 32 +FIFF.FIFFT_DIG_POINT_STRUCT = 33 +FIFF.FIFFT_CH_POS_STRUCT = 34 +FIFF.FIFFT_COORD_TRANS_STRUCT = 35 +FIFF.FIFFT_DIG_STRING_STRUCT = 36 FIFF.FIFFT_STREAM_SEGMENT_STRUCT = 37 # # Units of measurement @@ -814,143 +880,183 @@ # SI base units # FIFF.FIFF_UNIT_UNITLESS = 0 -FIFF.FIFF_UNIT_M = 1 # meter -FIFF.FIFF_UNIT_KG = 2 # kilogram +FIFF.FIFF_UNIT_M = 1 # meter +FIFF.FIFF_UNIT_KG = 2 # kilogram FIFF.FIFF_UNIT_SEC = 3 # second -FIFF.FIFF_UNIT_A = 4 # ampere -FIFF.FIFF_UNIT_K = 5 # Kelvin +FIFF.FIFF_UNIT_A = 4 # ampere +FIFF.FIFF_UNIT_K = 5 # Kelvin FIFF.FIFF_UNIT_MOL = 6 # mole # # SI Supplementary units # FIFF.FIFF_UNIT_RAD = 7 # radian -FIFF.FIFF_UNIT_SR = 8 # steradian +FIFF.FIFF_UNIT_SR = 8 # steradian # # SI base candela # -FIFF.FIFF_UNIT_CD = 9 # candela +FIFF.FIFF_UNIT_CD = 9 # candela # # SI derived units # FIFF.FIFF_UNIT_MOL_M3 = 10 # mol/m^3 -FIFF.FIFF_UNIT_HZ = 101 # hertz -FIFF.FIFF_UNIT_N = 102 # Newton -FIFF.FIFF_UNIT_PA = 103 # pascal -FIFF.FIFF_UNIT_J = 104 # joule -FIFF.FIFF_UNIT_W = 105 # watt -FIFF.FIFF_UNIT_C = 106 # coulomb -FIFF.FIFF_UNIT_V = 107 # volt -FIFF.FIFF_UNIT_F = 108 # farad +FIFF.FIFF_UNIT_HZ = 101 # hertz +FIFF.FIFF_UNIT_N = 102 # Newton +FIFF.FIFF_UNIT_PA = 103 # pascal +FIFF.FIFF_UNIT_J = 104 # joule +FIFF.FIFF_UNIT_W = 105 # watt +FIFF.FIFF_UNIT_C = 106 # coulomb +FIFF.FIFF_UNIT_V = 107 # volt +FIFF.FIFF_UNIT_F = 108 # farad FIFF.FIFF_UNIT_OHM = 109 # ohm -FIFF.FIFF_UNIT_S = 110 # Siemens (same as Moh, what fiff-constants calls it) -FIFF.FIFF_UNIT_WB = 111 # weber -FIFF.FIFF_UNIT_T = 112 # tesla -FIFF.FIFF_UNIT_H = 113 # Henry +FIFF.FIFF_UNIT_S = 110 # Siemens (same as Moh, what fiff-constants calls it) +FIFF.FIFF_UNIT_WB = 111 # weber +FIFF.FIFF_UNIT_T = 112 # tesla +FIFF.FIFF_UNIT_H = 113 # Henry FIFF.FIFF_UNIT_CEL = 114 # celsius -FIFF.FIFF_UNIT_LM = 115 # lumen -FIFF.FIFF_UNIT_LX = 116 # lux +FIFF.FIFF_UNIT_LM = 115 # lumen +FIFF.FIFF_UNIT_LX = 116 # lux FIFF.FIFF_UNIT_V_M2 = 117 # V/m^2 # # Others we need # -FIFF.FIFF_UNIT_T_M = 201 # T/m -FIFF.FIFF_UNIT_AM = 202 # Am +FIFF.FIFF_UNIT_T_M = 201 # T/m +FIFF.FIFF_UNIT_AM = 202 # Am FIFF.FIFF_UNIT_AM_M2 = 203 # Am/m^2 FIFF.FIFF_UNIT_AM_M3 = 204 # Am/m^3 -FIFF.FIFF_UNIT_PX = 210 # Pixel -_ch_unit_named = {key: key for key in( - FIFF.FIFF_UNIT_NONE, FIFF.FIFF_UNIT_UNITLESS, FIFF.FIFF_UNIT_M, - FIFF.FIFF_UNIT_KG, FIFF.FIFF_UNIT_SEC, FIFF.FIFF_UNIT_A, FIFF.FIFF_UNIT_K, - FIFF.FIFF_UNIT_MOL, FIFF.FIFF_UNIT_RAD, FIFF.FIFF_UNIT_SR, - FIFF.FIFF_UNIT_CD, FIFF.FIFF_UNIT_MOL_M3, FIFF.FIFF_UNIT_HZ, - FIFF.FIFF_UNIT_N, FIFF.FIFF_UNIT_PA, FIFF.FIFF_UNIT_J, FIFF.FIFF_UNIT_W, - FIFF.FIFF_UNIT_C, FIFF.FIFF_UNIT_V, FIFF.FIFF_UNIT_F, FIFF.FIFF_UNIT_OHM, - FIFF.FIFF_UNIT_S, FIFF.FIFF_UNIT_WB, FIFF.FIFF_UNIT_T, FIFF.FIFF_UNIT_H, - FIFF.FIFF_UNIT_CEL, FIFF.FIFF_UNIT_LM, FIFF.FIFF_UNIT_LX, - FIFF.FIFF_UNIT_V_M2, FIFF.FIFF_UNIT_T_M, FIFF.FIFF_UNIT_AM, - FIFF.FIFF_UNIT_AM_M2, FIFF.FIFF_UNIT_AM_M3, - FIFF.FIFF_UNIT_PX, -)} +FIFF.FIFF_UNIT_PX = 210 # Pixel +_ch_unit_named = { + key: key + for key in ( + FIFF.FIFF_UNIT_NONE, + FIFF.FIFF_UNIT_UNITLESS, + FIFF.FIFF_UNIT_M, + FIFF.FIFF_UNIT_KG, + FIFF.FIFF_UNIT_SEC, + FIFF.FIFF_UNIT_A, + FIFF.FIFF_UNIT_K, + FIFF.FIFF_UNIT_MOL, + FIFF.FIFF_UNIT_RAD, + FIFF.FIFF_UNIT_SR, + FIFF.FIFF_UNIT_CD, + FIFF.FIFF_UNIT_MOL_M3, + FIFF.FIFF_UNIT_HZ, + FIFF.FIFF_UNIT_N, + FIFF.FIFF_UNIT_PA, + FIFF.FIFF_UNIT_J, + FIFF.FIFF_UNIT_W, + FIFF.FIFF_UNIT_C, + FIFF.FIFF_UNIT_V, + FIFF.FIFF_UNIT_F, + FIFF.FIFF_UNIT_OHM, + FIFF.FIFF_UNIT_S, + FIFF.FIFF_UNIT_WB, + FIFF.FIFF_UNIT_T, + FIFF.FIFF_UNIT_H, + FIFF.FIFF_UNIT_CEL, + FIFF.FIFF_UNIT_LM, + FIFF.FIFF_UNIT_LX, + FIFF.FIFF_UNIT_V_M2, + FIFF.FIFF_UNIT_T_M, + FIFF.FIFF_UNIT_AM, + FIFF.FIFF_UNIT_AM_M2, + FIFF.FIFF_UNIT_AM_M3, + FIFF.FIFF_UNIT_PX, + ) +} # # Multipliers # -FIFF.FIFF_UNITM_E = 18 -FIFF.FIFF_UNITM_PET = 15 -FIFF.FIFF_UNITM_T = 12 -FIFF.FIFF_UNITM_GIG = 9 -FIFF.FIFF_UNITM_MEG = 6 -FIFF.FIFF_UNITM_K = 3 -FIFF.FIFF_UNITM_H = 2 -FIFF.FIFF_UNITM_DA = 1 +FIFF.FIFF_UNITM_E = 18 +FIFF.FIFF_UNITM_PET = 15 +FIFF.FIFF_UNITM_T = 12 +FIFF.FIFF_UNITM_GIG = 9 +FIFF.FIFF_UNITM_MEG = 6 +FIFF.FIFF_UNITM_K = 3 +FIFF.FIFF_UNITM_H = 2 +FIFF.FIFF_UNITM_DA = 1 FIFF.FIFF_UNITM_NONE = 0 -FIFF.FIFF_UNITM_D = -1 -FIFF.FIFF_UNITM_C = -2 -FIFF.FIFF_UNITM_M = -3 -FIFF.FIFF_UNITM_MU = -6 -FIFF.FIFF_UNITM_N = -9 -FIFF.FIFF_UNITM_P = -12 -FIFF.FIFF_UNITM_F = -15 -FIFF.FIFF_UNITM_A = -18 -_ch_unit_mul_named = {key: key for key in ( - FIFF.FIFF_UNITM_E, FIFF.FIFF_UNITM_PET, FIFF.FIFF_UNITM_T, - FIFF.FIFF_UNITM_GIG, FIFF.FIFF_UNITM_MEG, FIFF.FIFF_UNITM_K, - FIFF.FIFF_UNITM_H, FIFF.FIFF_UNITM_DA, FIFF.FIFF_UNITM_NONE, - FIFF.FIFF_UNITM_D, FIFF.FIFF_UNITM_C, FIFF.FIFF_UNITM_M, - FIFF.FIFF_UNITM_MU, FIFF.FIFF_UNITM_N, FIFF.FIFF_UNITM_P, - FIFF.FIFF_UNITM_F, FIFF.FIFF_UNITM_A, -)} +FIFF.FIFF_UNITM_D = -1 +FIFF.FIFF_UNITM_C = -2 +FIFF.FIFF_UNITM_M = -3 +FIFF.FIFF_UNITM_MU = -6 +FIFF.FIFF_UNITM_N = -9 +FIFF.FIFF_UNITM_P = -12 +FIFF.FIFF_UNITM_F = -15 +FIFF.FIFF_UNITM_A = -18 +_ch_unit_mul_named = { + key: key + for key in ( + FIFF.FIFF_UNITM_E, + FIFF.FIFF_UNITM_PET, + FIFF.FIFF_UNITM_T, + FIFF.FIFF_UNITM_GIG, + FIFF.FIFF_UNITM_MEG, + FIFF.FIFF_UNITM_K, + FIFF.FIFF_UNITM_H, + FIFF.FIFF_UNITM_DA, + FIFF.FIFF_UNITM_NONE, + FIFF.FIFF_UNITM_D, + FIFF.FIFF_UNITM_C, + FIFF.FIFF_UNITM_M, + FIFF.FIFF_UNITM_MU, + FIFF.FIFF_UNITM_N, + FIFF.FIFF_UNITM_P, + FIFF.FIFF_UNITM_F, + FIFF.FIFF_UNITM_A, + ) +} # # Coil types # -FIFF.FIFFV_COIL_NONE = 0 # The location info contains no data -FIFF.FIFFV_COIL_EEG = 1 # EEG electrode position in r0 -FIFF.FIFFV_COIL_NM_122 = 2 # Neuromag 122 coils -FIFF.FIFFV_COIL_NM_24 = 3 # Old 24 channel system in HUT -FIFF.FIFFV_COIL_NM_MCG_AXIAL = 4 # The axial devices in the HUCS MCG system -FIFF.FIFFV_COIL_EEG_BIPOLAR = 5 # Bipolar EEG lead -FIFF.FIFFV_COIL_EEG_CSD = 6 # CSD-transformed EEG lead +FIFF.FIFFV_COIL_NONE = 0 # The location info contains no data +FIFF.FIFFV_COIL_EEG = 1 # EEG electrode position in r0 +FIFF.FIFFV_COIL_NM_122 = 2 # Neuromag 122 coils +FIFF.FIFFV_COIL_NM_24 = 3 # Old 24 channel system in HUT +FIFF.FIFFV_COIL_NM_MCG_AXIAL = 4 # The axial devices in the HUCS MCG system +FIFF.FIFFV_COIL_EEG_BIPOLAR = 5 # Bipolar EEG lead +FIFF.FIFFV_COIL_EEG_CSD = 6 # CSD-transformed EEG lead -FIFF.FIFFV_COIL_DIPOLE = 200 # Time-varying dipole definition +FIFF.FIFFV_COIL_DIPOLE = 200 # Time-varying dipole definition # The coil info contains dipole location (r0) and # direction (ex) -FIFF.FIFFV_COIL_FNIRS_HBO = 300 # fNIRS oxyhemoglobin -FIFF.FIFFV_COIL_FNIRS_HBR = 301 # fNIRS deoxyhemoglobin -FIFF.FIFFV_COIL_FNIRS_CW_AMPLITUDE = 302 # fNIRS continuous wave amplitude -FIFF.FIFFV_COIL_FNIRS_OD = 303 # fNIRS optical density +FIFF.FIFFV_COIL_FNIRS_HBO = 300 # fNIRS oxyhemoglobin +FIFF.FIFFV_COIL_FNIRS_HBR = 301 # fNIRS deoxyhemoglobin +FIFF.FIFFV_COIL_FNIRS_CW_AMPLITUDE = 302 # fNIRS continuous wave amplitude +FIFF.FIFFV_COIL_FNIRS_OD = 303 # fNIRS optical density FIFF.FIFFV_COIL_FNIRS_FD_AC_AMPLITUDE = 304 # fNIRS frequency domain AC amplitude -FIFF.FIFFV_COIL_FNIRS_FD_PHASE = 305 # fNIRS frequency domain phase +FIFF.FIFFV_COIL_FNIRS_FD_PHASE = 305 # fNIRS frequency domain phase FIFF.FIFFV_COIL_FNIRS_RAW = FIFF.FIFFV_COIL_FNIRS_CW_AMPLITUDE # old alias -FIFF.FIFFV_COIL_FNIRS_TD_GATED_AMPLITUDE = 306 # fNIRS time-domain gated amplitude -FIFF.FIFFV_COIL_FNIRS_TD_MOMENTS_AMPLITUDE = 307 # fNIRS time-domain moments amplitude +FIFF.FIFFV_COIL_FNIRS_TD_GATED_AMPLITUDE = 306 # fNIRS time-domain gated amplitude +FIFF.FIFFV_COIL_FNIRS_TD_MOMENTS_AMPLITUDE = 307 # fNIRS time-domain moments amplitude -FIFF.FIFFV_COIL_EYETRACK_POS = 400 # Eye-tracking gaze position -FIFF.FIFFV_COIL_EYETRACK_PUPIL = 401 # Eye-tracking pupil size +FIFF.FIFFV_COIL_EYETRACK_POS = 400 # Eye-tracking gaze position +FIFF.FIFFV_COIL_EYETRACK_PUPIL = 401 # Eye-tracking pupil size -FIFF.FIFFV_COIL_MCG_42 = 1000 # For testing the MCG software +FIFF.FIFFV_COIL_MCG_42 = 1000 # For testing the MCG software -FIFF.FIFFV_COIL_POINT_MAGNETOMETER = 2000 # Simple point magnetometer -FIFF.FIFFV_COIL_AXIAL_GRAD_5CM = 2001 # Generic axial gradiometer +FIFF.FIFFV_COIL_POINT_MAGNETOMETER = 2000 # Simple point magnetometer +FIFF.FIFFV_COIL_AXIAL_GRAD_5CM = 2001 # Generic axial gradiometer -FIFF.FIFFV_COIL_VV_PLANAR_W = 3011 # VV prototype wirewound planar sensor -FIFF.FIFFV_COIL_VV_PLANAR_T1 = 3012 # Vectorview SQ20483N planar gradiometer -FIFF.FIFFV_COIL_VV_PLANAR_T2 = 3013 # Vectorview SQ20483N-A planar gradiometer -FIFF.FIFFV_COIL_VV_PLANAR_T3 = 3014 # Vectorview SQ20950N planar gradiometer -FIFF.FIFFV_COIL_VV_PLANAR_T4 = 3015 # Vectorview planar gradiometer (MEG-MRI) -FIFF.FIFFV_COIL_VV_MAG_W = 3021 # VV prototype wirewound magnetometer -FIFF.FIFFV_COIL_VV_MAG_T1 = 3022 # Vectorview SQ20483N magnetometer -FIFF.FIFFV_COIL_VV_MAG_T2 = 3023 # Vectorview SQ20483-A magnetometer -FIFF.FIFFV_COIL_VV_MAG_T3 = 3024 # Vectorview SQ20950N magnetometer -FIFF.FIFFV_COIL_VV_MAG_T4 = 3025 # Vectorview magnetometer (MEG-MRI) +FIFF.FIFFV_COIL_VV_PLANAR_W = 3011 # VV prototype wirewound planar sensor +FIFF.FIFFV_COIL_VV_PLANAR_T1 = 3012 # Vectorview SQ20483N planar gradiometer +FIFF.FIFFV_COIL_VV_PLANAR_T2 = 3013 # Vectorview SQ20483N-A planar gradiometer +FIFF.FIFFV_COIL_VV_PLANAR_T3 = 3014 # Vectorview SQ20950N planar gradiometer +FIFF.FIFFV_COIL_VV_PLANAR_T4 = 3015 # Vectorview planar gradiometer (MEG-MRI) +FIFF.FIFFV_COIL_VV_MAG_W = 3021 # VV prototype wirewound magnetometer +FIFF.FIFFV_COIL_VV_MAG_T1 = 3022 # Vectorview SQ20483N magnetometer +FIFF.FIFFV_COIL_VV_MAG_T2 = 3023 # Vectorview SQ20483-A magnetometer +FIFF.FIFFV_COIL_VV_MAG_T3 = 3024 # Vectorview SQ20950N magnetometer +FIFF.FIFFV_COIL_VV_MAG_T4 = 3025 # Vectorview magnetometer (MEG-MRI) -FIFF.FIFFV_COIL_MAGNES_MAG = 4001 # Magnes WH magnetometer -FIFF.FIFFV_COIL_MAGNES_GRAD = 4002 # Magnes WH gradiometer +FIFF.FIFFV_COIL_MAGNES_MAG = 4001 # Magnes WH magnetometer +FIFF.FIFFV_COIL_MAGNES_GRAD = 4002 # Magnes WH gradiometer # # Magnes reference sensors # -FIFF.FIFFV_COIL_MAGNES_REF_MAG = 4003 -FIFF.FIFFV_COIL_MAGNES_REF_GRAD = 4004 +FIFF.FIFFV_COIL_MAGNES_REF_MAG = 4003 +FIFF.FIFFV_COIL_MAGNES_REF_GRAD = 4004 FIFF.FIFFV_COIL_MAGNES_OFFDIAG_REF_GRAD = 4005 FIFF.FIFFV_COIL_MAGNES_R_MAG = FIFF.FIFFV_COIL_MAGNES_REF_MAG FIFF.FIFFV_COIL_MAGNES_R_GRAD = FIFF.FIFFV_COIL_MAGNES_REF_GRAD @@ -959,121 +1065,151 @@ # # CTF coil and channel types # -FIFF.FIFFV_COIL_CTF_GRAD = 5001 -FIFF.FIFFV_COIL_CTF_REF_MAG = 5002 -FIFF.FIFFV_COIL_CTF_REF_GRAD = 5003 +FIFF.FIFFV_COIL_CTF_GRAD = 5001 +FIFF.FIFFV_COIL_CTF_REF_MAG = 5002 +FIFF.FIFFV_COIL_CTF_REF_GRAD = 5003 FIFF.FIFFV_COIL_CTF_OFFDIAG_REF_GRAD = 5004 # # KIT system coil types # -FIFF.FIFFV_COIL_KIT_GRAD = 6001 -FIFF.FIFFV_COIL_KIT_REF_MAG = 6002 +FIFF.FIFFV_COIL_KIT_GRAD = 6001 +FIFF.FIFFV_COIL_KIT_REF_MAG = 6002 # # BabySQUID sensors # -FIFF.FIFFV_COIL_BABY_GRAD = 7001 +FIFF.FIFFV_COIL_BABY_GRAD = 7001 # # BabyMEG sensors # -FIFF.FIFFV_COIL_BABY_MAG = 7002 -FIFF.FIFFV_COIL_BABY_REF_MAG = 7003 -FIFF.FIFFV_COIL_BABY_REF_MAG2 = 7004 +FIFF.FIFFV_COIL_BABY_MAG = 7002 +FIFF.FIFFV_COIL_BABY_REF_MAG = 7003 +FIFF.FIFFV_COIL_BABY_REF_MAG2 = 7004 # # Artemis123 sensors # -FIFF.FIFFV_COIL_ARTEMIS123_GRAD = 7501 -FIFF.FIFFV_COIL_ARTEMIS123_REF_MAG = 7502 -FIFF.FIFFV_COIL_ARTEMIS123_REF_GRAD = 7503 +FIFF.FIFFV_COIL_ARTEMIS123_GRAD = 7501 +FIFF.FIFFV_COIL_ARTEMIS123_REF_MAG = 7502 +FIFF.FIFFV_COIL_ARTEMIS123_REF_GRAD = 7503 # # QuSpin sensors # -FIFF.FIFFV_COIL_QUSPIN_ZFOPM_MAG = 8001 -FIFF.FIFFV_COIL_QUSPIN_ZFOPM_MAG2 = 8002 +FIFF.FIFFV_COIL_QUSPIN_ZFOPM_MAG = 8001 +FIFF.FIFFV_COIL_QUSPIN_ZFOPM_MAG2 = 8002 # # FieldLine sensors # -FIFF.FIFFV_COIL_FIELDLINE_OPM_MAG_GEN1 = 8101 +FIFF.FIFFV_COIL_FIELDLINE_OPM_MAG_GEN1 = 8101 # # Kernel sensors # -FIFF.FIFFV_COIL_KERNEL_OPM_MAG_GEN1 = 8201 +FIFF.FIFFV_COIL_KERNEL_OPM_MAG_GEN1 = 8201 # # KRISS sensors # -FIFF.FIFFV_COIL_KRISS_GRAD = 9001 +FIFF.FIFFV_COIL_KRISS_GRAD = 9001 # # Compumedics adult/pediatric gradiometer # -FIFF.FIFFV_COIL_COMPUMEDICS_ADULT_GRAD = 9101 -FIFF.FIFFV_COIL_COMPUMEDICS_PEDIATRIC_GRAD = 9102 -_ch_coil_type_named = {key: key for key in ( - FIFF.FIFFV_COIL_NONE, FIFF.FIFFV_COIL_EEG, FIFF.FIFFV_COIL_NM_122, - FIFF.FIFFV_COIL_NM_24, FIFF.FIFFV_COIL_NM_MCG_AXIAL, - FIFF.FIFFV_COIL_EEG_BIPOLAR, FIFF.FIFFV_COIL_EEG_CSD, - FIFF.FIFFV_COIL_DIPOLE, FIFF.FIFFV_COIL_FNIRS_HBO, - FIFF.FIFFV_COIL_FNIRS_HBR, FIFF.FIFFV_COIL_FNIRS_RAW, - FIFF.FIFFV_COIL_FNIRS_OD, FIFF.FIFFV_COIL_FNIRS_FD_AC_AMPLITUDE, - FIFF.FIFFV_COIL_FNIRS_FD_PHASE, FIFF.FIFFV_COIL_FNIRS_TD_GATED_AMPLITUDE, - FIFF.FIFFV_COIL_FNIRS_TD_MOMENTS_AMPLITUDE, FIFF.FIFFV_COIL_MCG_42, - FIFF.FIFFV_COIL_EYETRACK_POS, FIFF.FIFFV_COIL_EYETRACK_PUPIL, - FIFF.FIFFV_COIL_POINT_MAGNETOMETER, FIFF.FIFFV_COIL_AXIAL_GRAD_5CM, - FIFF.FIFFV_COIL_VV_PLANAR_W, FIFF.FIFFV_COIL_VV_PLANAR_T1, - FIFF.FIFFV_COIL_VV_PLANAR_T2, FIFF.FIFFV_COIL_VV_PLANAR_T3, - FIFF.FIFFV_COIL_VV_PLANAR_T4, FIFF.FIFFV_COIL_VV_MAG_W, - FIFF.FIFFV_COIL_VV_MAG_T1, FIFF.FIFFV_COIL_VV_MAG_T2, - FIFF.FIFFV_COIL_VV_MAG_T3, FIFF.FIFFV_COIL_VV_MAG_T4, - FIFF.FIFFV_COIL_MAGNES_MAG, FIFF.FIFFV_COIL_MAGNES_GRAD, - FIFF.FIFFV_COIL_MAGNES_REF_MAG, FIFF.FIFFV_COIL_MAGNES_REF_GRAD, - FIFF.FIFFV_COIL_MAGNES_OFFDIAG_REF_GRAD, FIFF.FIFFV_COIL_CTF_GRAD, - FIFF.FIFFV_COIL_CTF_REF_MAG, FIFF.FIFFV_COIL_CTF_REF_GRAD, - FIFF.FIFFV_COIL_CTF_OFFDIAG_REF_GRAD, FIFF.FIFFV_COIL_KIT_GRAD, - FIFF.FIFFV_COIL_KIT_REF_MAG, FIFF.FIFFV_COIL_BABY_GRAD, - FIFF.FIFFV_COIL_BABY_MAG, FIFF.FIFFV_COIL_BABY_REF_MAG, - FIFF.FIFFV_COIL_BABY_REF_MAG2, FIFF.FIFFV_COIL_ARTEMIS123_GRAD, - FIFF.FIFFV_COIL_ARTEMIS123_REF_MAG, FIFF.FIFFV_COIL_ARTEMIS123_REF_GRAD, - FIFF.FIFFV_COIL_QUSPIN_ZFOPM_MAG, FIFF.FIFFV_COIL_QUSPIN_ZFOPM_MAG2, - FIFF.FIFFV_COIL_FIELDLINE_OPM_MAG_GEN1, - FIFF.FIFFV_COIL_KERNEL_OPM_MAG_GEN1, - FIFF.FIFFV_COIL_KRISS_GRAD, FIFF.FIFFV_COIL_COMPUMEDICS_ADULT_GRAD, - FIFF.FIFFV_COIL_COMPUMEDICS_PEDIATRIC_GRAD, -)} +FIFF.FIFFV_COIL_COMPUMEDICS_ADULT_GRAD = 9101 +FIFF.FIFFV_COIL_COMPUMEDICS_PEDIATRIC_GRAD = 9102 +_ch_coil_type_named = { + key: key + for key in ( + FIFF.FIFFV_COIL_NONE, + FIFF.FIFFV_COIL_EEG, + FIFF.FIFFV_COIL_NM_122, + FIFF.FIFFV_COIL_NM_24, + FIFF.FIFFV_COIL_NM_MCG_AXIAL, + FIFF.FIFFV_COIL_EEG_BIPOLAR, + FIFF.FIFFV_COIL_EEG_CSD, + FIFF.FIFFV_COIL_DIPOLE, + FIFF.FIFFV_COIL_FNIRS_HBO, + FIFF.FIFFV_COIL_FNIRS_HBR, + FIFF.FIFFV_COIL_FNIRS_RAW, + FIFF.FIFFV_COIL_FNIRS_OD, + FIFF.FIFFV_COIL_FNIRS_FD_AC_AMPLITUDE, + FIFF.FIFFV_COIL_FNIRS_FD_PHASE, + FIFF.FIFFV_COIL_FNIRS_TD_GATED_AMPLITUDE, + FIFF.FIFFV_COIL_FNIRS_TD_MOMENTS_AMPLITUDE, + FIFF.FIFFV_COIL_MCG_42, + FIFF.FIFFV_COIL_EYETRACK_POS, + FIFF.FIFFV_COIL_EYETRACK_PUPIL, + FIFF.FIFFV_COIL_POINT_MAGNETOMETER, + FIFF.FIFFV_COIL_AXIAL_GRAD_5CM, + FIFF.FIFFV_COIL_VV_PLANAR_W, + FIFF.FIFFV_COIL_VV_PLANAR_T1, + FIFF.FIFFV_COIL_VV_PLANAR_T2, + FIFF.FIFFV_COIL_VV_PLANAR_T3, + FIFF.FIFFV_COIL_VV_PLANAR_T4, + FIFF.FIFFV_COIL_VV_MAG_W, + FIFF.FIFFV_COIL_VV_MAG_T1, + FIFF.FIFFV_COIL_VV_MAG_T2, + FIFF.FIFFV_COIL_VV_MAG_T3, + FIFF.FIFFV_COIL_VV_MAG_T4, + FIFF.FIFFV_COIL_MAGNES_MAG, + FIFF.FIFFV_COIL_MAGNES_GRAD, + FIFF.FIFFV_COIL_MAGNES_REF_MAG, + FIFF.FIFFV_COIL_MAGNES_REF_GRAD, + FIFF.FIFFV_COIL_MAGNES_OFFDIAG_REF_GRAD, + FIFF.FIFFV_COIL_CTF_GRAD, + FIFF.FIFFV_COIL_CTF_REF_MAG, + FIFF.FIFFV_COIL_CTF_REF_GRAD, + FIFF.FIFFV_COIL_CTF_OFFDIAG_REF_GRAD, + FIFF.FIFFV_COIL_KIT_GRAD, + FIFF.FIFFV_COIL_KIT_REF_MAG, + FIFF.FIFFV_COIL_BABY_GRAD, + FIFF.FIFFV_COIL_BABY_MAG, + FIFF.FIFFV_COIL_BABY_REF_MAG, + FIFF.FIFFV_COIL_BABY_REF_MAG2, + FIFF.FIFFV_COIL_ARTEMIS123_GRAD, + FIFF.FIFFV_COIL_ARTEMIS123_REF_MAG, + FIFF.FIFFV_COIL_ARTEMIS123_REF_GRAD, + FIFF.FIFFV_COIL_QUSPIN_ZFOPM_MAG, + FIFF.FIFFV_COIL_QUSPIN_ZFOPM_MAG2, + FIFF.FIFFV_COIL_FIELDLINE_OPM_MAG_GEN1, + FIFF.FIFFV_COIL_KERNEL_OPM_MAG_GEN1, + FIFF.FIFFV_COIL_KRISS_GRAD, + FIFF.FIFFV_COIL_COMPUMEDICS_ADULT_GRAD, + FIFF.FIFFV_COIL_COMPUMEDICS_PEDIATRIC_GRAD, + ) +} # MNE RealTime -FIFF.FIFF_MNE_RT_COMMAND = 3700 # realtime command -FIFF.FIFF_MNE_RT_CLIENT_ID = 3701 # realtime client +FIFF.FIFF_MNE_RT_COMMAND = 3700 # realtime command +FIFF.FIFF_MNE_RT_CLIENT_ID = 3701 # realtime client # MNE epochs bookkeeping -FIFF.FIFF_MNE_EPOCHS_SELECTION = 3800 # the epochs selection -FIFF.FIFF_MNE_EPOCHS_DROP_LOG = 3801 # the drop log -FIFF.FIFF_MNE_EPOCHS_REJECT_FLAT = 3802 # rejection and flat params -FIFF.FIFF_MNE_EPOCHS_RAW_SFREQ = 3803 # original raw sfreq +FIFF.FIFF_MNE_EPOCHS_SELECTION = 3800 # the epochs selection +FIFF.FIFF_MNE_EPOCHS_DROP_LOG = 3801 # the drop log +FIFF.FIFF_MNE_EPOCHS_REJECT_FLAT = 3802 # rejection and flat params +FIFF.FIFF_MNE_EPOCHS_RAW_SFREQ = 3803 # original raw sfreq # MNE annotations -FIFF.FIFFB_MNE_ANNOTATIONS = 3810 # annotations block +FIFF.FIFFB_MNE_ANNOTATIONS = 3810 # annotations block # MNE Metadata Dataframes -FIFF.FIFFB_MNE_METADATA = 3811 # metadata dataframes block +FIFF.FIFFB_MNE_METADATA = 3811 # metadata dataframes block # Table to match unrecognized channel location names to their known aliases CHANNEL_LOC_ALIASES = { # this set of aliases are published in doi:10.1097/WNP.0000000000000316 and # doi:10.1016/S1388-2457(00)00527-7. - 'Cb1': 'POO7', - 'Cb2': 'POO8', - 'CB1': 'POO7', - 'CB2': 'POO8', - 'T1': 'T9', - 'T2': 'T10', - 'T3': 'T7', - 'T4': 'T8', - 'T5': 'T9', - 'T6': 'T10', - 'M1': 'TP9', - 'M2': 'TP10', + "Cb1": "POO7", + "Cb2": "POO8", + "CB1": "POO7", + "CB2": "POO8", + "T1": "T9", + "T2": "T10", + "T3": "T7", + "T4": "T8", + "T5": "T9", + "T6": "T10", + "M1": "TP9", + "M2": "TP10", # EGI ref chan is named VREF/Vertex Ref. # In the standard montages for EGI, the ref is named Cz - 'VREF': 'Cz', - 'Vertex Reference': 'Cz' + "VREF": "Cz", + "Vertex Reference": "Cz" # add a comment here (with doi of a published source) above any new # aliases, as they are added } diff --git a/mne/io/ctf/constants.py b/mne/io/ctf/constants.py index c8dc99880d6..f7176232f4e 100644 --- a/mne/io/ctf/constants.py +++ b/mne/io/ctf/constants.py @@ -36,4 +36,4 @@ # read_write_data.c CTF.HEADER_SIZE = 8 CTF.BLOCK_SIZE = 2000 -CTF.SYSTEM_CLOCK_CH = 'SCLK01-177' +CTF.SYSTEM_CLOCK_CH = "SCLK01-177" diff --git a/mne/io/ctf/ctf.py b/mne/io/ctf/ctf.py index d06e48d8bb5..14d55c05cb3 100644 --- a/mne/io/ctf/ctf.py +++ b/mne/io/ctf/ctf.py @@ -10,8 +10,14 @@ import numpy as np from .._digitization import _format_dig_points -from ...utils import (verbose, logger, _clean_names, fill_doc, _check_option, - _check_fname) +from ...utils import ( + verbose, + logger, + _clean_names, + fill_doc, + _check_option, + _check_fname, +) from ..base import BaseRaw from ..utils import _mult_cal_one, _blk_read_lims @@ -26,8 +32,9 @@ @fill_doc -def read_raw_ctf(directory, system_clock='truncate', preload=False, - clean_names=False, verbose=None): +def read_raw_ctf( + directory, system_clock="truncate", preload=False, clean_names=False, verbose=None +): """Raw object from CTF directory. Parameters @@ -64,8 +71,13 @@ def read_raw_ctf(directory, system_clock='truncate', preload=False, points will then automatically be read into the `mne.io.Raw` instance via `mne.io.read_raw_ctf`. """ - return RawCTF(directory, system_clock, preload=preload, - clean_names=clean_names, verbose=verbose) + return RawCTF( + directory, + system_clock, + preload=preload, + clean_names=clean_names, + verbose=verbose, + ) @fill_doc @@ -93,17 +105,24 @@ class RawCTF(BaseRaw): """ @verbose - def __init__(self, directory, system_clock='truncate', preload=False, - verbose=None, clean_names=False): # noqa: D102 + def __init__( + self, + directory, + system_clock="truncate", + preload=False, + verbose=None, + clean_names=False, + ): # noqa: D102 # adapted from mne_ctf2fiff.c directory = str( _check_fname(directory, "read", True, "directory", need_dir=True) ) - if not directory.endswith('.ds'): - raise TypeError('directory must be a directory ending with ".ds", ' - f'got {directory}') - _check_option('system_clock', system_clock, ['ignore', 'truncate']) - logger.info('ds directory : %s' % directory) + if not directory.endswith(".ds"): + raise TypeError( + 'directory must be a directory ending with ".ds", ' f"got {directory}" + ) + _check_option("system_clock", system_clock, ["ignore", "truncate"]) + logger.info("ds directory : %s" % directory) res4 = _read_res4(directory) # Read the magical res4 file coils = _read_hc(directory) # Read the coil locations eeg = _read_eeg(directory) # Read the EEG electrode loc info @@ -116,9 +135,9 @@ def __init__(self, directory, system_clock='truncate', preload=False, # Compose a structure which makes fiff writing a piece of cake info = _compose_meas_info(res4, coils, coord_trans, eeg) with info._unlock(): - info['dig'] += digs - info['dig'] = _format_dig_points(info['dig']) - info['bads'] += _read_bad_chans(directory, info) + info["dig"] += digs + info["dig"] = _format_dig_points(info["dig"]) + info["bads"] += _read_bad_chans(directory, info) # Determine how our data is distributed across files fnames = list() @@ -127,45 +146,50 @@ def __init__(self, directory, system_clock='truncate', preload=False, missing_names = list() no_samps = list() while True: - suffix = 'meg4' if len(fnames) == 0 else ('%d_meg4' % len(fnames)) - meg4_name, found = _make_ctf_name( - directory, suffix, raise_error=False) + suffix = "meg4" if len(fnames) == 0 else ("%d_meg4" % len(fnames)) + meg4_name, found = _make_ctf_name(directory, suffix, raise_error=False) if not found: missing_names.append(os.path.relpath(meg4_name, directory)) break # check how much data is in the file sample_info = _get_sample_info(meg4_name, res4, system_clock) - if sample_info['n_samp'] == 0: + if sample_info["n_samp"] == 0: no_samps.append(os.path.relpath(meg4_name, directory)) break if len(fnames) == 0: - buffer_size_sec = sample_info['block_size'] / info['sfreq'] + buffer_size_sec = sample_info["block_size"] / info["sfreq"] else: - buffer_size_sec = 1. + buffer_size_sec = 1.0 fnames.append(meg4_name) - last_samps.append(sample_info['n_samp'] - 1) + last_samps.append(sample_info["n_samp"] - 1) raw_extras.append(sample_info) first_samps = [0] * len(last_samps) if len(fnames) == 0: raise OSError( - f'Could not find any data, could not find the following ' - f'file(s): {missing_names}, and the following file(s) had no ' - f'valid samples: {no_samps}') + f"Could not find any data, could not find the following " + f"file(s): {missing_names}, and the following file(s) had no " + f"valid samples: {no_samps}" + ) super(RawCTF, self).__init__( - info, preload, first_samps=first_samps, - last_samps=last_samps, filenames=fnames, - raw_extras=raw_extras, orig_format='int', - buffer_size_sec=buffer_size_sec, verbose=verbose) + info, + preload, + first_samps=first_samps, + last_samps=last_samps, + filenames=fnames, + raw_extras=raw_extras, + orig_format="int", + buffer_size_sec=buffer_size_sec, + verbose=verbose, + ) # Add bad segments as Annotations (correct for start time) - start_time = -res4['pre_trig_pts'] / float(info['sfreq']) - annot = _annotate_bad_segments(directory, start_time, - info['meas_date']) + start_time = -res4["pre_trig_pts"] / float(info["sfreq"]) + annot = _annotate_bad_segments(directory, start_time, info["meas_date"]) marker_annot = _read_annotations_ctf_call( directory=directory, - total_offset=(res4['pre_trig_pts'] / res4['sfreq']), - trial_duration=(res4['nsamp'] / res4['sfreq']), - meas_date=info['meas_date'] + total_offset=(res4["pre_trig_pts"] / res4["sfreq"]), + trial_duration=(res4["nsamp"] / res4["sfreq"]), + meas_date=info["meas_date"], ) annot = marker_annot if annot is None else annot + marker_annot self.set_annotations(annot) @@ -177,24 +201,24 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): """Read a chunk of raw data.""" si = self._raw_extras[fi] offset = 0 - trial_start_idx, r_lims, d_lims = _blk_read_lims(start, stop, - int(si['block_size'])) - with open(self._filenames[fi], 'rb') as fid: + trial_start_idx, r_lims, d_lims = _blk_read_lims( + start, stop, int(si["block_size"]) + ) + with open(self._filenames[fi], "rb") as fid: for bi in range(len(r_lims)): - samp_offset = (bi + trial_start_idx) * si['res4_nsamp'] - n_read = min(si['n_samp_tot'] - samp_offset, si['block_size']) + samp_offset = (bi + trial_start_idx) * si["res4_nsamp"] + n_read = min(si["n_samp_tot"] - samp_offset, si["block_size"]) # read the chunk of data # have to be careful on Windows and make sure we are using # 64-bit integers here - with np.errstate(over='raise'): + with np.errstate(over="raise"): pos = np.int64(CTF.HEADER_SIZE) - pos += np.int64(samp_offset) * si['n_chan'] * 4 + pos += np.int64(samp_offset) * si["n_chan"] * 4 fid.seek(pos, 0) - this_data = np.fromfile(fid, '>i4', - count=si['n_chan'] * n_read) - this_data.shape = (si['n_chan'], n_read) - this_data = this_data[:, r_lims[bi, 0]:r_lims[bi, 1]] - data_view = data[:, d_lims[bi, 0]:d_lims[bi, 1]] + this_data = np.fromfile(fid, ">i4", count=si["n_chan"] * n_read) + this_data.shape = (si["n_chan"], n_read) + this_data = this_data[:, r_lims[bi, 0] : r_lims[bi, 1]] + data_view = data[:, d_lims[bi, 0] : d_lims[bi, 1]] _mult_cal_one(data_view, this_data, idx, cals, mult) offset += n_read @@ -204,66 +228,80 @@ def _clean_names(self): self.rename_channels(mapping) - for comp in self.info['comps']: - for key in ('row_names', 'col_names'): - comp['data'][key] = _clean_names(comp['data'][key]) + for comp in self.info["comps"]: + for key in ("row_names", "col_names"): + comp["data"][key] = _clean_names(comp["data"][key]) def _get_sample_info(fname, res4, system_clock): """Determine the number of valid samples.""" - logger.info('Finding samples for %s: ' % (fname,)) - if CTF.SYSTEM_CLOCK_CH in res4['ch_names']: - clock_ch = res4['ch_names'].index(CTF.SYSTEM_CLOCK_CH) + logger.info("Finding samples for %s: " % (fname,)) + if CTF.SYSTEM_CLOCK_CH in res4["ch_names"]: + clock_ch = res4["ch_names"].index(CTF.SYSTEM_CLOCK_CH) else: clock_ch = None - for k, ch in enumerate(res4['chs']): - if ch['ch_name'] == CTF.SYSTEM_CLOCK_CH: + for k, ch in enumerate(res4["chs"]): + if ch["ch_name"] == CTF.SYSTEM_CLOCK_CH: clock_ch = k break - with open(fname, 'rb') as fid: + with open(fname, "rb") as fid: fid.seek(0, os.SEEK_END) st_size = fid.tell() fid.seek(0, 0) - if (st_size - CTF.HEADER_SIZE) % (4 * res4['nsamp'] * - res4['nchan']) != 0: - raise RuntimeError('The number of samples is not an even multiple ' - 'of the trial size') - n_samp_tot = (st_size - CTF.HEADER_SIZE) // (4 * res4['nchan']) - n_trial = n_samp_tot // res4['nsamp'] + if (st_size - CTF.HEADER_SIZE) % (4 * res4["nsamp"] * res4["nchan"]) != 0: + raise RuntimeError( + "The number of samples is not an even multiple " "of the trial size" + ) + n_samp_tot = (st_size - CTF.HEADER_SIZE) // (4 * res4["nchan"]) + n_trial = n_samp_tot // res4["nsamp"] n_samp = n_samp_tot if clock_ch is None: - logger.info(' System clock channel is not available, assuming ' - 'all samples to be valid.') - elif system_clock == 'ignore': - logger.info(' System clock channel is available, but ignored.') + logger.info( + " System clock channel is not available, assuming " + "all samples to be valid." + ) + elif system_clock == "ignore": + logger.info(" System clock channel is available, but ignored.") else: # use it - logger.info(' System clock channel is available, checking ' - 'which samples are valid.') + logger.info( + " System clock channel is available, checking " + "which samples are valid." + ) for t in range(n_trial): # Skip to the correct trial - samp_offset = t * res4['nsamp'] - offset = CTF.HEADER_SIZE + (samp_offset * res4['nchan'] + - (clock_ch * res4['nsamp'])) * 4 + samp_offset = t * res4["nsamp"] + offset = ( + CTF.HEADER_SIZE + + (samp_offset * res4["nchan"] + (clock_ch * res4["nsamp"])) * 4 + ) fid.seek(offset, 0) - this_data = np.fromfile(fid, '>i4', res4['nsamp']) - if len(this_data) != res4['nsamp']: - raise RuntimeError('Cannot read data for trial %d' - % (t + 1)) + this_data = np.fromfile(fid, ">i4", res4["nsamp"]) + if len(this_data) != res4["nsamp"]: + raise RuntimeError("Cannot read data for trial %d" % (t + 1)) end = np.where(this_data == 0)[0] if len(end) > 0: n_samp = samp_offset + end[0] break - if n_samp < res4['nsamp']: + if n_samp < res4["nsamp"]: n_trial = 1 - logger.info(' %d x %d = %d samples from %d chs' - % (n_trial, n_samp, n_samp, res4['nchan'])) + logger.info( + " %d x %d = %d samples from %d chs" + % (n_trial, n_samp, n_samp, res4["nchan"]) + ) else: - n_trial = n_samp // res4['nsamp'] + n_trial = n_samp // res4["nsamp"] n_omit = n_samp_tot - n_samp - logger.info(' %d x %d = %d samples from %d chs' - % (n_trial, res4['nsamp'], n_samp, res4['nchan'])) + logger.info( + " %d x %d = %d samples from %d chs" + % (n_trial, res4["nsamp"], n_samp, res4["nchan"]) + ) if n_omit != 0: - logger.info(' %d samples omitted at the end' % n_omit) + logger.info(" %d samples omitted at the end" % n_omit) - return dict(n_samp=n_samp, n_samp_tot=n_samp_tot, block_size=res4['nsamp'], - res4_nsamp=res4['nsamp'], n_chan=res4['nchan']) + return dict( + n_samp=n_samp, + n_samp_tot=n_samp_tot, + block_size=res4["nsamp"], + res4_nsamp=res4["nsamp"], + n_chan=res4["nchan"], + ) diff --git a/mne/io/ctf/eeg.py b/mne/io/ctf/eeg.py index 4e6091abefd..86713c33629 100644 --- a/mne/io/ctf/eeg.py +++ b/mne/io/ctf/eeg.py @@ -14,62 +14,71 @@ from ...transforms import apply_trans -_cardinal_dict = dict(nasion=FIFF.FIFFV_POINT_NASION, - lpa=FIFF.FIFFV_POINT_LPA, left=FIFF.FIFFV_POINT_LPA, - rpa=FIFF.FIFFV_POINT_RPA, right=FIFF.FIFFV_POINT_RPA) +_cardinal_dict = dict( + nasion=FIFF.FIFFV_POINT_NASION, + lpa=FIFF.FIFFV_POINT_LPA, + left=FIFF.FIFFV_POINT_LPA, + rpa=FIFF.FIFFV_POINT_RPA, + right=FIFF.FIFFV_POINT_RPA, +) def _read_eeg(directory): """Read the .eeg file.""" # Missing file is ok - fname, found = _make_ctf_name(directory, 'eeg', raise_error=False) + fname, found = _make_ctf_name(directory, "eeg", raise_error=False) if not found: - logger.info(' Separate EEG position data file not present.') + logger.info(" Separate EEG position data file not present.") return - eeg = dict(labels=list(), kinds=list(), ids=list(), rr=list(), np=0, - assign_to_chs=True, coord_frame=FIFF.FIFFV_MNE_COORD_CTF_HEAD) - with open(fname, 'rb') as fid: + eeg = dict( + labels=list(), + kinds=list(), + ids=list(), + rr=list(), + np=0, + assign_to_chs=True, + coord_frame=FIFF.FIFFV_MNE_COORD_CTF_HEAD, + ) + with open(fname, "rb") as fid: for line in fid: line = line.strip() if len(line) > 0: - parts = line.decode('utf-8').split() + parts = line.decode("utf-8").split() if len(parts) != 5: - raise RuntimeError('Illegal data in EEG position file: %s' - % line) - r = np.array([float(p) for p in parts[2:]]) / 100. + raise RuntimeError("Illegal data in EEG position file: %s" % line) + r = np.array([float(p) for p in parts[2:]]) / 100.0 if (r * r).sum() > 1e-4: label = parts[1] - eeg['labels'].append(label) - eeg['rr'].append(r) + eeg["labels"].append(label) + eeg["rr"].append(r) id_ = _cardinal_dict.get(label.lower(), int(parts[0])) if label.lower() in _cardinal_dict: kind = FIFF.FIFFV_POINT_CARDINAL else: kind = FIFF.FIFFV_POINT_EXTRA - eeg['ids'].append(id_) - eeg['kinds'].append(kind) - eeg['np'] += 1 - logger.info(' Separate EEG position data file read.') + eeg["ids"].append(id_) + eeg["kinds"].append(kind) + eeg["np"] += 1 + logger.info(" Separate EEG position data file read.") return eeg def _read_pos(directory, transformations): """Read the .pos file and return eeg positions as dig extra points.""" - fname = [join(directory, f) for f in listdir(directory) if - f.endswith('.pos')] + fname = [join(directory, f) for f in listdir(directory) if f.endswith(".pos")] if len(fname) < 1: return list() elif len(fname) > 1: - warn(' Found multiple pos files. Extra digitizer points not added.') + warn(" Found multiple pos files. Extra digitizer points not added.") return list() - logger.info(' Reading digitizer points from %s...' % fname) - if transformations['t_ctf_head_head'] is None: - warn(' No transformation found. Extra digitizer points not added.') + logger.info(" Reading digitizer points from %s..." % fname) + if transformations["t_ctf_head_head"] is None: + warn(" No transformation found. Extra digitizer points not added.") return list() fname = fname[0] digs = list() i = 2000 - with open(fname, 'r') as fid: + with open(fname, "r") as fid: for line in fid: line = line.strip() if len(line) > 0: @@ -85,11 +94,15 @@ def _read_pos(directory, transformations): except ValueError: # if id is not an int ident = i i += 1 - dig = dict(kind=FIFF.FIFFV_POINT_EXTRA, ident=ident, r=list(), - coord_frame=FIFF.FIFFV_COORD_HEAD) - r = np.array([float(p) for p in parts[-3:]]) / 100. # cm to m + dig = dict( + kind=FIFF.FIFFV_POINT_EXTRA, + ident=ident, + r=list(), + coord_frame=FIFF.FIFFV_COORD_HEAD, + ) + r = np.array([float(p) for p in parts[-3:]]) / 100.0 # cm to m if (r * r).sum() > 1e-4: - r = apply_trans(transformations['t_ctf_head_head'], r) - dig['r'] = r + r = apply_trans(transformations["t_ctf_head_head"], r) + dig["r"] = r digs.append(dig) return digs diff --git a/mne/io/ctf/hc.py b/mne/io/ctf/hc.py index 1911fc84055..7e05cc1ce0a 100644 --- a/mne/io/ctf/hc.py +++ b/mne/io/ctf/hc.py @@ -12,73 +12,79 @@ from ..constants import FIFF -_kind_dict = {'nasion': CTF.CTFV_COIL_NAS, 'left ear': CTF.CTFV_COIL_LPA, - 'right ear': CTF.CTFV_COIL_RPA, 'spare': CTF.CTFV_COIL_SPARE} +_kind_dict = { + "nasion": CTF.CTFV_COIL_NAS, + "left ear": CTF.CTFV_COIL_LPA, + "right ear": CTF.CTFV_COIL_RPA, + "spare": CTF.CTFV_COIL_SPARE, +} -_coord_dict = {'relative to dewar': FIFF.FIFFV_MNE_COORD_CTF_DEVICE, - 'relative to head': FIFF.FIFFV_MNE_COORD_CTF_HEAD} +_coord_dict = { + "relative to dewar": FIFF.FIFFV_MNE_COORD_CTF_DEVICE, + "relative to head": FIFF.FIFFV_MNE_COORD_CTF_HEAD, +} def _read_one_coil_point(fid): """Read coil coordinate information from the hc file.""" # Descriptor - one = '#' - while len(one) > 0 and one[0] == '#': + one = "#" + while len(one) > 0 and one[0] == "#": one = fid.readline() if len(one) == 0: return None - one = one.strip().decode('utf-8') - if 'Unable' in one: + one = one.strip().decode("utf-8") + if "Unable" in one: raise RuntimeError("HPI information not available") # Hopefully this is an unambiguous interpretation p = dict() - p['valid'] = ('measured' in one) + p["valid"] = "measured" in one for key, val in _coord_dict.items(): if key in one: - p['coord_frame'] = val + p["coord_frame"] = val break else: - p['coord_frame'] = -1 + p["coord_frame"] = -1 for key, val in _kind_dict.items(): if key in one: - p['kind'] = val + p["kind"] = val break else: - p['kind'] = -1 + p["kind"] = -1 # Three coordinates - p['r'] = np.empty(3) - for ii, coord in enumerate('xyz'): - sp = fid.readline().decode('utf-8').strip() + p["r"] = np.empty(3) + for ii, coord in enumerate("xyz"): + sp = fid.readline().decode("utf-8").strip() if len(sp) == 0: # blank line continue - sp = sp.split(' ') - if len(sp) != 3 or sp[0] != coord or sp[1] != '=': - raise RuntimeError('Bad line: %s' % one) + sp = sp.split(" ") + if len(sp) != 3 or sp[0] != coord or sp[1] != "=": + raise RuntimeError("Bad line: %s" % one) # We do not deal with centimeters - p['r'][ii] = float(sp[2]) / 100.0 + p["r"][ii] = float(sp[2]) / 100.0 return p def _read_hc(directory): """Read the hc file to get the HPI info and to prepare for coord trans.""" - fname, found = _make_ctf_name(directory, 'hc', raise_error=False) + fname, found = _make_ctf_name(directory, "hc", raise_error=False) if not found: - logger.info(' hc data not present') + logger.info(" hc data not present") return None s = list() - with open(fname, 'rb') as fid: + with open(fname, "rb") as fid: while True: p = _read_one_coil_point(fid) if p is None: # First point bad indicates that the file is empty if len(s) == 0: - logger.info('hc file empty, no data present') + logger.info("hc file empty, no data present") return None # Returns None if at EOF - logger.info(' hc data read.') + logger.info(" hc data read.") return s - if p['valid']: + if p["valid"]: s.append(p) diff --git a/mne/io/ctf/info.py b/mne/io/ctf/info.py index 0afbe3e2836..6f04d23ea57 100644 --- a/mne/io/ctf/info.py +++ b/mne/io/ctf/info.py @@ -11,8 +11,12 @@ import numpy as np from ...utils import logger, warn, _clean_names -from ...transforms import (apply_trans, _coord_frame_name, invert_transform, - combine_transforms) +from ...transforms import ( + apply_trans, + _coord_frame_name, + invert_transform, + combine_transforms, +) from ...annotations import Annotations from ..meas_info import _empty_info @@ -23,9 +27,11 @@ from .constants import CTF -_ctf_to_fiff = {CTF.CTFV_COIL_LPA: FIFF.FIFFV_POINT_LPA, - CTF.CTFV_COIL_RPA: FIFF.FIFFV_POINT_RPA, - CTF.CTFV_COIL_NAS: FIFF.FIFFV_POINT_NASION} +_ctf_to_fiff = { + CTF.CTFV_COIL_LPA: FIFF.FIFFV_POINT_LPA, + CTF.CTFV_COIL_RPA: FIFF.FIFFV_POINT_RPA, + CTF.CTFV_COIL_NAS: FIFF.FIFFV_POINT_NASION, +} def _pick_isotrak_and_hpi_coils(res4, coils, t): @@ -37,47 +43,59 @@ def _pick_isotrak_and_hpi_coils(res4, coils, t): n_coil_dev = 0 n_coil_head = 0 for p in coils: - if p['valid']: - if p['kind'] in [CTF.CTFV_COIL_LPA, CTF.CTFV_COIL_RPA, - CTF.CTFV_COIL_NAS]: + if p["valid"]: + if p["kind"] in [CTF.CTFV_COIL_LPA, CTF.CTFV_COIL_RPA, CTF.CTFV_COIL_NAS]: kind = FIFF.FIFFV_POINT_CARDINAL - ident = _ctf_to_fiff[p['kind']] + ident = _ctf_to_fiff[p["kind"]] else: # CTF.CTFV_COIL_SPARE kind = FIFF.FIFFV_POINT_HPI - ident = p['kind'] - if p['coord_frame'] == FIFF.FIFFV_MNE_COORD_CTF_DEVICE: - if t is None or t['t_ctf_dev_dev'] is None: - raise RuntimeError('No coordinate transformation ' - 'available for HPI coil locations') - d = dict(kind=kind, ident=ident, - r=apply_trans(t['t_ctf_dev_dev'], p['r']), - coord_frame=FIFF.FIFFV_COORD_UNKNOWN) - hpi_result['dig_points'].append(d) + ident = p["kind"] + if p["coord_frame"] == FIFF.FIFFV_MNE_COORD_CTF_DEVICE: + if t is None or t["t_ctf_dev_dev"] is None: + raise RuntimeError( + "No coordinate transformation " + "available for HPI coil locations" + ) + d = dict( + kind=kind, + ident=ident, + r=apply_trans(t["t_ctf_dev_dev"], p["r"]), + coord_frame=FIFF.FIFFV_COORD_UNKNOWN, + ) + hpi_result["dig_points"].append(d) n_coil_dev += 1 - elif p['coord_frame'] == FIFF.FIFFV_MNE_COORD_CTF_HEAD: - if t is None or t['t_ctf_head_head'] is None: - raise RuntimeError('No coordinate transformation ' - 'available for (virtual) Polhemus data') - d = dict(kind=kind, ident=ident, - r=apply_trans(t['t_ctf_head_head'], p['r']), - coord_frame=FIFF.FIFFV_COORD_HEAD) + elif p["coord_frame"] == FIFF.FIFFV_MNE_COORD_CTF_HEAD: + if t is None or t["t_ctf_head_head"] is None: + raise RuntimeError( + "No coordinate transformation " + "available for (virtual) Polhemus data" + ) + d = dict( + kind=kind, + ident=ident, + r=apply_trans(t["t_ctf_head_head"], p["r"]), + coord_frame=FIFF.FIFFV_COORD_HEAD, + ) dig.append(d) n_coil_head += 1 if n_coil_head > 0: - logger.info(' Polhemus data for %d HPI coils added' % n_coil_head) + logger.info(" Polhemus data for %d HPI coils added" % n_coil_head) if n_coil_dev > 0: - logger.info(' Device coordinate locations for %d HPI coils added' - % n_coil_dev) + logger.info( + " Device coordinate locations for %d HPI coils added" % n_coil_dev + ) return dig, [hpi_result] def _convert_time(date_str, time_str): """Convert date and time strings to float time.""" - if date_str == time_str == '': - date_str = '01/01/1970' - time_str = '00:00:00' - logger.info('No date or time found, setting to the start of the ' - 'POSIX epoch (1970/01/01 midnight)') + if date_str == time_str == "": + date_str = "01/01/1970" + time_str = "00:00:00" + logger.info( + "No date or time found, setting to the start of the " + "POSIX epoch (1970/01/01 midnight)" + ) for fmt in ("%d/%m/%Y", "%d-%b-%Y", "%a, %b %d, %Y", "%Y/%m/%d"): try: @@ -88,12 +106,13 @@ def _convert_time(date_str, time_str): break else: raise RuntimeError( - 'Illegal date: %s.\nIf the language of the date does not ' - 'correspond to your local machine\'s language try to set the ' - 'locale to the language of the date string:\n' - 'locale.setlocale(locale.LC_ALL, "en_US")' % date_str) + "Illegal date: %s.\nIf the language of the date does not " + "correspond to your local machine's language try to set the " + "locale to the language of the date string:\n" + 'locale.setlocale(locale.LC_ALL, "en_US")' % date_str + ) - for fmt in ('%H:%M:%S', '%H:%M'): + for fmt in ("%H:%M:%S", "%H:%M"): try: time = strptime(time_str, fmt) except ValueError: @@ -101,15 +120,25 @@ def _convert_time(date_str, time_str): else: break else: - raise RuntimeError('Illegal time: %s' % time_str) + raise RuntimeError("Illegal time: %s" % time_str) # MNE-C uses mktime which uses local time, but here we instead decouple # conversion location from the process, and instead assume that the # acquisition was in GMT. This will be wrong for most sites, but at least # the value we obtain here won't depend on the geographical location # that the file was converted. - res = timegm((date.tm_year, date.tm_mon, date.tm_mday, - time.tm_hour, time.tm_min, time.tm_sec, - date.tm_wday, date.tm_yday, date.tm_isdst)) + res = timegm( + ( + date.tm_year, + date.tm_mon, + date.tm_mday, + time.tm_hour, + time.tm_min, + time.tm_sec, + date.tm_wday, + date.tm_yday, + date.tm_isdst, + ) + ) return res @@ -118,15 +147,15 @@ def _get_plane_vectors(ez): assert ez.shape == (3,) ez_len = np.sqrt(np.sum(ez * ez)) if ez_len == 0: - raise RuntimeError('Zero length normal. Cannot proceed.') + raise RuntimeError("Zero length normal. Cannot proceed.") if np.abs(ez_len - np.abs(ez[2])) < 1e-5: # ez already in z-direction - ex = np.array([1., 0., 0.]) + ex = np.array([1.0, 0.0, 0.0]) else: ex = np.zeros(3) if ez[1] < ez[2]: - ex[0 if ez[0] < ez[1] else 1] = 1. + ex[0 if ez[0] < ez[1] else 1] = 1.0 else: - ex[0 if ez[0] < ez[2] else 2] = 1. + ex[0 if ez[0] < ez[2] else 2] = 1.0 ez /= ez_len ex -= np.dot(ez, ex) * ez ex /= np.sqrt(np.sum(ex * ex)) @@ -136,16 +165,17 @@ def _get_plane_vectors(ez): def _at_origin(x): """Determine if a vector is at the origin.""" - return (np.sum(x * x) < 1e-8) + return np.sum(x * x) < 1e-8 def _check_comp_ch(cch, kind, desired=None): if desired is None: - desired = cch['grad_order_no'] - if cch['grad_order_no'] != desired: - raise RuntimeError('%s channel with inconsistent compensation ' - 'grade %s, should be %s' - % (kind, cch['grad_order_no'], desired)) + desired = cch["grad_order_no"] + if cch["grad_order_no"] != desired: + raise RuntimeError( + "%s channel with inconsistent compensation " + "grade %s, should be %s" % (kind, cch["grad_order_no"], desired) + ) return desired @@ -154,49 +184,64 @@ def _convert_channel_info(res4, t, use_eeg_pos): nmeg = neeg = nstim = nmisc = nref = 0 chs = list() this_comp = None - for k, cch in enumerate(res4['chs']): - cal = float(1. / (cch['proper_gain'] * cch['qgain'])) - ch = dict(scanno=k + 1, range=1., cal=cal, loc=np.full(12, np.nan), - unit_mul=FIFF.FIFF_UNITM_NONE, ch_name=cch['ch_name'][:15], - coil_type=FIFF.FIFFV_COIL_NONE) + for k, cch in enumerate(res4["chs"]): + cal = float(1.0 / (cch["proper_gain"] * cch["qgain"])) + ch = dict( + scanno=k + 1, + range=1.0, + cal=cal, + loc=np.full(12, np.nan), + unit_mul=FIFF.FIFF_UNITM_NONE, + ch_name=cch["ch_name"][:15], + coil_type=FIFF.FIFFV_COIL_NONE, + ) del k chs.append(ch) # Create the channel position information - if cch['sensor_type_index'] in (CTF.CTFV_REF_MAG_CH, - CTF.CTFV_REF_GRAD_CH, - CTF.CTFV_MEG_CH): + if cch["sensor_type_index"] in ( + CTF.CTFV_REF_MAG_CH, + CTF.CTFV_REF_GRAD_CH, + CTF.CTFV_MEG_CH, + ): # Extra check for a valid MEG channel - if np.sum(cch['coil']['pos'][0] ** 2) < 1e-6 or \ - np.sum(cch['coil']['norm'][0] ** 2) < 1e-6: + if ( + np.sum(cch["coil"]["pos"][0] ** 2) < 1e-6 + or np.sum(cch["coil"]["norm"][0] ** 2) < 1e-6 + ): nmisc += 1 - ch.update(logno=nmisc, coord_frame=FIFF.FIFFV_COORD_UNKNOWN, - kind=FIFF.FIFFV_MISC_CH, unit=FIFF.FIFF_UNIT_V) - text = 'MEG' - if cch['sensor_type_index'] != CTF.CTFV_MEG_CH: - text += ' ref' - warn('%s channel %s did not have position assigned, so ' - 'it was changed to a MISC channel' - % (text, ch['ch_name'])) + ch.update( + logno=nmisc, + coord_frame=FIFF.FIFFV_COORD_UNKNOWN, + kind=FIFF.FIFFV_MISC_CH, + unit=FIFF.FIFF_UNIT_V, + ) + text = "MEG" + if cch["sensor_type_index"] != CTF.CTFV_MEG_CH: + text += " ref" + warn( + "%s channel %s did not have position assigned, so " + "it was changed to a MISC channel" % (text, ch["ch_name"]) + ) continue - ch['unit'] = FIFF.FIFF_UNIT_T + ch["unit"] = FIFF.FIFF_UNIT_T # Set up the local coordinate frame - r0 = cch['coil']['pos'][0].copy() - ez = cch['coil']['norm'][0].copy() + r0 = cch["coil"]["pos"][0].copy() + ez = cch["coil"]["norm"][0].copy() # It turns out that positive proper_gain requires swapping # of the normal direction - if cch['proper_gain'] > 0.0: + if cch["proper_gain"] > 0.0: ez *= -1 # Check how the other vectors should be defined off_diag = False # Default: ex and ey are arbitrary in the plane normal to ez - if cch['sensor_type_index'] == CTF.CTFV_REF_GRAD_CH: + if cch["sensor_type_index"] == CTF.CTFV_REF_GRAD_CH: # The off-diagonal gradiometers are an exception: # # We use the same convention for ex as for Neuromag planar # gradiometers: ex pointing in the positive gradient direction - diff = cch['coil']['pos'][0] - cch['coil']['pos'][1] + diff = cch["coil"]["pos"][0] - cch["coil"]["pos"][1] size = np.sqrt(np.sum(diff * diff)) - if size > 0.: + if size > 0.0: diff /= size # Is ez normal to the line joining the coils? if np.abs(np.dot(diff, ez)) < 1e-3: @@ -210,69 +255,84 @@ def _convert_channel_info(res4, t, use_eeg_pos): else: ex, ey = _get_plane_vectors(ez) # Transform into a Neuromag-like device coordinate system - ch['loc'] = np.concatenate([ - apply_trans(t['t_ctf_dev_dev'], r0), - apply_trans(t['t_ctf_dev_dev'], ex, move=False), - apply_trans(t['t_ctf_dev_dev'], ey, move=False), - apply_trans(t['t_ctf_dev_dev'], ez, move=False)]) + ch["loc"] = np.concatenate( + [ + apply_trans(t["t_ctf_dev_dev"], r0), + apply_trans(t["t_ctf_dev_dev"], ex, move=False), + apply_trans(t["t_ctf_dev_dev"], ey, move=False), + apply_trans(t["t_ctf_dev_dev"], ez, move=False), + ] + ) del r0, ex, ey, ez # Set the coil type - if cch['sensor_type_index'] == CTF.CTFV_REF_MAG_CH: - ch['kind'] = FIFF.FIFFV_REF_MEG_CH - ch['coil_type'] = FIFF.FIFFV_COIL_CTF_REF_MAG + if cch["sensor_type_index"] == CTF.CTFV_REF_MAG_CH: + ch["kind"] = FIFF.FIFFV_REF_MEG_CH + ch["coil_type"] = FIFF.FIFFV_COIL_CTF_REF_MAG nref += 1 - ch['logno'] = nref - elif cch['sensor_type_index'] == CTF.CTFV_REF_GRAD_CH: - ch['kind'] = FIFF.FIFFV_REF_MEG_CH + ch["logno"] = nref + elif cch["sensor_type_index"] == CTF.CTFV_REF_GRAD_CH: + ch["kind"] = FIFF.FIFFV_REF_MEG_CH if off_diag: - ch['coil_type'] = FIFF.FIFFV_COIL_CTF_OFFDIAG_REF_GRAD + ch["coil_type"] = FIFF.FIFFV_COIL_CTF_OFFDIAG_REF_GRAD else: - ch['coil_type'] = FIFF.FIFFV_COIL_CTF_REF_GRAD + ch["coil_type"] = FIFF.FIFFV_COIL_CTF_REF_GRAD nref += 1 - ch['logno'] = nref + ch["logno"] = nref else: - this_comp = _check_comp_ch(cch, 'Gradiometer', this_comp) - ch['kind'] = FIFF.FIFFV_MEG_CH - ch['coil_type'] = FIFF.FIFFV_COIL_CTF_GRAD + this_comp = _check_comp_ch(cch, "Gradiometer", this_comp) + ch["kind"] = FIFF.FIFFV_MEG_CH + ch["coil_type"] = FIFF.FIFFV_COIL_CTF_GRAD nmeg += 1 - ch['logno'] = nmeg + ch["logno"] = nmeg # Encode the software gradiometer order - ch['coil_type'] = int( - ch['coil_type'] | (cch['grad_order_no'] << 16)) - ch['coord_frame'] = FIFF.FIFFV_COORD_DEVICE - elif cch['sensor_type_index'] == CTF.CTFV_EEG_CH: + ch["coil_type"] = int(ch["coil_type"] | (cch["grad_order_no"] << 16)) + ch["coord_frame"] = FIFF.FIFFV_COORD_DEVICE + elif cch["sensor_type_index"] == CTF.CTFV_EEG_CH: coord_frame = FIFF.FIFFV_COORD_HEAD if use_eeg_pos: # EEG electrode coordinates may be present but in the # CTF head frame - ch['loc'][:3] = cch['coil']['pos'][0] - if not _at_origin(ch['loc'][:3]): - if t['t_ctf_head_head'] is None: - warn('EEG electrode (%s) location omitted because of ' - 'missing HPI information' % ch['ch_name']) - ch['loc'].fill(np.nan) + ch["loc"][:3] = cch["coil"]["pos"][0] + if not _at_origin(ch["loc"][:3]): + if t["t_ctf_head_head"] is None: + warn( + "EEG electrode (%s) location omitted because of " + "missing HPI information" % ch["ch_name"] + ) + ch["loc"].fill(np.nan) coord_frame = FIFF.FIFFV_MNE_COORD_CTF_HEAD else: - ch['loc'][:3] = apply_trans( - t['t_ctf_head_head'], ch['loc'][:3]) + ch["loc"][:3] = apply_trans(t["t_ctf_head_head"], ch["loc"][:3]) neeg += 1 - ch.update(logno=neeg, kind=FIFF.FIFFV_EEG_CH, - unit=FIFF.FIFF_UNIT_V, coord_frame=coord_frame, - coil_type=FIFF.FIFFV_COIL_EEG) - elif cch['sensor_type_index'] == CTF.CTFV_STIM_CH: + ch.update( + logno=neeg, + kind=FIFF.FIFFV_EEG_CH, + unit=FIFF.FIFF_UNIT_V, + coord_frame=coord_frame, + coil_type=FIFF.FIFFV_COIL_EEG, + ) + elif cch["sensor_type_index"] == CTF.CTFV_STIM_CH: nstim += 1 - ch.update(logno=nstim, coord_frame=FIFF.FIFFV_COORD_UNKNOWN, - kind=FIFF.FIFFV_STIM_CH, unit=FIFF.FIFF_UNIT_V) + ch.update( + logno=nstim, + coord_frame=FIFF.FIFFV_COORD_UNKNOWN, + kind=FIFF.FIFFV_STIM_CH, + unit=FIFF.FIFF_UNIT_V, + ) else: nmisc += 1 - ch.update(logno=nmisc, coord_frame=FIFF.FIFFV_COORD_UNKNOWN, - kind=FIFF.FIFFV_MISC_CH, unit=FIFF.FIFF_UNIT_V) + ch.update( + logno=nmisc, + coord_frame=FIFF.FIFFV_COORD_UNKNOWN, + kind=FIFF.FIFFV_MISC_CH, + unit=FIFF.FIFF_UNIT_V, + ) return chs def _comp_sort_keys(c): """Sort the compensation data.""" - return (int(c['coeff_type']), int(c['scanno'])) + return (int(c["coeff_type"]), int(c["scanno"])) def _check_comp(comp): @@ -280,77 +340,87 @@ def _check_comp(comp): ref_sens = None kind = -1 for k, c_k in enumerate(comp): - if c_k['coeff_type'] != kind: + if c_k["coeff_type"] != kind: c_ref = c_k - ref_sens = c_ref['sensors'] - kind = c_k['coeff_type'] - elif not c_k['sensors'] == ref_sens: - raise RuntimeError('Cannot use an uneven compensation matrix') + ref_sens = c_ref["sensors"] + kind = c_k["coeff_type"] + elif not c_k["sensors"] == ref_sens: + raise RuntimeError("Cannot use an uneven compensation matrix") def _conv_comp(comp, first, last, chs): """Add a new converted compensation data item.""" - ch_names = [c['ch_name'] for c in chs] - n_col = comp[first]['ncoeff'] - col_names = comp[first]['sensors'][:n_col] - row_names = [comp[p]['sensor_name'] for p in range(first, last + 1)] + ch_names = [c["ch_name"] for c in chs] + n_col = comp[first]["ncoeff"] + col_names = comp[first]["sensors"][:n_col] + row_names = [comp[p]["sensor_name"] for p in range(first, last + 1)] mask = np.in1d(col_names, ch_names) # missing channels excluded col_names = np.array(col_names)[mask].tolist() n_col = len(col_names) n_row = len(row_names) - ccomp = dict(ctfkind=comp[first]['coeff_type'], save_calibrated=False) + ccomp = dict(ctfkind=comp[first]["coeff_type"], save_calibrated=False) _add_kind(ccomp) data = np.empty((n_row, n_col)) - for ii, coeffs in enumerate(comp[first:last + 1]): + for ii, coeffs in enumerate(comp[first : last + 1]): # Pick the elements to the matrix - data[ii, :] = coeffs['coeffs'][mask] - ccomp['data'] = dict(row_names=row_names, col_names=col_names, - data=data, nrow=len(row_names), ncol=len(col_names)) - mk = ('proper_gain', 'qgain') + data[ii, :] = coeffs["coeffs"][mask] + ccomp["data"] = dict( + row_names=row_names, + col_names=col_names, + data=data, + nrow=len(row_names), + ncol=len(col_names), + ) + mk = ("proper_gain", "qgain") _calibrate_comp(ccomp, chs, row_names, col_names, mult_keys=mk, flip=True) return ccomp def _convert_comp_data(res4): """Convert the compensation data into named matrices.""" - if res4['ncomp'] == 0: + if res4["ncomp"] == 0: return # Sort the coefficients in our favorite order - res4['comp'] = sorted(res4['comp'], key=_comp_sort_keys) + res4["comp"] = sorted(res4["comp"], key=_comp_sort_keys) # Check that all items for a given compensation type have the correct # number of channels - _check_comp(res4['comp']) + _check_comp(res4["comp"]) # Create named matrices first = 0 kind = -1 comps = list() - for k in range(len(res4['comp'])): - if res4['comp'][k]['coeff_type'] != kind: + for k in range(len(res4["comp"])): + if res4["comp"][k]["coeff_type"] != kind: if k > 0: - comps.append(_conv_comp(res4['comp'], first, k - 1, - res4['chs'])) - kind = res4['comp'][k]['coeff_type'] + comps.append(_conv_comp(res4["comp"], first, k - 1, res4["chs"])) + kind = res4["comp"][k]["coeff_type"] first = k - comps.append(_conv_comp(res4['comp'], first, k, res4['chs'])) + comps.append(_conv_comp(res4["comp"], first, k, res4["chs"])) return comps def _pick_eeg_pos(c): """Pick EEG positions.""" - eeg = dict(coord_frame=FIFF.FIFFV_COORD_HEAD, assign_to_chs=False, - labels=list(), ids=list(), rr=list(), kinds=list(), np=0) - for ch in c['chs']: - if ch['kind'] == FIFF.FIFFV_EEG_CH and not _at_origin(ch['loc'][:3]): - eeg['labels'].append(ch['ch_name']) - eeg['ids'].append(ch['logno']) - eeg['rr'].append(ch['loc'][:3]) - eeg['kinds'].append(FIFF.FIFFV_POINT_EEG) - eeg['np'] += 1 - if eeg['np'] == 0: + eeg = dict( + coord_frame=FIFF.FIFFV_COORD_HEAD, + assign_to_chs=False, + labels=list(), + ids=list(), + rr=list(), + kinds=list(), + np=0, + ) + for ch in c["chs"]: + if ch["kind"] == FIFF.FIFFV_EEG_CH and not _at_origin(ch["loc"][:3]): + eeg["labels"].append(ch["ch_name"]) + eeg["ids"].append(ch["logno"]) + eeg["rr"].append(ch["loc"][:3]) + eeg["kinds"].append(FIFF.FIFFV_POINT_EEG) + eeg["np"] += 1 + if eeg["np"] == 0: return None - logger.info('Picked positions of %d EEG channels from channel info' - % eeg['np']) + logger.info("Picked positions of %d EEG channels from channel info" % eeg["np"]) return eeg @@ -358,96 +428,104 @@ def _add_eeg_pos(eeg, t, c): """Pick the (virtual) EEG position data.""" if eeg is None: return - if t is None or t['t_ctf_head_head'] is None: - raise RuntimeError('No coordinate transformation available for EEG ' - 'position data') + if t is None or t["t_ctf_head_head"] is None: + raise RuntimeError( + "No coordinate transformation available for EEG " "position data" + ) eeg_assigned = 0 - if eeg['assign_to_chs']: - for k in range(eeg['np']): + if eeg["assign_to_chs"]: + for k in range(eeg["np"]): # Look for a channel name match - for ch in c['chs']: - if ch['ch_name'].lower() == eeg['labels'][k].lower(): - r0 = ch['loc'][:3] - r0[:] = eeg['rr'][k] - if eeg['coord_frame'] == FIFF.FIFFV_MNE_COORD_CTF_HEAD: - r0[:] = apply_trans(t['t_ctf_head_head'], r0) - elif eeg['coord_frame'] != FIFF.FIFFV_COORD_HEAD: + for ch in c["chs"]: + if ch["ch_name"].lower() == eeg["labels"][k].lower(): + r0 = ch["loc"][:3] + r0[:] = eeg["rr"][k] + if eeg["coord_frame"] == FIFF.FIFFV_MNE_COORD_CTF_HEAD: + r0[:] = apply_trans(t["t_ctf_head_head"], r0) + elif eeg["coord_frame"] != FIFF.FIFFV_COORD_HEAD: raise RuntimeError( - 'Illegal coordinate frame for EEG electrode ' - 'positions : %s' - % _coord_frame_name(eeg['coord_frame'])) + "Illegal coordinate frame for EEG electrode " + "positions : %s" % _coord_frame_name(eeg["coord_frame"]) + ) # Use the logical channel number as an identifier - eeg['ids'][k] = ch['logno'] - eeg['kinds'][k] = FIFF.FIFFV_POINT_EEG + eeg["ids"][k] = ch["logno"] + eeg["kinds"][k] = FIFF.FIFFV_POINT_EEG eeg_assigned += 1 break # Add these to the Polhemus data fid_count = eeg_count = extra_count = 0 - for k in range(eeg['np']): - d = dict(r=eeg['rr'][k].copy(), kind=eeg['kinds'][k], - ident=eeg['ids'][k], coord_frame=FIFF.FIFFV_COORD_HEAD) - c['dig'].append(d) - if eeg['coord_frame'] == FIFF.FIFFV_MNE_COORD_CTF_HEAD: - d['r'] = apply_trans(t['t_ctf_head_head'], d['r']) - elif eeg['coord_frame'] != FIFF.FIFFV_COORD_HEAD: - raise RuntimeError('Illegal coordinate frame for EEG electrode ' - 'positions: %s' - % _coord_frame_name(eeg['coord_frame'])) - if eeg['kinds'][k] == FIFF.FIFFV_POINT_CARDINAL: + for k in range(eeg["np"]): + d = dict( + r=eeg["rr"][k].copy(), + kind=eeg["kinds"][k], + ident=eeg["ids"][k], + coord_frame=FIFF.FIFFV_COORD_HEAD, + ) + c["dig"].append(d) + if eeg["coord_frame"] == FIFF.FIFFV_MNE_COORD_CTF_HEAD: + d["r"] = apply_trans(t["t_ctf_head_head"], d["r"]) + elif eeg["coord_frame"] != FIFF.FIFFV_COORD_HEAD: + raise RuntimeError( + "Illegal coordinate frame for EEG electrode " + "positions: %s" % _coord_frame_name(eeg["coord_frame"]) + ) + if eeg["kinds"][k] == FIFF.FIFFV_POINT_CARDINAL: fid_count += 1 - elif eeg['kinds'][k] == FIFF.FIFFV_POINT_EEG: + elif eeg["kinds"][k] == FIFF.FIFFV_POINT_EEG: eeg_count += 1 else: extra_count += 1 if eeg_assigned > 0: - logger.info(' %d EEG electrode locations assigned to channel info.' - % eeg_assigned) - for count, kind in zip((fid_count, eeg_count, extra_count), - ('fiducials', 'EEG locations', 'extra points')): + logger.info( + " %d EEG electrode locations assigned to channel info." % eeg_assigned + ) + for count, kind in zip( + (fid_count, eeg_count, extra_count), + ("fiducials", "EEG locations", "extra points"), + ): if count > 0: - logger.info(' %d %s added to Polhemus data.' % (count, kind)) + logger.info(" %d %s added to Polhemus data." % (count, kind)) -_filt_map = {CTF.CTFV_FILTER_LOWPASS: 'lowpass', - CTF.CTFV_FILTER_HIGHPASS: 'highpass'} +_filt_map = {CTF.CTFV_FILTER_LOWPASS: "lowpass", CTF.CTFV_FILTER_HIGHPASS: "highpass"} def _compose_meas_info(res4, coils, trans, eeg): """Create meas info from CTF data.""" - info = _empty_info(res4['sfreq']) + info = _empty_info(res4["sfreq"]) # Collect all the necessary data from the structures read - info['meas_id'] = get_new_file_id() - info['meas_id']['usecs'] = 0 - info['meas_id']['secs'] = _convert_time(res4['data_date'], - res4['data_time']) - info['meas_date'] = (info['meas_id']['secs'], info['meas_id']['usecs']) - info['experimenter'] = res4['nf_operator'] - info['subject_info'] = dict(his_id=res4['nf_subject_id']) - for filt in res4['filters']: - if filt['type'] in _filt_map: - info[_filt_map[filt['type']]] = filt['freq'] - info['dig'], info['hpi_results'] = _pick_isotrak_and_hpi_coils( - res4, coils, trans) + info["meas_id"] = get_new_file_id() + info["meas_id"]["usecs"] = 0 + info["meas_id"]["secs"] = _convert_time(res4["data_date"], res4["data_time"]) + info["meas_date"] = (info["meas_id"]["secs"], info["meas_id"]["usecs"]) + info["experimenter"] = res4["nf_operator"] + info["subject_info"] = dict(his_id=res4["nf_subject_id"]) + for filt in res4["filters"]: + if filt["type"] in _filt_map: + info[_filt_map[filt["type"]]] = filt["freq"] + info["dig"], info["hpi_results"] = _pick_isotrak_and_hpi_coils(res4, coils, trans) if trans is not None: - if len(info['hpi_results']) > 0: - info['hpi_results'][0]['coord_trans'] = trans['t_ctf_head_head'] - if trans['t_dev_head'] is not None: - info['dev_head_t'] = trans['t_dev_head'] - info['dev_ctf_t'] = combine_transforms( - trans['t_dev_head'], - invert_transform(trans['t_ctf_head_head']), - FIFF.FIFFV_COORD_DEVICE, FIFF.FIFFV_MNE_COORD_CTF_HEAD) - if trans['t_ctf_head_head'] is not None: - info['ctf_head_t'] = trans['t_ctf_head_head'] - info['chs'] = _convert_channel_info(res4, trans, eeg is None) - info['comps'] = _convert_comp_data(res4) + if len(info["hpi_results"]) > 0: + info["hpi_results"][0]["coord_trans"] = trans["t_ctf_head_head"] + if trans["t_dev_head"] is not None: + info["dev_head_t"] = trans["t_dev_head"] + info["dev_ctf_t"] = combine_transforms( + trans["t_dev_head"], + invert_transform(trans["t_ctf_head_head"]), + FIFF.FIFFV_COORD_DEVICE, + FIFF.FIFFV_MNE_COORD_CTF_HEAD, + ) + if trans["t_ctf_head_head"] is not None: + info["ctf_head_t"] = trans["t_ctf_head_head"] + info["chs"] = _convert_channel_info(res4, trans, eeg is None) + info["comps"] = _convert_comp_data(res4) if eeg is None: # Pick EEG locations from chan info if not read from a separate file eeg = _pick_eeg_pos(info) _add_eeg_pos(eeg, trans, info) - logger.info(' Measurement info composed.') + logger.info(" Measurement info composed.") info._unlocked = False info._update_redundant() return info @@ -455,17 +533,17 @@ def _compose_meas_info(res4, coils, trans, eeg): def _read_bad_chans(directory, info): """Read Bad channel list and match to internal names.""" - fname = op.join(directory, 'BadChannels') + fname = op.join(directory, "BadChannels") if not op.exists(fname): return [] - mapping = dict(zip(_clean_names(info['ch_names']), info['ch_names'])) - with open(fname, 'r') as fid: + mapping = dict(zip(_clean_names(info["ch_names"]), info["ch_names"])) + with open(fname, "r") as fid: bad_chans = [mapping[f.strip()] for f in fid.readlines()] return bad_chans def _annotate_bad_segments(directory, start_time, meas_date): - fname = op.join(directory, 'bad.segments') + fname = op.join(directory, "bad.segments") if not op.exists(fname): return None @@ -473,10 +551,10 @@ def _annotate_bad_segments(directory, start_time, meas_date): onsets = [] durations = [] desc = [] - with open(fname, 'r') as fid: + with open(fname, "r") as fid: for f in fid.readlines(): tmp = f.strip().split() - desc.append('bad_%s' % tmp[0]) + desc.append("bad_%s" % tmp[0]) onsets.append(np.float64(tmp[1]) - start_time) durations.append(np.float64(tmp[2]) - np.float64(tmp[1])) # return None if there are no bad segments diff --git a/mne/io/ctf/markers.py b/mne/io/ctf/markers.py index e129e8db505..1d000c00aff 100644 --- a/mne/io/ctf/markers.py +++ b/mne/io/ctf/markers.py @@ -18,27 +18,25 @@ def consume(fid, predicate): # just a consumer to move around conveniently def parse_marker(string): # XXX: there should be a nicer way to do that data = np.genfromtxt( - BytesIO(string.encode()), dtype=[('trial', int), ('sync', float)]) - return int(data['trial']), float(data['sync']) + BytesIO(string.encode()), dtype=[("trial", int), ("sync", float)] + ) + return int(data["trial"]), float(data["sync"]) markers = dict() with open(fname) as fid: - consume(fid, lambda line: not line.startswith('NUMBER OF MARKERS:')) + consume(fid, lambda line: not line.startswith("NUMBER OF MARKERS:")) num_of_markers = int(fid.readline()) for _ in range(num_of_markers): - consume(fid, lambda line: not line.startswith('NAME:')) - label = fid.readline().strip('\n') + consume(fid, lambda line: not line.startswith("NAME:")) + label = fid.readline().strip("\n") - consume( - fid, lambda line: not line.startswith('NUMBER OF SAMPLES:')) + consume(fid, lambda line: not line.startswith("NUMBER OF SAMPLES:")) n_markers = int(fid.readline()) - consume(fid, lambda line: not line.startswith('LIST OF SAMPLES:')) + consume(fid, lambda line: not line.startswith("LIST OF SAMPLES:")) next(fid) # skip the samples header - markers[label] = [ - parse_marker(next(fid)) for _ in range(n_markers) - ] + markers[label] = [parse_marker(next(fid)) for _ in range(n_markers)] return markers @@ -49,35 +47,42 @@ def _get_res4_info_needed_by_markers(directory): # instead of parsing the entire res4 file. res4 = _read_res4(directory) - total_offset_duration = res4['pre_trig_pts'] / res4['sfreq'] - trial_duration = res4['nsamp'] / res4['sfreq'] + total_offset_duration = res4["pre_trig_pts"] / res4["sfreq"] + trial_duration = res4["nsamp"] / res4["sfreq"] - meas_date = (_convert_time(res4['data_date'], - res4['data_time']), 0) + meas_date = (_convert_time(res4["data_date"], res4["data_time"]), 0) return total_offset_duration, trial_duration, meas_date def _read_annotations_ctf(directory): - total_offset, trial_duration, meas_date \ - = _get_res4_info_needed_by_markers(directory) - return _read_annotations_ctf_call(directory, total_offset, trial_duration, - meas_date) + total_offset, trial_duration, meas_date = _get_res4_info_needed_by_markers( + directory + ) + return _read_annotations_ctf_call( + directory, total_offset, trial_duration, meas_date + ) -def _read_annotations_ctf_call(directory, total_offset, trial_duration, - meas_date): - fname = op.join(directory, 'MarkerFile.mrk') +def _read_annotations_ctf_call(directory, total_offset, trial_duration, meas_date): + fname = op.join(directory, "MarkerFile.mrk") if not op.exists(fname): return Annotations(list(), list(), list(), orig_time=meas_date) else: markers = _get_markers(fname) - onset = [synctime + (trialnum * trial_duration) + total_offset - for _, m in markers.items() for (trialnum, synctime) in m] - - description = np.concatenate([ - np.repeat(label, len(m)) for label, m in markers.items() - ]) - - return Annotations(onset=onset, duration=np.zeros_like(onset), - description=description, orig_time=meas_date) + onset = [ + synctime + (trialnum * trial_duration) + total_offset + for _, m in markers.items() + for (trialnum, synctime) in m + ] + + description = np.concatenate( + [np.repeat(label, len(m)) for label, m in markers.items()] + ) + + return Annotations( + onset=onset, + duration=np.zeros_like(onset), + description=description, + orig_time=meas_date, + ) diff --git a/mne/io/ctf/res4.py b/mne/io/ctf/res4.py index 7b0a4e2b9e6..b5c0f884c99 100644 --- a/mne/io/ctf/res4.py +++ b/mne/io/ctf/res4.py @@ -15,40 +15,40 @@ def _make_ctf_name(directory, extra, raise_error=True): """Make a CTF name.""" - fname = op.join(directory, op.basename(directory)[:-3] + '.' + extra) + fname = op.join(directory, op.basename(directory)[:-3] + "." + extra) found = True if not op.isfile(fname): if raise_error: - raise OSError('Standard file %s not found' % fname) + raise OSError("Standard file %s not found" % fname) found = False return fname, found def _read_double(fid, n=1): """Read a double.""" - return np.fromfile(fid, '>f8', n) + return np.fromfile(fid, ">f8", n) def _read_string(fid, n_bytes, decode=True): """Read string.""" s0 = fid.read(n_bytes) - s = s0.split(b'\x00')[0] - return s.decode('utf-8') if decode else s + s = s0.split(b"\x00")[0] + return s.decode("utf-8") if decode else s def _read_ustring(fid, n_bytes): """Read unsigned character string.""" - return np.fromfile(fid, '>B', n_bytes) + return np.fromfile(fid, ">B", n_bytes) def _read_int2(fid): """Read int from short.""" - return np.fromfile(fid, '>i2', 1)[0] + return np.fromfile(fid, ">i2", 1)[0] def _read_int(fid): """Read a 32-bit integer.""" - return np.fromfile(fid, '>i4', 1)[0] + return np.fromfile(fid, ">i4", 1)[0] def _move_to_next(fid, byte=8): @@ -62,140 +62,157 @@ def _move_to_next(fid, byte=8): def _read_filter(fid): """Read filter information.""" f = dict() - f['freq'] = _read_double(fid)[0] - f['class'] = _read_int(fid) - f['type'] = _read_int(fid) - f['npar'] = _read_int2(fid) - f['pars'] = _read_double(fid, f['npar']) + f["freq"] = _read_double(fid)[0] + f["class"] = _read_int(fid) + f["type"] = _read_int(fid) + f["npar"] = _read_int2(fid) + f["pars"] = _read_double(fid, f["npar"]) return f def _read_comp_coeff(fid, d): """Read compensation coefficients.""" # Read the coefficients and initialize - d['ncomp'] = _read_int2(fid) - d['comp'] = list() + d["ncomp"] = _read_int2(fid) + d["comp"] = list() # Read each record - dt = np.dtype([ - ('sensor_name', 'S32'), - ('coeff_type', '>i4'), ('d0', '>i4'), - ('ncoeff', '>i2'), - ('sensors', 'S%s' % CTF.CTFV_SENSOR_LABEL, CTF.CTFV_MAX_BALANCING), - ('coeffs', '>f8', CTF.CTFV_MAX_BALANCING)]) - comps = np.fromfile(fid, dt, d['ncomp']) - for k in range(d['ncomp']): + dt = np.dtype( + [ + ("sensor_name", "S32"), + ("coeff_type", ">i4"), + ("d0", ">i4"), + ("ncoeff", ">i2"), + ("sensors", "S%s" % CTF.CTFV_SENSOR_LABEL, CTF.CTFV_MAX_BALANCING), + ("coeffs", ">f8", CTF.CTFV_MAX_BALANCING), + ] + ) + comps = np.fromfile(fid, dt, d["ncomp"]) + for k in range(d["ncomp"]): comp = dict() - d['comp'].append(comp) - comp['sensor_name'] = \ - comps['sensor_name'][k].split(b'\x00')[0].decode('utf-8') - comp['coeff_type'] = comps['coeff_type'][k].item() - comp['ncoeff'] = comps['ncoeff'][k].item() - comp['sensors'] = [s.split(b'\x00')[0].decode('utf-8') - for s in comps['sensors'][k][:comp['ncoeff']]] - comp['coeffs'] = comps['coeffs'][k][:comp['ncoeff']] - comp['scanno'] = d['ch_names'].index(comp['sensor_name']) + d["comp"].append(comp) + comp["sensor_name"] = comps["sensor_name"][k].split(b"\x00")[0].decode("utf-8") + comp["coeff_type"] = comps["coeff_type"][k].item() + comp["ncoeff"] = comps["ncoeff"][k].item() + comp["sensors"] = [ + s.split(b"\x00")[0].decode("utf-8") + for s in comps["sensors"][k][: comp["ncoeff"]] + ] + comp["coeffs"] = comps["coeffs"][k][: comp["ncoeff"]] + comp["scanno"] = d["ch_names"].index(comp["sensor_name"]) def _read_res4(dsdir): """Read the magical res4 file.""" # adapted from read_res4.c - name, _ = _make_ctf_name(dsdir, 'res4') + name, _ = _make_ctf_name(dsdir, "res4") res = dict() - with open(name, 'rb') as fid: + with open(name, "rb") as fid: # Read the fields - res['head'] = _read_string(fid, 8) - res['appname'] = _read_string(fid, 256) - res['origin'] = _read_string(fid, 256) - res['desc'] = _read_string(fid, 256) - res['nave'] = _read_int2(fid) - res['data_time'] = _read_string(fid, 255) - res['data_date'] = _read_string(fid, 255) + res["head"] = _read_string(fid, 8) + res["appname"] = _read_string(fid, 256) + res["origin"] = _read_string(fid, 256) + res["desc"] = _read_string(fid, 256) + res["nave"] = _read_int2(fid) + res["data_time"] = _read_string(fid, 255) + res["data_date"] = _read_string(fid, 255) # Seems that date and time can be swapped # (are they entered manually?!) - if '/' in res['data_time'] and ':' in res['data_date']: - data_date = res['data_date'] - res['data_date'] = res['data_time'] - res['data_time'] = data_date - res['nsamp'] = _read_int(fid) - res['nchan'] = _read_int2(fid) + if "/" in res["data_time"] and ":" in res["data_date"]: + data_date = res["data_date"] + res["data_date"] = res["data_time"] + res["data_time"] = data_date + res["nsamp"] = _read_int(fid) + res["nchan"] = _read_int2(fid) _move_to_next(fid, 8) - res['sfreq'] = _read_double(fid)[0] - res['epoch_time'] = _read_double(fid)[0] - res['no_trials'] = _read_int2(fid) + res["sfreq"] = _read_double(fid)[0] + res["epoch_time"] = _read_double(fid)[0] + res["no_trials"] = _read_int2(fid) _move_to_next(fid, 4) - res['pre_trig_pts'] = _read_int(fid) - res['no_trials_done'] = _read_int2(fid) - res['no_trials_bst_message_windowlay'] = _read_int2(fid) + res["pre_trig_pts"] = _read_int(fid) + res["no_trials_done"] = _read_int2(fid) + res["no_trials_bst_message_windowlay"] = _read_int2(fid) _move_to_next(fid, 4) - res['save_trials'] = _read_int(fid) - res['primary_trigger'] = fid.read(1) - res['secondary_trigger'] = [fid.read(1) - for k in range(CTF.CTFV_MAX_AVERAGE_BINS)] - res['trigger_polarity_mask'] = fid.read(1) - res['trigger_mode'] = _read_int2(fid) + res["save_trials"] = _read_int(fid) + res["primary_trigger"] = fid.read(1) + res["secondary_trigger"] = [ + fid.read(1) for k in range(CTF.CTFV_MAX_AVERAGE_BINS) + ] + res["trigger_polarity_mask"] = fid.read(1) + res["trigger_mode"] = _read_int2(fid) _move_to_next(fid, 4) - res['accept_reject'] = _read_int(fid) - res['run_time_bst_message_windowlay'] = _read_int2(fid) + res["accept_reject"] = _read_int(fid) + res["run_time_bst_message_windowlay"] = _read_int2(fid) _move_to_next(fid, 4) - res['zero_head'] = _read_int(fid) + res["zero_head"] = _read_int(fid) _move_to_next(fid, 4) - res['artifact_mode'] = _read_int(fid) + res["artifact_mode"] = _read_int(fid) _read_int(fid) # padding - res['nf_run_name'] = _read_string(fid, 32) - res['nf_run_title'] = _read_string(fid, 256) - res['nf_instruments'] = _read_string(fid, 32) - res['nf_collect_descriptor'] = _read_string(fid, 32) - res['nf_subject_id'] = _read_string(fid, 32) - res['nf_operator'] = _read_string(fid, 32) - if len(res['nf_operator']) == 0: - res['nf_operator'] = None - res['nf_sensor_file_name'] = _read_ustring(fid, 60) + res["nf_run_name"] = _read_string(fid, 32) + res["nf_run_title"] = _read_string(fid, 256) + res["nf_instruments"] = _read_string(fid, 32) + res["nf_collect_descriptor"] = _read_string(fid, 32) + res["nf_subject_id"] = _read_string(fid, 32) + res["nf_operator"] = _read_string(fid, 32) + if len(res["nf_operator"]) == 0: + res["nf_operator"] = None + res["nf_sensor_file_name"] = _read_ustring(fid, 60) _move_to_next(fid, 4) - res['rdlen'] = _read_int(fid) + res["rdlen"] = _read_int(fid) fid.seek(CTF.FUNNY_POS, 0) - if res['rdlen'] > 0: - res['run_desc'] = _read_string(fid, res['rdlen']) + if res["rdlen"] > 0: + res["run_desc"] = _read_string(fid, res["rdlen"]) # Filters - res['nfilt'] = _read_int2(fid) - res['filters'] = list() - for k in range(res['nfilt']): - res['filters'].append(_read_filter(fid)) + res["nfilt"] = _read_int2(fid) + res["filters"] = list() + for k in range(res["nfilt"]): + res["filters"].append(_read_filter(fid)) # Channel information (names, then data) - res['ch_names'] = list() - for k in range(res['nchan']): + res["ch_names"] = list() + for k in range(res["nchan"]): ch_name = _read_string(fid, 32) - res['ch_names'].append(ch_name) - _coil_dt = np.dtype([ - ('pos', '>f8', 3), ('d0', '>f8'), - ('norm', '>f8', 3), ('d1', '>f8'), - ('turns', '>i2'), ('d2', '>i4'), ('d3', '>i2'), - ('area', '>f8')]) - _ch_dt = np.dtype([ - ('sensor_type_index', '>i2'), - ('original_run_no', '>i2'), - ('coil_type', '>i4'), - ('proper_gain', '>f8'), - ('qgain', '>f8'), - ('io_gain', '>f8'), - ('io_offset', '>f8'), - ('num_coils', '>i2'), - ('grad_order_no', '>i2'), ('d0', '>i4'), - ('coil', _coil_dt, CTF.CTFV_MAX_COILS), - ('head_coil', _coil_dt, CTF.CTFV_MAX_COILS)]) - chs = np.fromfile(fid, _ch_dt, res['nchan']) - for coil in (chs['coil'], chs['head_coil']): - coil['pos'] /= 100. - coil['area'] *= 1e-4 + res["ch_names"].append(ch_name) + _coil_dt = np.dtype( + [ + ("pos", ">f8", 3), + ("d0", ">f8"), + ("norm", ">f8", 3), + ("d1", ">f8"), + ("turns", ">i2"), + ("d2", ">i4"), + ("d3", ">i2"), + ("area", ">f8"), + ] + ) + _ch_dt = np.dtype( + [ + ("sensor_type_index", ">i2"), + ("original_run_no", ">i2"), + ("coil_type", ">i4"), + ("proper_gain", ">f8"), + ("qgain", ">f8"), + ("io_gain", ">f8"), + ("io_offset", ">f8"), + ("num_coils", ">i2"), + ("grad_order_no", ">i2"), + ("d0", ">i4"), + ("coil", _coil_dt, CTF.CTFV_MAX_COILS), + ("head_coil", _coil_dt, CTF.CTFV_MAX_COILS), + ] + ) + chs = np.fromfile(fid, _ch_dt, res["nchan"]) + for coil in (chs["coil"], chs["head_coil"]): + coil["pos"] /= 100.0 + coil["area"] *= 1e-4 # convert to dict chs = [dict(zip(chs.dtype.names, x)) for x in chs] - res['chs'] = chs - for k in range(res['nchan']): - res['chs'][k]['ch_name'] = res['ch_names'][k] + res["chs"] = chs + for k in range(res["nchan"]): + res["chs"][k]["ch_name"] = res["ch_names"][k] # The compensation coefficients _read_comp_coeff(fid, res) - logger.info(' res4 data read.') + logger.info(" res4 data read.") return res diff --git a/mne/io/ctf/tests/test_ctf.py b/mne/io/ctf/tests/test_ctf.py index 1e699007714..0b550a02c19 100644 --- a/mne/io/ctf/tests/test_ctf.py +++ b/mne/io/ctf/tests/test_ctf.py @@ -15,8 +15,13 @@ import mne import mne.io.ctf.info -from mne import (pick_types, read_annotations, create_info, - events_from_annotations, make_forward_solution) +from mne import ( + pick_types, + read_annotations, + create_info, + events_from_annotations, + make_forward_solution, +) from mne.transforms import apply_trans from mne.io import read_raw_fif, read_raw_ctf, RawArray from mne.io.compensator import get_current_comp @@ -24,21 +29,22 @@ from mne.io.ctf.info import _convert_time from mne.io.tests.test_raw import _test_raw_reader from mne.tests.test_annotations import _assert_annotations_equal -from mne.utils import (_clean_names, catch_logging, _stamp_to_dt, - _record_warnings) +from mne.utils import _clean_names, catch_logging, _stamp_to_dt, _record_warnings from mne.datasets import testing, spm_face, brainstorm from mne.io.constants import FIFF -ctf_dir = testing.data_path(download=False) / 'CTF' -ctf_fname_continuous = 'testdata_ctf.ds' -ctf_fname_1_trial = 'testdata_ctf_short.ds' -ctf_fname_2_trials = 'testdata_ctf_pseudocontinuous.ds' -ctf_fname_discont = 'testdata_ctf_short_discontinuous.ds' -ctf_fname_somato = 'somMDYO-18av.ds' -ctf_fname_catch = 'catch-alp-good-f.ds' +ctf_dir = testing.data_path(download=False) / "CTF" +ctf_fname_continuous = "testdata_ctf.ds" +ctf_fname_1_trial = "testdata_ctf_short.ds" +ctf_fname_2_trials = "testdata_ctf_pseudocontinuous.ds" +ctf_fname_discont = "testdata_ctf_short_discontinuous.ds" +ctf_fname_somato = "somMDYO-18av.ds" +ctf_fname_catch = "catch-alp-good-f.ds" somato_fname = op.join( - brainstorm.bst_raw.data_path(download=False), 'MEG', 'bst_raw', - 'subj001_somatosensory_20111109_01_AUX-f.ds' + brainstorm.bst_raw.data_path(download=False), + "MEG", + "bst_raw", + "subj001_somatosensory_20111109_01_AUX-f.ds", ) spm_path = spm_face.data_path(download=False) @@ -63,53 +69,62 @@ def test_read_ctf(tmp_path): """Test CTF reader.""" temp_dir = str(tmp_path) - out_fname = op.join(temp_dir, 'test_py_raw.fif') + out_fname = op.join(temp_dir, "test_py_raw.fif") # Create a dummy .eeg file so we can test our reading/application of it - os.mkdir(op.join(temp_dir, 'randpos')) - ctf_eeg_fname = op.join(temp_dir, 'randpos', ctf_fname_catch) + os.mkdir(op.join(temp_dir, "randpos")) + ctf_eeg_fname = op.join(temp_dir, "randpos", ctf_fname_catch) shutil.copytree(op.join(ctf_dir, ctf_fname_catch), ctf_eeg_fname) - with pytest.warns(RuntimeWarning, match='RMSP .* changed to a MISC ch'): + with pytest.warns(RuntimeWarning, match="RMSP .* changed to a MISC ch"): raw = _test_raw_reader(read_raw_ctf, directory=ctf_eeg_fname) picks = pick_types(raw.info, meg=False, eeg=True) pos = np.random.RandomState(42).randn(len(picks), 3) - fake_eeg_fname = op.join(ctf_eeg_fname, 'catch-alp-good-f.eeg') + fake_eeg_fname = op.join(ctf_eeg_fname, "catch-alp-good-f.eeg") # Create a bad file - with open(fake_eeg_fname, 'wb') as fid: - fid.write('foo\n'.encode('ascii')) + with open(fake_eeg_fname, "wb") as fid: + fid.write("foo\n".encode("ascii")) pytest.raises(RuntimeError, read_raw_ctf, ctf_eeg_fname) # Create a good file - with open(fake_eeg_fname, 'wb') as fid: + with open(fake_eeg_fname, "wb") as fid: for ii, ch_num in enumerate(picks): - args = (str(ch_num + 1), raw.ch_names[ch_num],) + tuple( - '%0.5f' % x for x in 100 * pos[ii]) # convert to cm - fid.write(('\t'.join(args) + '\n').encode('ascii')) - pos_read_old = np.array([raw.info['chs'][p]['loc'][:3] for p in picks]) - with pytest.warns(RuntimeWarning, match='RMSP .* changed to a MISC ch'): + args = ( + str(ch_num + 1), + raw.ch_names[ch_num], + ) + tuple( + "%0.5f" % x for x in 100 * pos[ii] + ) # convert to cm + fid.write(("\t".join(args) + "\n").encode("ascii")) + pos_read_old = np.array([raw.info["chs"][p]["loc"][:3] for p in picks]) + with pytest.warns(RuntimeWarning, match="RMSP .* changed to a MISC ch"): raw = read_raw_ctf(ctf_eeg_fname) # read modified data - pos_read = np.array([raw.info['chs'][p]['loc'][:3] for p in picks]) - assert_allclose(apply_trans(raw.info['ctf_head_t'], pos), pos_read, - rtol=1e-5, atol=1e-5) + pos_read = np.array([raw.info["chs"][p]["loc"][:3] for p in picks]) + assert_allclose( + apply_trans(raw.info["ctf_head_t"], pos), pos_read, rtol=1e-5, atol=1e-5 + ) assert (pos_read == pos_read_old).mean() < 0.1 - shutil.copy(op.join(ctf_dir, 'catch-alp-good-f.ds_randpos_raw.fif'), - op.join(temp_dir, 'randpos', 'catch-alp-good-f.ds_raw.fif')) + shutil.copy( + op.join(ctf_dir, "catch-alp-good-f.ds_randpos_raw.fif"), + op.join(temp_dir, "randpos", "catch-alp-good-f.ds_raw.fif"), + ) # Create a version with no hc, starting out *with* EEG pos (error) - os.mkdir(op.join(temp_dir, 'nohc')) - ctf_no_hc_fname = op.join(temp_dir, 'no_hc', ctf_fname_catch) + os.mkdir(op.join(temp_dir, "nohc")) + ctf_no_hc_fname = op.join(temp_dir, "no_hc", ctf_fname_catch) shutil.copytree(ctf_eeg_fname, ctf_no_hc_fname) remove_base = op.join(ctf_no_hc_fname, op.basename(ctf_fname_catch[:-3])) - os.remove(remove_base + '.hc') - with pytest.warns(RuntimeWarning, match='MISC channel'): + os.remove(remove_base + ".hc") + with pytest.warns(RuntimeWarning, match="MISC channel"): pytest.raises(RuntimeError, read_raw_ctf, ctf_no_hc_fname) - os.remove(remove_base + '.eeg') - shutil.copy(op.join(ctf_dir, 'catch-alp-good-f.ds_nohc_raw.fif'), - op.join(temp_dir, 'no_hc', 'catch-alp-good-f.ds_raw.fif')) + os.remove(remove_base + ".eeg") + shutil.copy( + op.join(ctf_dir, "catch-alp-good-f.ds_nohc_raw.fif"), + op.join(temp_dir, "no_hc", "catch-alp-good-f.ds_raw.fif"), + ) # All our files use_fnames = [op.join(ctf_dir, c) for c in ctf_fnames] for fname in use_fnames: - raw_c = read_raw_fif(fname + '_raw.fif', preload=True) + raw_c = read_raw_fif(fname + "_raw.fif", preload=True) # sometimes matches "MISC channel" with _record_warnings(): raw = read_raw_ctf(fname) @@ -118,172 +133,221 @@ def test_read_ctf(tmp_path): assert_array_equal(raw.ch_names, raw_c.ch_names) assert_allclose(raw.times, raw_c.times) assert_allclose(raw._cals, raw_c._cals) - assert (raw.info['meas_id']['version'] == - raw_c.info['meas_id']['version'] + 1) - for t in ('dev_head_t', 'dev_ctf_t', 'ctf_head_t'): - assert_allclose(raw.info[t]['trans'], raw_c.info[t]['trans'], - rtol=1e-4, atol=1e-7) + assert raw.info["meas_id"]["version"] == raw_c.info["meas_id"]["version"] + 1 + for t in ("dev_head_t", "dev_ctf_t", "ctf_head_t"): + assert_allclose( + raw.info[t]["trans"], raw_c.info[t]["trans"], rtol=1e-4, atol=1e-7 + ) # XXX 2019/11/29 : MNC-C FIF conversion files don't have meas_date set. # Consider adding meas_date to below checks once this is addressed in # MNE-C - for key in ('acq_pars', 'acq_stim', 'bads', - 'ch_names', 'custom_ref_applied', 'description', - 'events', 'experimenter', 'highpass', 'line_freq', - 'lowpass', 'nchan', 'proj_id', 'proj_name', - 'projs', 'sfreq', 'subject_info'): + for key in ( + "acq_pars", + "acq_stim", + "bads", + "ch_names", + "custom_ref_applied", + "description", + "events", + "experimenter", + "highpass", + "line_freq", + "lowpass", + "nchan", + "proj_id", + "proj_name", + "projs", + "sfreq", + "subject_info", + ): assert raw.info[key] == raw_c.info[key], key if op.basename(fname) not in single_trials: # We don't force buffer size to be smaller like MNE-C assert raw.buffer_size_sec == raw_c.buffer_size_sec - assert len(raw.info['comps']) == len(raw_c.info['comps']) - for c1, c2 in zip(raw.info['comps'], raw_c.info['comps']): - for key in ('colcals', 'rowcals'): + assert len(raw.info["comps"]) == len(raw_c.info["comps"]) + for c1, c2 in zip(raw.info["comps"], raw_c.info["comps"]): + for key in ("colcals", "rowcals"): assert_allclose(c1[key], c2[key]) - assert c1['save_calibrated'] == c2['save_calibrated'] - for key in ('row_names', 'col_names', 'nrow', 'ncol'): - assert_array_equal(c1['data'][key], c2['data'][key]) - assert_allclose(c1['data']['data'], c2['data']['data'], atol=1e-7, - rtol=1e-5) - assert_allclose(raw.info['hpi_results'][0]['coord_trans']['trans'], - raw_c.info['hpi_results'][0]['coord_trans']['trans'], - rtol=1e-5, atol=1e-7) - assert len(raw.info['chs']) == len(raw_c.info['chs']) - for ii, (c1, c2) in enumerate(zip(raw.info['chs'], raw_c.info['chs'])): - for key in ('kind', 'scanno', 'unit', 'ch_name', 'unit_mul', - 'range', 'coord_frame', 'coil_type', 'logno'): - if c1['ch_name'] == 'RMSP' and \ - 'catch-alp-good-f' in fname and \ - key in ('kind', 'unit', 'coord_frame', 'coil_type', - 'logno'): + assert c1["save_calibrated"] == c2["save_calibrated"] + for key in ("row_names", "col_names", "nrow", "ncol"): + assert_array_equal(c1["data"][key], c2["data"][key]) + assert_allclose( + c1["data"]["data"], c2["data"]["data"], atol=1e-7, rtol=1e-5 + ) + assert_allclose( + raw.info["hpi_results"][0]["coord_trans"]["trans"], + raw_c.info["hpi_results"][0]["coord_trans"]["trans"], + rtol=1e-5, + atol=1e-7, + ) + assert len(raw.info["chs"]) == len(raw_c.info["chs"]) + for ii, (c1, c2) in enumerate(zip(raw.info["chs"], raw_c.info["chs"])): + for key in ( + "kind", + "scanno", + "unit", + "ch_name", + "unit_mul", + "range", + "coord_frame", + "coil_type", + "logno", + ): + if ( + c1["ch_name"] == "RMSP" + and "catch-alp-good-f" in fname + and key in ("kind", "unit", "coord_frame", "coil_type", "logno") + ): continue # XXX see below... - if key == 'coil_type' and c1[key] == FIFF.FIFFV_COIL_EEG: + if key == "coil_type" and c1[key] == FIFF.FIFFV_COIL_EEG: # XXX MNE-C bug that this is not set assert c2[key] == FIFF.FIFFV_COIL_NONE continue assert c1[key] == c2[key], key - for key in ('cal',): - assert_allclose(c1[key], c2[key], atol=1e-6, rtol=1e-4, - err_msg='raw.info["chs"][%d][%s]' % (ii, key)) + for key in ("cal",): + assert_allclose( + c1[key], + c2[key], + atol=1e-6, + rtol=1e-4, + err_msg='raw.info["chs"][%d][%s]' % (ii, key), + ) # XXX 2016/02/24: fixed bug with normal computation that used # to exist, once mne-C tools are updated we should update our FIF # conversion files, then the slices can go away (and the check # can be combined with that for "cal") - for key in ('loc',): - if c1['ch_name'] == 'RMSP' and 'catch-alp-good-f' in fname: + for key in ("loc",): + if c1["ch_name"] == "RMSP" and "catch-alp-good-f" in fname: continue - if (c2[key][:3] == 0.).all(): + if (c2[key][:3] == 0.0).all(): check = [np.nan] * 3 else: check = c2[key][:3] - assert_allclose(c1[key][:3], check, atol=1e-6, rtol=1e-4, - err_msg='raw.info["chs"][%d][%s]' % (ii, key)) - if (c2[key][3:] == 0.).all(): + assert_allclose( + c1[key][:3], + check, + atol=1e-6, + rtol=1e-4, + err_msg='raw.info["chs"][%d][%s]' % (ii, key), + ) + if (c2[key][3:] == 0.0).all(): check = [np.nan] * 3 else: check = c2[key][9:12] - assert_allclose(c1[key][9:12], check, atol=1e-6, rtol=1e-4, - err_msg='raw.info["chs"][%d][%s]' % (ii, key)) + assert_allclose( + c1[key][9:12], + check, + atol=1e-6, + rtol=1e-4, + err_msg='raw.info["chs"][%d][%s]' % (ii, key), + ) # Make sure all digitization points are in the MNE head coord frame - for p in raw.info['dig']: - assert p['coord_frame'] == FIFF.FIFFV_COORD_HEAD, \ - 'dig points must be in FIFF.FIFFV_COORD_HEAD' + for p in raw.info["dig"]: + assert ( + p["coord_frame"] == FIFF.FIFFV_COORD_HEAD + ), "dig points must be in FIFF.FIFFV_COORD_HEAD" - if fname.endswith('catch-alp-good-f.ds'): # omit points from .pos file + if fname.endswith("catch-alp-good-f.ds"): # omit points from .pos file with raw.info._unlock(): - raw.info['dig'] = raw.info['dig'][:-10] + raw.info["dig"] = raw.info["dig"][:-10] # XXX: Next test would fail because c-tools assign the fiducials from # CTF data as HPI. Should eventually clarify/unify with Matti. # assert_dig_allclose(raw.info, raw_c.info) # check data match - raw_c.save(out_fname, overwrite=True, buffer_size_sec=1.) + raw_c.save(out_fname, overwrite=True, buffer_size_sec=1.0) raw_read = read_raw_fif(out_fname) # so let's check tricky cases based on sample boundaries rng = np.random.RandomState(0) pick_ch = rng.permutation(np.arange(len(raw.ch_names)))[:10] - bnd = int(round(raw.info['sfreq'] * raw.buffer_size_sec)) - assert bnd == raw._raw_extras[0]['block_size'] + bnd = int(round(raw.info["sfreq"] * raw.buffer_size_sec)) + assert bnd == raw._raw_extras[0]["block_size"] assert bnd == block_sizes[op.basename(fname)] - slices = (slice(0, bnd), slice(bnd - 1, bnd), slice(3, bnd), - slice(3, 300), slice(None)) + slices = ( + slice(0, bnd), + slice(bnd - 1, bnd), + slice(3, bnd), + slice(3, 300), + slice(None), + ) if len(raw.times) >= 2 * bnd: # at least two complete blocks - slices = slices + (slice(bnd, 2 * bnd), slice(bnd, bnd + 1), - slice(0, bnd + 100)) + slices = slices + ( + slice(bnd, 2 * bnd), + slice(bnd, bnd + 1), + slice(0, bnd + 100), + ) for sl_time in slices: - assert_allclose(raw[pick_ch, sl_time][0], - raw_c[pick_ch, sl_time][0]) - assert_allclose(raw_read[pick_ch, sl_time][0], - raw_c[pick_ch, sl_time][0]) + assert_allclose(raw[pick_ch, sl_time][0], raw_c[pick_ch, sl_time][0]) + assert_allclose(raw_read[pick_ch, sl_time][0], raw_c[pick_ch, sl_time][0]) # all data / preload raw.load_data() assert_allclose(raw[:][0], raw_c[:][0], atol=1e-15) # test bad segment annotations - if 'testdata_ctf_short.ds' in fname: - assert 'bad' in raw.annotations.description[0] + if "testdata_ctf_short.ds" in fname: + assert "bad" in raw.annotations.description[0] assert_allclose(raw.annotations.onset, [2.15]) assert_allclose(raw.annotations.duration, [0.0225]) - with pytest.raises(TypeError, match='path-like'): + with pytest.raises(TypeError, match="path-like"): read_raw_ctf(1) - with pytest.raises(FileNotFoundError, match='does not exist'): - read_raw_ctf(ctf_fname_continuous + 'foo.ds') + with pytest.raises(FileNotFoundError, match="does not exist"): + read_raw_ctf(ctf_fname_continuous + "foo.ds") # test ignoring of system clock - read_raw_ctf(op.join(ctf_dir, ctf_fname_continuous), 'ignore') - with pytest.raises(ValueError, match='system_clock'): - read_raw_ctf(op.join(ctf_dir, ctf_fname_continuous), 'foo') + read_raw_ctf(op.join(ctf_dir, ctf_fname_continuous), "ignore") + with pytest.raises(ValueError, match="system_clock"): + read_raw_ctf(op.join(ctf_dir, ctf_fname_continuous), "foo") @testing.requires_testing_data def test_rawctf_clean_names(): """Test RawCTF _clean_names method.""" # read test data - with pytest.warns(RuntimeWarning, match='ref channel RMSP did not'): + with pytest.warns(RuntimeWarning, match="ref channel RMSP did not"): raw = read_raw_ctf(op.join(ctf_dir, ctf_fname_catch)) - raw_cleaned = read_raw_ctf(op.join(ctf_dir, ctf_fname_catch), - clean_names=True) + raw_cleaned = read_raw_ctf(op.join(ctf_dir, ctf_fname_catch), clean_names=True) test_channel_names = _clean_names(raw.ch_names) - test_info_comps = copy.deepcopy(raw.info['comps']) + test_info_comps = copy.deepcopy(raw.info["comps"]) # channel names should not be cleaned by default assert raw.ch_names != test_channel_names - chs_ch_names = [ch['ch_name'] for ch in raw.info['chs']] + chs_ch_names = [ch["ch_name"] for ch in raw.info["chs"]] assert chs_ch_names != test_channel_names - for test_comp, comp in zip(test_info_comps, raw.info['comps']): - for key in ('row_names', 'col_names'): - assert not array_equal(_clean_names(test_comp['data'][key]), - comp['data'][key]) + for test_comp, comp in zip(test_info_comps, raw.info["comps"]): + for key in ("row_names", "col_names"): + assert not array_equal( + _clean_names(test_comp["data"][key]), comp["data"][key] + ) # channel names should be cleaned if clean_names=True assert raw_cleaned.ch_names == test_channel_names - for ch, test_ch_name in zip(raw_cleaned.info['chs'], test_channel_names): - assert ch['ch_name'] == test_ch_name + for ch, test_ch_name in zip(raw_cleaned.info["chs"], test_channel_names): + assert ch["ch_name"] == test_ch_name - for test_comp, comp in zip(test_info_comps, raw_cleaned.info['comps']): - for key in ('row_names', 'col_names'): - assert _clean_names(test_comp['data'][key]) == comp['data'][key] + for test_comp, comp in zip(test_info_comps, raw_cleaned.info["comps"]): + for key in ("row_names", "col_names"): + assert _clean_names(test_comp["data"][key]) == comp["data"][key] @spm_face.requires_spm_data def test_read_spm_ctf(): """Test CTF reader with omitted samples.""" - raw_fname = op.join(spm_path, 'MEG', 'spm', - 'SPM_CTF_MEG_example_faces1_3D.ds') + raw_fname = op.join(spm_path, "MEG", "spm", "SPM_CTF_MEG_example_faces1_3D.ds") raw = read_raw_ctf(raw_fname) extras = raw._raw_extras[0] - assert extras['n_samp'] == raw.n_times - assert extras['n_samp'] != extras['n_samp_tot'] + assert extras["n_samp"] == raw.n_times + assert extras["n_samp"] != extras["n_samp_tot"] # Test that LPA, nasion and RPA are correct. - coord_frames = np.array([d['coord_frame'] for d in raw.info['dig']]) + coord_frames = np.array([d["coord_frame"] for d in raw.info["dig"]]) assert np.all(coord_frames == FIFF.FIFFV_COORD_HEAD) - cardinals = {d['ident']: d['r'] for d in raw.info['dig']} + cardinals = {d["ident"]: d["r"] for d in raw.info["dig"]} assert cardinals[1][0] < cardinals[2][0] < cardinals[3][0] # x coord assert cardinals[1][1] < cardinals[2][1] # y coord assert cardinals[3][1] < cardinals[2][1] # y coord @@ -292,75 +356,259 @@ def test_read_spm_ctf(): @testing.requires_testing_data -@pytest.mark.parametrize('comp_grade', [0, 1]) +@pytest.mark.parametrize("comp_grade", [0, 1]) def test_saving_picked(tmp_path, comp_grade): """Test saving picked CTF instances.""" temp_dir = str(tmp_path) - out_fname = op.join(temp_dir, 'test_py_raw.fif') + out_fname = op.join(temp_dir, "test_py_raw.fif") raw = read_raw_ctf(op.join(ctf_dir, ctf_fname_1_trial)) - assert raw.info['meas_date'] == _stamp_to_dt((1367228160, 0)) + assert raw.info["meas_date"] == _stamp_to_dt((1367228160, 0)) raw.crop(0, 1).load_data() assert raw.compensation_grade == get_current_comp(raw.info) == 0 - assert len(raw.info['comps']) == 5 + assert len(raw.info["comps"]) == 5 pick_kwargs = dict(meg=True, ref_meg=False, verbose=True) raw.apply_gradient_compensation(comp_grade) with catch_logging() as log: raw_pick = raw.copy().pick_types(**pick_kwargs) - assert len(raw.info['comps']) == 5 - assert len(raw_pick.info['comps']) == 0 + assert len(raw.info["comps"]) == 5 + assert len(raw_pick.info["comps"]) == 0 log = log.getvalue() - assert 'Removing 5 compensators' in log + assert "Removing 5 compensators" in log raw_pick.save(out_fname, overwrite=True) # should work raw2 = read_raw_fif(out_fname) - assert (raw_pick.ch_names == raw2.ch_names) + assert raw_pick.ch_names == raw2.ch_names assert_array_equal(raw_pick.times, raw2.times) - assert_allclose(raw2[0:20][0], raw_pick[0:20][0], rtol=1e-6, - atol=1e-20) # atol is very small but > 0 + assert_allclose( + raw2[0:20][0], raw_pick[0:20][0], rtol=1e-6, atol=1e-20 + ) # atol is very small but > 0 raw2 = read_raw_fif(out_fname, preload=True) - assert (raw_pick.ch_names == raw2.ch_names) + assert raw_pick.ch_names == raw2.ch_names assert_array_equal(raw_pick.times, raw2.times) - assert_allclose(raw2[0:20][0], raw_pick[0:20][0], rtol=1e-6, - atol=1e-20) # atol is very small but > 0 + assert_allclose( + raw2[0:20][0], raw_pick[0:20][0], rtol=1e-6, atol=1e-20 + ) # atol is very small but > 0 @brainstorm.bst_raw.requires_bstraw_data def test_read_ctf_annotations(): """Test reading CTF marker file.""" - EXPECTED_LATENCIES = np.array([ - 5640, 7950, 9990, 12253, 14171, 16557, 18896, 20846, # noqa - 22702, 24990, 26830, 28974, 30906, 33077, 34985, 36907, # noqa - 38922, 40760, 42881, 45222, 47457, 49618, 51802, 54227, # noqa - 56171, 58274, 60394, 62375, 64444, 66767, 68827, 71109, # noqa - 73499, 75807, 78146, 80415, 82554, 84508, 86403, 88426, # noqa - 90746, 92893, 94779, 96822, 98996, 99001, 100949, 103325, # noqa - 105322, 107678, 109667, 111844, 113682, 115817, 117691, 119663, # noqa - 121966, 123831, 126110, 128490, 130521, 132808, 135204, 137210, # noqa - 139130, 141390, 143660, 145748, 147889, 150205, 152528, 154646, # noqa - 156897, 159191, 161446, 163722, 166077, 168467, 170624, 172519, # noqa - 174719, 176886, 179062, 181405, 183709, 186034, 188454, 190330, # noqa - 192660, 194682, 196834, 199161, 201035, 203008, 204999, 207409, # noqa - 209661, 211895, 213957, 216005, 218040, 220178, 222137, 224305, # noqa - 226297, 228654, 230755, 232909, 235205, 237373, 239723, 241762, # noqa - 243748, 245762, 247801, 250055, 251886, 254252, 256441, 258354, # noqa - 260680, 263026, 265048, 267073, 269235, 271556, 273927, 276197, # noqa - 278436, 280536, 282691, 284933, 287061, 288936, 290941, 293183, # noqa - 295369, 297729, 299626, 301546, 303449, 305548, 307882, 310124, # noqa - 312374, 314509, 316815, 318789, 320981, 322879, 324878, 326959, # noqa - 329341, 331200, 331201, 333469, 335584, 337984, 340143, 342034, # noqa - 344360, 346309, 348544, 350970, 353052, 355227, 357449, 359603, # noqa - 361725, 363676, 365735, 367799, 369777, 371904, 373856, 376204, # noqa - 378391, 380800, 382859, 385161, 387093, 389434, 391624, 393785, # noqa - 396093, 398214, 400198, 402166, 404104, 406047, 408372, 410686, # noqa - 413029, 414975, 416850, 418797, 420824, 422959, 425026, 427215, # noqa - 429278, 431668 # noqa - ]) - 1 # Fieldtrip has 1 sample difference with MNE + EXPECTED_LATENCIES = ( + np.array( + [ + 5640, + 7950, + 9990, + 12253, + 14171, + 16557, + 18896, + 20846, # noqa + 22702, + 24990, + 26830, + 28974, + 30906, + 33077, + 34985, + 36907, # noqa + 38922, + 40760, + 42881, + 45222, + 47457, + 49618, + 51802, + 54227, # noqa + 56171, + 58274, + 60394, + 62375, + 64444, + 66767, + 68827, + 71109, # noqa + 73499, + 75807, + 78146, + 80415, + 82554, + 84508, + 86403, + 88426, # noqa + 90746, + 92893, + 94779, + 96822, + 98996, + 99001, + 100949, + 103325, # noqa + 105322, + 107678, + 109667, + 111844, + 113682, + 115817, + 117691, + 119663, # noqa + 121966, + 123831, + 126110, + 128490, + 130521, + 132808, + 135204, + 137210, # noqa + 139130, + 141390, + 143660, + 145748, + 147889, + 150205, + 152528, + 154646, # noqa + 156897, + 159191, + 161446, + 163722, + 166077, + 168467, + 170624, + 172519, # noqa + 174719, + 176886, + 179062, + 181405, + 183709, + 186034, + 188454, + 190330, # noqa + 192660, + 194682, + 196834, + 199161, + 201035, + 203008, + 204999, + 207409, # noqa + 209661, + 211895, + 213957, + 216005, + 218040, + 220178, + 222137, + 224305, # noqa + 226297, + 228654, + 230755, + 232909, + 235205, + 237373, + 239723, + 241762, # noqa + 243748, + 245762, + 247801, + 250055, + 251886, + 254252, + 256441, + 258354, # noqa + 260680, + 263026, + 265048, + 267073, + 269235, + 271556, + 273927, + 276197, # noqa + 278436, + 280536, + 282691, + 284933, + 287061, + 288936, + 290941, + 293183, # noqa + 295369, + 297729, + 299626, + 301546, + 303449, + 305548, + 307882, + 310124, # noqa + 312374, + 314509, + 316815, + 318789, + 320981, + 322879, + 324878, + 326959, # noqa + 329341, + 331200, + 331201, + 333469, + 335584, + 337984, + 340143, + 342034, # noqa + 344360, + 346309, + 348544, + 350970, + 353052, + 355227, + 357449, + 359603, # noqa + 361725, + 363676, + 365735, + 367799, + 369777, + 371904, + 373856, + 376204, # noqa + 378391, + 380800, + 382859, + 385161, + 387093, + 389434, + 391624, + 393785, # noqa + 396093, + 398214, + 400198, + 402166, + 404104, + 406047, + 408372, + 410686, # noqa + 413029, + 414975, + 416850, + 418797, + 420824, + 422959, + 425026, + 427215, # noqa + 429278, + 431668, # noqa + ] + ) + - 1 + ) # Fieldtrip has 1 sample difference with MNE raw = RawArray( data=np.empty((1, 432000), dtype=np.float64), - info=create_info(ch_names=1, sfreq=1200.0)) - raw.set_meas_date(read_raw_ctf(somato_fname).info['meas_date']) + info=create_info(ch_names=1, sfreq=1200.0), + ) + raw.set_meas_date(read_raw_ctf(somato_fname).info["meas_date"]) raw.set_annotations(read_annotations(somato_fname)) events, _ = events_from_annotations(raw) @@ -376,14 +624,44 @@ def test_read_ctf_annotations_smoke_test(): of whatever is in the MarkerFile.mrk. """ EXPECTED_ONSET = [ - 0., 0.1425, 0.285, 0.42833333, 0.57083333, 0.71416667, 0.85666667, - 0.99916667, 1.1425, 1.285, 1.4275, 1.57083333, 1.71333333, 1.85666667, - 1.99916667, 2.14166667, 2.285, 2.4275, 2.57083333, 2.71333333, - 2.85583333, 2.99916667, 3.14166667, 3.28416667, 3.4275, 3.57, - 3.71333333, 3.85583333, 3.99833333, 4.14166667, 4.28416667, 4.42666667, - 4.57, 4.7125, 4.85583333, 4.99833333 + 0.0, + 0.1425, + 0.285, + 0.42833333, + 0.57083333, + 0.71416667, + 0.85666667, + 0.99916667, + 1.1425, + 1.285, + 1.4275, + 1.57083333, + 1.71333333, + 1.85666667, + 1.99916667, + 2.14166667, + 2.285, + 2.4275, + 2.57083333, + 2.71333333, + 2.85583333, + 2.99916667, + 3.14166667, + 3.28416667, + 3.4275, + 3.57, + 3.71333333, + 3.85583333, + 3.99833333, + 4.14166667, + 4.28416667, + 4.42666667, + 4.57, + 4.7125, + 4.85583333, + 4.99833333, ] - fname = op.join(ctf_dir, 'testdata_ctf_mc.ds') + fname = op.join(ctf_dir, "testdata_ctf_mc.ds") annot = read_annotations(fname) assert_allclose(annot.onset, EXPECTED_ONSET) @@ -393,17 +671,17 @@ def test_read_ctf_annotations_smoke_test(): def _read_res4_mag_comp(dsdir): res = mne.io.ctf.res4._read_res4(dsdir) - for ch in res['chs']: - if ch['sensor_type_index'] == CTF.CTFV_REF_MAG_CH: - ch['grad_order_no'] = 1 + for ch in res["chs"]: + if ch["sensor_type_index"] == CTF.CTFV_REF_MAG_CH: + ch["grad_order_no"] = 1 return res def _bad_res4_grad_comp(dsdir): res = mne.io.ctf.res4._read_res4(dsdir) - for ch in res['chs']: - if ch['sensor_type_index'] == CTF.CTFV_MEG_CH: - ch['grad_order_no'] = 1 + for ch in res["chs"]: + if ch["sensor_type_index"] == CTF.CTFV_MEG_CH: + ch["grad_order_no"] = 1 break return res @@ -412,11 +690,10 @@ def _bad_res4_grad_comp(dsdir): def test_missing_res4(tmp_path): """Test that res4 missing is handled gracefully.""" use_ds = tmp_path / ctf_fname_continuous - shutil.copytree(ctf_dir / ctf_fname_continuous, - tmp_path / ctf_fname_continuous) + shutil.copytree(ctf_dir / ctf_fname_continuous, tmp_path / ctf_fname_continuous) read_raw_ctf(use_ds) - os.remove(use_ds / (ctf_fname_continuous[:-2] + 'meg4')) - with pytest.raises(OSError, match='could not find the following'): + os.remove(use_ds / (ctf_fname_continuous[:-2] + "meg4")) + with pytest.raises(OSError, match="could not find the following"): read_raw_ctf(use_ds) @@ -426,33 +703,35 @@ def test_read_ctf_mag_bad_comp(tmp_path, monkeypatch): path = op.join(ctf_dir, ctf_fname_continuous) raw_orig = read_raw_ctf(path) assert raw_orig.compensation_grade == 0 - monkeypatch.setattr(mne.io.ctf.ctf, '_read_res4', _read_res4_mag_comp) + monkeypatch.setattr(mne.io.ctf.ctf, "_read_res4", _read_res4_mag_comp) raw_mag_comp = read_raw_ctf(path) assert raw_mag_comp.compensation_grade == 0 sphere = mne.make_sphere_model() - src = mne.setup_volume_source_space(pos=50., exclude=5., bem=sphere) - assert src[0]['nuse'] == 26 + src = mne.setup_volume_source_space(pos=50.0, exclude=5.0, bem=sphere) + assert src[0]["nuse"] == 26 for grade in (0, 1): raw_orig.apply_gradient_compensation(grade) raw_mag_comp.apply_gradient_compensation(grade) args = (None, src, sphere, True, False) fwd_orig = make_forward_solution(raw_orig.info, *args) fwd_mag_comp = make_forward_solution(raw_mag_comp.info, *args) - assert_allclose(fwd_orig['sol']['data'], fwd_mag_comp['sol']['data']) - monkeypatch.setattr(mne.io.ctf.ctf, '_read_res4', _bad_res4_grad_comp) - with pytest.raises(RuntimeError, match='inconsistent compensation grade'): + assert_allclose(fwd_orig["sol"]["data"], fwd_mag_comp["sol"]["data"]) + monkeypatch.setattr(mne.io.ctf.ctf, "_read_res4", _bad_res4_grad_comp) + with pytest.raises(RuntimeError, match="inconsistent compensation grade"): read_raw_ctf(path) @testing.requires_testing_data def test_invalid_meas_date(monkeypatch): """Test handling of invalid meas_date.""" + def _convert_time_bad(date_str, time_str): - return _convert_time('', '') - monkeypatch.setattr(mne.io.ctf.info, '_convert_time', _convert_time_bad) + return _convert_time("", "") + + monkeypatch.setattr(mne.io.ctf.info, "_convert_time", _convert_time_bad) with catch_logging() as log: raw = read_raw_ctf(ctf_dir / ctf_fname_continuous, verbose=True) log = log.getvalue() - assert 'No date or time found' in log - assert raw.info['meas_date'] == datetime.fromtimestamp(0, tz=timezone.utc) + assert "No date or time found" in log + assert raw.info["meas_date"] == datetime.fromtimestamp(0, tz=timezone.utc) diff --git a/mne/io/ctf/trans.py b/mne/io/ctf/trans.py index 0497518a314..8f7443cfcc1 100644 --- a/mne/io/ctf/trans.py +++ b/mne/io/ctf/trans.py @@ -6,9 +6,15 @@ import numpy as np -from ...transforms import (combine_transforms, invert_transform, Transform, - _quat_to_affine, _fit_matched_points, apply_trans, - get_ras_to_neuromag_trans) +from ...transforms import ( + combine_transforms, + invert_transform, + Transform, + _quat_to_affine, + _fit_matched_points, + apply_trans, + get_ras_to_neuromag_trans, +) from ...utils import logger from ..constants import FIFF from .constants import CTF @@ -16,8 +22,9 @@ def _make_transform_card(fro, to, r_lpa, r_nasion, r_rpa): """Make a transform from cardinal landmarks.""" - return invert_transform(Transform( - to, fro, get_ras_to_neuromag_trans(r_nasion, r_lpa, r_rpa))) + return invert_transform( + Transform(to, fro, get_ras_to_neuromag_trans(r_nasion, r_lpa, r_rpa)) + ) def _quaternion_align(from_frame, to_frame, from_pts, to_pts, diff_tol=1e-4): @@ -26,17 +33,19 @@ def _quaternion_align(from_frame, to_frame, from_pts, to_pts, diff_tol=1e-4): trans = _quat_to_affine(_fit_matched_points(from_pts, to_pts)[0]) # Test the transformation and print the results - logger.info(' Quaternion matching (desired vs. transformed):') + logger.info(" Quaternion matching (desired vs. transformed):") for fro, to in zip(from_pts, to_pts): rr = apply_trans(trans, fro) diff = np.linalg.norm(to - rr) - logger.info(' %7.2f %7.2f %7.2f mm <-> %7.2f %7.2f %7.2f mm ' - '(orig : %7.2f %7.2f %7.2f mm) diff = %8.3f mm' - % (tuple(1000 * to) + tuple(1000 * rr) + - tuple(1000 * fro) + (1000 * diff,))) + logger.info( + " %7.2f %7.2f %7.2f mm <-> %7.2f %7.2f %7.2f mm " + "(orig : %7.2f %7.2f %7.2f mm) diff = %8.3f mm" + % (tuple(1000 * to) + tuple(1000 * rr) + tuple(1000 * fro) + (1000 * diff,)) + ) if diff > diff_tol: - raise RuntimeError('Something is wrong: quaternion matching did ' - 'not work (see above)') + raise RuntimeError( + "Something is wrong: quaternion matching did " "not work (see above)" + ) return Transform(from_frame, to_frame, trans) @@ -46,19 +55,18 @@ def _make_ctf_coord_trans_set(res4, coils): lpa = rpa = nas = T1 = T2 = T3 = T5 = None if coils is not None: for p in coils: - if p['valid'] and (p['coord_frame'] == - FIFF.FIFFV_MNE_COORD_CTF_HEAD): - if lpa is None and p['kind'] == CTF.CTFV_COIL_LPA: + if p["valid"] and (p["coord_frame"] == FIFF.FIFFV_MNE_COORD_CTF_HEAD): + if lpa is None and p["kind"] == CTF.CTFV_COIL_LPA: lpa = p - elif rpa is None and p['kind'] == CTF.CTFV_COIL_RPA: + elif rpa is None and p["kind"] == CTF.CTFV_COIL_RPA: rpa = p - elif nas is None and p['kind'] == CTF.CTFV_COIL_NAS: + elif nas is None and p["kind"] == CTF.CTFV_COIL_NAS: nas = p if lpa is None or rpa is None or nas is None: - raise RuntimeError('Some of the mandatory HPI device-coordinate ' - 'info was not there.') - t = _make_transform_card('head', 'ctf_head', - lpa['r'], nas['r'], rpa['r']) + raise RuntimeError( + "Some of the mandatory HPI device-coordinate " "info was not there." + ) + t = _make_transform_card("head", "ctf_head", lpa["r"], nas["r"], rpa["r"]) T3 = invert_transform(t) # CTF device -> Neuromag device @@ -67,48 +75,58 @@ def _make_ctf_coord_trans_set(res4, coils): # in z direction to get a coordinate system comparable to the Neuromag one # R = np.eye(4) - R[:3, 3] = [0., 0., 0.19] - val = 0.5 * np.sqrt(2.) + R[:3, 3] = [0.0, 0.0, 0.19] + val = 0.5 * np.sqrt(2.0) R[0, 0] = val R[0, 1] = -val R[1, 0] = val R[1, 1] = val - T4 = Transform('ctf_meg', 'meg', R) + T4 = Transform("ctf_meg", "meg", R) # CTF device -> CTF head # We need to make the implicit transform explicit! h_pts = dict() d_pts = dict() - kinds = (CTF.CTFV_COIL_LPA, CTF.CTFV_COIL_RPA, CTF.CTFV_COIL_NAS, - CTF.CTFV_COIL_SPARE) + kinds = ( + CTF.CTFV_COIL_LPA, + CTF.CTFV_COIL_RPA, + CTF.CTFV_COIL_NAS, + CTF.CTFV_COIL_SPARE, + ) if coils is not None: for p in coils: - if p['valid']: - if p['coord_frame'] == FIFF.FIFFV_MNE_COORD_CTF_HEAD: + if p["valid"]: + if p["coord_frame"] == FIFF.FIFFV_MNE_COORD_CTF_HEAD: for kind in kinds: - if kind not in h_pts and p['kind'] == kind: - h_pts[kind] = p['r'] - elif p['coord_frame'] == FIFF.FIFFV_MNE_COORD_CTF_DEVICE: + if kind not in h_pts and p["kind"] == kind: + h_pts[kind] = p["r"] + elif p["coord_frame"] == FIFF.FIFFV_MNE_COORD_CTF_DEVICE: for kind in kinds: - if kind not in d_pts and p['kind'] == kind: - d_pts[kind] = p['r'] + if kind not in d_pts and p["kind"] == kind: + d_pts[kind] = p["r"] if any(kind not in h_pts for kind in kinds[:-1]): - raise RuntimeError('Some of the mandatory HPI device-coordinate ' - 'info was not there.') + raise RuntimeError( + "Some of the mandatory HPI device-coordinate " "info was not there." + ) if any(kind not in d_pts for kind in kinds[:-1]): - raise RuntimeError('Some of the mandatory HPI head-coordinate ' - 'info was not there.') - use_kinds = [kind for kind in kinds - if (kind in h_pts and kind in d_pts)] + raise RuntimeError( + "Some of the mandatory HPI head-coordinate " "info was not there." + ) + use_kinds = [kind for kind in kinds if (kind in h_pts and kind in d_pts)] r_head = np.array([h_pts[kind] for kind in use_kinds]) r_dev = np.array([d_pts[kind] for kind in use_kinds]) - T2 = _quaternion_align('ctf_meg', 'ctf_head', r_dev, r_head) + T2 = _quaternion_align("ctf_meg", "ctf_head", r_dev, r_head) # The final missing transform if T3 is not None and T2 is not None: - T5 = combine_transforms(T2, T3, 'ctf_meg', 'head') - T1 = combine_transforms(invert_transform(T4), T5, 'meg', 'head') - s = dict(t_dev_head=T1, t_ctf_dev_ctf_head=T2, t_ctf_head_head=T3, - t_ctf_dev_dev=T4, t_ctf_dev_head=T5) - logger.info(' Coordinate transformations established.') + T5 = combine_transforms(T2, T3, "ctf_meg", "head") + T1 = combine_transforms(invert_transform(T4), T5, "meg", "head") + s = dict( + t_dev_head=T1, + t_ctf_dev_ctf_head=T2, + t_ctf_head_head=T3, + t_ctf_dev_dev=T4, + t_ctf_dev_head=T5, + ) + logger.info(" Coordinate transformations established.") return s diff --git a/mne/io/ctf_comp.py b/mne/io/ctf_comp.py index 04198e45c58..a16d1c79607 100644 --- a/mne/io/ctf_comp.py +++ b/mne/io/ctf_comp.py @@ -19,39 +19,41 @@ def _add_kind(one): """Convert CTF kind to MNE kind.""" - if one['ctfkind'] == int('47314252', 16): - one['kind'] = 1 - elif one['ctfkind'] == int('47324252', 16): - one['kind'] = 2 - elif one['ctfkind'] == int('47334252', 16): - one['kind'] = 3 + if one["ctfkind"] == int("47314252", 16): + one["kind"] = 1 + elif one["ctfkind"] == int("47324252", 16): + one["kind"] = 2 + elif one["ctfkind"] == int("47334252", 16): + one["kind"] = 3 else: - one['kind'] = int(one['ctfkind']) + one["kind"] = int(one["ctfkind"]) -def _calibrate_comp(comp, chs, row_names, col_names, - mult_keys=('range', 'cal'), flip=False): +def _calibrate_comp( + comp, chs, row_names, col_names, mult_keys=("range", "cal"), flip=False +): """Get row and column cals.""" - ch_names = [c['ch_name'] for c in chs] + ch_names = [c["ch_name"] for c in chs] row_cals = np.zeros(len(row_names)) col_cals = np.zeros(len(col_names)) - for names, cals, inv in zip((row_names, col_names), (row_cals, col_cals), - (False, True)): + for names, cals, inv in zip( + (row_names, col_names), (row_cals, col_cals), (False, True) + ): for ii in range(len(cals)): p = ch_names.count(names[ii]) if p != 1: - raise RuntimeError('Channel %s does not appear exactly once ' - 'in data, found %d instance%s' - % (names[ii], p, _pl(p))) + raise RuntimeError( + "Channel %s does not appear exactly once " + "in data, found %d instance%s" % (names[ii], p, _pl(p)) + ) idx = ch_names.index(names[ii]) val = chs[idx][mult_keys[0]] * chs[idx][mult_keys[1]] - val = float(1. / val) if inv else float(val) - val = 1. / val if flip else val + val = float(1.0 / val) if inv else float(val) + val = 1.0 / val if flip else val cals[ii] = val - comp['rowcals'] = row_cals - comp['colcals'] = col_cals - comp['data']['data'] = (row_cals[:, None] * - comp['data']['data'] * col_cals[None, :]) + comp["rowcals"] = row_cals + comp["colcals"] = col_cals + comp["data"]["data"] = row_cals[:, None] * comp["data"]["data"] * col_cals[None, :] @verbose @@ -99,6 +101,7 @@ def _read_ctf_comp(fid, node, chs, ch_names_mapping): The compensation data """ from .meas_info import _rename_comps + ch_names_mapping = dict() if ch_names_mapping is None else ch_names_mapping compdata = [] comps = dir_tree_find(node, FIFF.FIFFB_MNE_CTF_COMP_DATA) @@ -106,22 +109,22 @@ def _read_ctf_comp(fid, node, chs, ch_names_mapping): for node in comps: # Read the data we need mat = _read_named_matrix(fid, node, FIFF.FIFF_MNE_CTF_COMP_DATA) - for p in range(node['nent']): - kind = node['directory'][p].kind - pos = node['directory'][p].pos + for p in range(node["nent"]): + kind = node["directory"][p].kind + pos = node["directory"][p].pos if kind == FIFF.FIFF_MNE_CTF_COMP_KIND: tag = read_tag(fid, pos) break else: - raise Exception('Compensation type not found') + raise Exception("Compensation type not found") # Get the compensation kind and map it to a simple number one = dict(ctfkind=tag.data.item()) del tag _add_kind(one) - for p in range(node['nent']): - kind = node['directory'][p].kind - pos = node['directory'][p].pos + for p in range(node["nent"]): + kind = node["directory"][p].kind + pos = node["directory"][p].pos if kind == FIFF.FIFF_MNE_CTF_COMP_CALIBRATED: tag = read_tag(fid, pos) calibrated = tag.data @@ -129,20 +132,20 @@ def _read_ctf_comp(fid, node, chs, ch_names_mapping): else: calibrated = False - one['save_calibrated'] = bool(calibrated) - one['data'] = mat + one["save_calibrated"] = bool(calibrated) + one["data"] = mat _rename_comps([one], ch_names_mapping) if not calibrated: # Calibrate... - _calibrate_comp(one, chs, mat['row_names'], mat['col_names']) + _calibrate_comp(one, chs, mat["row_names"], mat["col_names"]) else: - one['rowcals'] = np.ones(mat['data'].shape[0], dtype=np.float64) - one['colcals'] = np.ones(mat['data'].shape[1], dtype=np.float64) + one["rowcals"] = np.ones(mat["data"].shape[0], dtype=np.float64) + one["colcals"] = np.ones(mat["data"].shape[1], dtype=np.float64) compdata.append(one) if len(compdata) > 0: - logger.info(' Read %d compensation matrices' % len(compdata)) + logger.info(" Read %d compensation matrices" % len(compdata)) return compdata @@ -150,6 +153,7 @@ def _read_ctf_comp(fid, node, chs, ch_names_mapping): ############################################################################### # Writing + def write_ctf_comp(fid, comps): """Write the CTF compensation data into a fif file. @@ -169,18 +173,20 @@ def write_ctf_comp(fid, comps): for comp in comps: start_block(fid, FIFF.FIFFB_MNE_CTF_COMP_DATA) # Write the compensation kind - write_int(fid, FIFF.FIFF_MNE_CTF_COMP_KIND, comp['ctfkind']) - if comp.get('save_calibrated', False): - write_int(fid, FIFF.FIFF_MNE_CTF_COMP_CALIBRATED, - comp['save_calibrated']) + write_int(fid, FIFF.FIFF_MNE_CTF_COMP_KIND, comp["ctfkind"]) + if comp.get("save_calibrated", False): + write_int(fid, FIFF.FIFF_MNE_CTF_COMP_CALIBRATED, comp["save_calibrated"]) - if not comp.get('save_calibrated', True): + if not comp.get("save_calibrated", True): # Undo calibration comp = deepcopy(comp) - data = ((1. / comp['rowcals'][:, None]) * comp['data']['data'] * - (1. / comp['colcals'][None, :])) - comp['data']['data'] = data - write_named_matrix(fid, FIFF.FIFF_MNE_CTF_COMP_DATA, comp['data']) + data = ( + (1.0 / comp["rowcals"][:, None]) + * comp["data"]["data"] + * (1.0 / comp["colcals"][None, :]) + ) + comp["data"]["data"] = data + write_named_matrix(fid, FIFF.FIFF_MNE_CTF_COMP_DATA, comp["data"]) end_block(fid, FIFF.FIFFB_MNE_CTF_COMP_DATA) end_block(fid, FIFF.FIFFB_MNE_CTF_COMP) diff --git a/mne/io/curry/curry.py b/mne/io/curry/curry.py index 6c1e5d79821..619cb20278b 100644 --- a/mne/io/curry/curry.py +++ b/mne/io/curry/curry.py @@ -21,9 +21,15 @@ from ..constants import FIFF from ..ctf.trans import _quaternion_align from ...surface import _normal_orth -from ...transforms import (apply_trans, Transform, get_ras_to_neuromag_trans, - combine_transforms, invert_transform, - _angle_between_quats, rot_to_quat) +from ...transforms import ( + apply_trans, + Transform, + get_ras_to_neuromag_trans, + combine_transforms, + invert_transform, + _angle_between_quats, + rot_to_quat, +) from ...utils import check_fname, logger, verbose, _check_fname from ...annotations import Annotations @@ -43,19 +49,26 @@ "events_cef": ".cdt.cef", "events_ceo": ".cdt.ceo", "hpi": ".cdt.hpi", - } + }, } CHANTYPES = {"meg": "_MAG1", "eeg": "", "misc": "_OTHERS"} -FIFFV_CHANTYPES = {"meg": FIFF.FIFFV_MEG_CH, "eeg": FIFF.FIFFV_EEG_CH, - "misc": FIFF.FIFFV_MISC_CH} -FIFFV_COILTYPES = {"meg": FIFF.FIFFV_COIL_CTF_GRAD, "eeg": FIFF.FIFFV_COIL_EEG, - "misc": FIFF.FIFFV_COIL_NONE} +FIFFV_CHANTYPES = { + "meg": FIFF.FIFFV_MEG_CH, + "eeg": FIFF.FIFFV_EEG_CH, + "misc": FIFF.FIFFV_MISC_CH, +} +FIFFV_COILTYPES = { + "meg": FIFF.FIFFV_COIL_CTF_GRAD, + "eeg": FIFF.FIFFV_COIL_EEG, + "misc": FIFF.FIFFV_COIL_NONE, +} SI_UNITS = dict(V=FIFF.FIFF_UNIT_V, T=FIFF.FIFF_UNIT_T) SI_UNIT_SCALE = dict(c=1e-2, m=1e-3, u=1e-6, µ=1e-6, n=1e-9, p=1e-12, f=1e-15) -CurryParameters = namedtuple('CurryParameters', - 'n_samples, sfreq, is_ascii, unit_dict, ' - 'n_chans, dt_start, chanidx_in_file') +CurryParameters = namedtuple( + "CurryParameters", + "n_samples, sfreq, is_ascii, unit_dict, " "n_chans, dt_start, chanidx_in_file", +) def _get_curry_version(file_extension): @@ -65,24 +78,26 @@ def _get_curry_version(file_extension): def _get_curry_file_structure(fname, required=()): """Store paths to a dict and check for required files.""" - _msg = "The following required files cannot be found: {0}.\nPlease make " \ - "sure all required files are located in the same directory as {1}." - fname = Path(_check_fname(fname, 'read', True, 'fname')) + _msg = ( + "The following required files cannot be found: {0}.\nPlease make " + "sure all required files are located in the same directory as {1}." + ) + fname = Path(_check_fname(fname, "read", True, "fname")) # we don't use os.path.splitext to also handle extensions like .cdt.dpa # this won't handle a dot in the filename, but it should handle it in # the parent directories - fname_base = fname.name.split('.', maxsplit=1)[0] - ext = fname.name[len(fname_base):] + fname_base = fname.name.split(".", maxsplit=1)[0] + ext = fname.name[len(fname_base) :] fname_base = str(fname) - fname_base = fname_base[:len(fname_base) - len(ext)] + fname_base = fname_base[: len(fname_base) - len(ext)] del fname version = _get_curry_version(ext) my_curry = dict() - for key in ('info', 'data', 'labels', 'events_cef', 'events_ceo', 'hpi'): + for key in ("info", "data", "labels", "events_cef", "events_ceo", "hpi"): fname = fname_base + FILE_EXTENSIONS[version][key] if op.isfile(fname): - _key = 'events' if key.startswith('events') else key + _key = "events" if key.startswith("events") else key my_curry[_key] = fname missing = [field for field in required if field not in my_curry] @@ -140,20 +155,38 @@ def _read_curry_lines(fname, regex_list): def _read_curry_parameters(fname): """Extract Curry params from a Curry info file.""" - _msg_match = "The sampling frequency and the time steps extracted from " \ - "the parameter file do not match." + _msg_match = ( + "The sampling frequency and the time steps extracted from " + "the parameter file do not match." + ) _msg_invalid = "sfreq must be greater than 0. Got sfreq = {0}" - var_names = ['NumSamples', 'SampleFreqHz', - 'DataFormat', 'SampleTimeUsec', - 'NumChannels', - 'StartYear', 'StartMonth', 'StartDay', 'StartHour', - 'StartMin', 'StartSec', 'StartMillisec', - 'NUM_SAMPLES', 'SAMPLE_FREQ_HZ', - 'DATA_FORMAT', 'SAMPLE_TIME_USEC', - 'NUM_CHANNELS', - 'START_YEAR', 'START_MONTH', 'START_DAY', 'START_HOUR', - 'START_MIN', 'START_SEC', 'START_MILLISEC'] + var_names = [ + "NumSamples", + "SampleFreqHz", + "DataFormat", + "SampleTimeUsec", + "NumChannels", + "StartYear", + "StartMonth", + "StartDay", + "StartHour", + "StartMin", + "StartSec", + "StartMillisec", + "NUM_SAMPLES", + "SAMPLE_FREQ_HZ", + "DATA_FORMAT", + "SAMPLE_TIME_USEC", + "NUM_CHANNELS", + "START_YEAR", + "START_MONTH", + "START_DAY", + "START_HOUR", + "START_MIN", + "START_SEC", + "START_MILLISEC", + ] param_dict = dict() unit_dict = dict() @@ -166,14 +199,15 @@ def _read_curry_parameters(fname): for type in CHANTYPES: if "DEVICE_PARAMETERS" + CHANTYPES[type] + " START" in line: data_unit = next(fid) - unit_dict[type] = data_unit.replace(" ", "") \ - .replace("\n", "").split("=")[-1] + unit_dict[type] = ( + data_unit.replace(" ", "").replace("\n", "").split("=")[-1] + ) # look for CHAN_IN_FILE sections, which may or may not exist; issue #8391 types = ["meg", "eeg", "misc"] - chanidx_in_file = _read_curry_lines(fname, - ["CHAN_IN_FILE" + - CHANTYPES[key] for key in types]) + chanidx_in_file = _read_curry_lines( + fname, ["CHAN_IN_FILE" + CHANTYPES[key] for key in types] + ) n_samples = int(param_dict["numsamples"]) sfreq = float(param_dict["samplefreqhz"]) @@ -181,14 +215,16 @@ def _read_curry_parameters(fname): is_ascii = param_dict["dataformat"] == "ASCII" n_channels = int(param_dict["numchannels"]) try: - dt_start = datetime(int(param_dict["startyear"]), - int(param_dict["startmonth"]), - int(param_dict["startday"]), - int(param_dict["starthour"]), - int(param_dict["startmin"]), - int(param_dict["startsec"]), - int(param_dict["startmillisec"]) * 1000, - timezone.utc) + dt_start = datetime( + int(param_dict["startyear"]), + int(param_dict["startmonth"]), + int(param_dict["startday"]), + int(param_dict["starthour"]), + int(param_dict["startmin"]), + int(param_dict["startsec"]), + int(param_dict["startmillisec"]) * 1000, + timezone.utc, + ) # Note that the time zone information is not stored in the Curry info # file, and it seems the start time info is in the local timezone # of the acquisition system (which is unknown); therefore, just set @@ -209,88 +245,102 @@ def _read_curry_parameters(fname): if true_sfreq <= 0: raise ValueError(_msg_invalid.format(true_sfreq)) - return CurryParameters(n_samples, true_sfreq, is_ascii, unit_dict, - n_channels, dt_start, chanidx_in_file) + return CurryParameters( + n_samples, + true_sfreq, + is_ascii, + unit_dict, + n_channels, + dt_start, + chanidx_in_file, + ) def _read_curry_info(curry_paths): """Extract info from curry parameter files.""" - curry_params = _read_curry_parameters(curry_paths['info']) + curry_params = _read_curry_parameters(curry_paths["info"]) R = np.eye(4) R[[0, 1], [0, 1]] = -1 # rotate 180 deg # shift down and back # (chosen by eyeballing to make the CTF helmet look roughly correct) - R[:3, 3] = [0., -0.015, -0.12] - curry_dev_dev_t = Transform('ctf_meg', 'meg', R) + R[:3, 3] = [0.0, -0.015, -0.12] + curry_dev_dev_t = Transform("ctf_meg", "meg", R) # read labels from label files - label_fname = curry_paths['labels'] + label_fname = curry_paths["labels"] types = ["meg", "eeg", "misc"] - labels = _read_curry_lines(label_fname, - ["LABELS" + CHANTYPES[key] for key in types]) - sensors = _read_curry_lines(label_fname, - ["SENSORS" + CHANTYPES[key] for key in types]) - normals = _read_curry_lines(label_fname, - ['NORMALS' + CHANTYPES[key] for key in types]) + labels = _read_curry_lines( + label_fname, ["LABELS" + CHANTYPES[key] for key in types] + ) + sensors = _read_curry_lines( + label_fname, ["SENSORS" + CHANTYPES[key] for key in types] + ) + normals = _read_curry_lines( + label_fname, ["NORMALS" + CHANTYPES[key] for key in types] + ) assert len(labels) == len(sensors) == len(normals) all_chans = list() dig_ch_pos = dict() for key in ["meg", "eeg", "misc"]: - chanidx_is_explicit = (len(curry_params.chanidx_in_file["CHAN_IN_FILE" - + CHANTYPES[key]]) > 0) # channel index + chanidx_is_explicit = ( + len(curry_params.chanidx_in_file["CHAN_IN_FILE" + CHANTYPES[key]]) > 0 + ) # channel index # position in the datafile may or may not be explicitly declared, # based on the CHAN_IN_FILE section in info file for ind, chan in enumerate(labels["LABELS" + CHANTYPES[key]]): - chanidx = len(all_chans) + 1 # by default, just assume the + chanidx = len(all_chans) + 1 # by default, just assume the # channel index in the datafile is in order of the channel # names as we found them in the labels file if chanidx_is_explicit: # but, if explicitly declared, use # that index number - chanidx = int(curry_params.chanidx_in_file["CHAN_IN_FILE" - + CHANTYPES[key]][ind]) - if chanidx <= 0: # if chanidx was explicitly declared to be ' 0', + chanidx = int( + curry_params.chanidx_in_file["CHAN_IN_FILE" + CHANTYPES[key]][ind] + ) + if chanidx <= 0: # if chanidx was explicitly declared to be ' 0', # it means the channel is not actually saved in the data file # (e.g. the "Ref" channel), so don't add it to our list. # Git issue #8391 continue - ch = {"ch_name": chan, - "unit": curry_params.unit_dict[key], - "kind": FIFFV_CHANTYPES[key], - "coil_type": FIFFV_COILTYPES[key], - "ch_idx": chanidx - } + ch = { + "ch_name": chan, + "unit": curry_params.unit_dict[key], + "kind": FIFFV_CHANTYPES[key], + "coil_type": FIFFV_COILTYPES[key], + "ch_idx": chanidx, + } if key == "eeg": loc = np.array(sensors["SENSORS" + CHANTYPES[key]][ind], float) # XXX just the sensor, where is ref (next 3)? assert loc.shape == (3,) - loc /= 1000. # to meters + loc /= 1000.0 # to meters loc = np.concatenate([loc, np.zeros(9)]) - ch['loc'] = loc + ch["loc"] = loc # XXX need to check/ensure this - ch['coord_frame'] = FIFF.FIFFV_COORD_HEAD + ch["coord_frame"] = FIFF.FIFFV_COORD_HEAD dig_ch_pos[chan] = loc[:3] - elif key == 'meg': + elif key == "meg": pos = np.array(sensors["SENSORS" + CHANTYPES[key]][ind], float) - pos /= 1000. # to meters + pos /= 1000.0 # to meters pos = pos[:3] # just the inner coil pos = apply_trans(curry_dev_dev_t, pos) nn = np.array(normals["NORMALS" + CHANTYPES[key]][ind], float) - assert np.isclose(np.linalg.norm(nn), 1., atol=1e-4) + assert np.isclose(np.linalg.norm(nn), 1.0, atol=1e-4) nn /= np.linalg.norm(nn) nn = apply_trans(curry_dev_dev_t, nn, move=False) trans = np.eye(4) trans[:3, 3] = pos trans[:3, :3] = _normal_orth(nn).T - ch['loc'] = _coil_trans_to_loc(trans) - ch['coord_frame'] = FIFF.FIFFV_COORD_DEVICE + ch["loc"] = _coil_trans_to_loc(trans) + ch["coord_frame"] = FIFF.FIFFV_COORD_DEVICE all_chans.append(ch) dig = _make_dig_points( - dig_ch_pos=dig_ch_pos, coord_frame='head', add_missing_fiducials=True) + dig_ch_pos=dig_ch_pos, coord_frame="head", add_missing_fiducials=True + ) del dig_ch_pos ch_count = len(all_chans) - assert (ch_count == curry_params.n_chans) # ensure that we have assembled + assert ch_count == curry_params.n_chans # ensure that we have assembled # the same number of channels as declared in the info (.DAP) file in the # DATA_PARAMETERS section. Git issue #8391 @@ -298,49 +348,51 @@ def _read_curry_info(curry_paths): # recorded in the datafile. In general they most likely are already in # the correct order, but if the channel index in the data file was # explicitly declared we might as well use it. - all_chans = sorted(all_chans, key=lambda ch: ch['ch_idx']) + all_chans = sorted(all_chans, key=lambda ch: ch["ch_idx"]) ch_names = [chan["ch_name"] for chan in all_chans] info = create_info(ch_names, curry_params.sfreq) with info._unlock(): - info['meas_date'] = curry_params.dt_start # for Git issue #8398 - info['dig'] = dig + info["meas_date"] = curry_params.dt_start # for Git issue #8398 + info["dig"] = dig _make_trans_dig(curry_paths, info, curry_dev_dev_t) for ind, ch_dict in enumerate(info["chs"]): - all_chans[ind].pop('ch_idx') + all_chans[ind].pop("ch_idx") ch_dict.update(all_chans[ind]) - assert ch_dict['loc'].shape == (12,) - ch_dict['unit'] = SI_UNITS[all_chans[ind]['unit'][1]] - ch_dict['cal'] = SI_UNIT_SCALE[all_chans[ind]['unit'][0]] + assert ch_dict["loc"].shape == (12,) + ch_dict["unit"] = SI_UNITS[all_chans[ind]["unit"][1]] + ch_dict["cal"] = SI_UNIT_SCALE[all_chans[ind]["unit"][0]] return info, curry_params.n_samples, curry_params.is_ascii -_card_dict = {'Left ear': FIFF.FIFFV_POINT_LPA, - 'Nasion': FIFF.FIFFV_POINT_NASION, - 'Right ear': FIFF.FIFFV_POINT_RPA} +_card_dict = { + "Left ear": FIFF.FIFFV_POINT_LPA, + "Nasion": FIFF.FIFFV_POINT_NASION, + "Right ear": FIFF.FIFFV_POINT_RPA, +} def _make_trans_dig(curry_paths, info, curry_dev_dev_t): # Coordinate frame transformations and definitions - no_msg = 'Leaving device<->head transform as None' - info['dev_head_t'] = None - label_fname = curry_paths['labels'] - key = 'LANDMARKS' + CHANTYPES['meg'] + no_msg = "Leaving device<->head transform as None" + info["dev_head_t"] = None + label_fname = curry_paths["labels"] + key = "LANDMARKS" + CHANTYPES["meg"] lm = _read_curry_lines(label_fname, [key])[key] lm = np.array(lm, float) lm.shape = (-1, 3) if len(lm) == 0: # no dig - logger.info(no_msg + ' (no landmarks found)') + logger.info(no_msg + " (no landmarks found)") return - lm /= 1000. - key = 'LM_REMARKS' + CHANTYPES['meg'] + lm /= 1000.0 + key = "LM_REMARKS" + CHANTYPES["meg"] remarks = _read_curry_lines(label_fname, [key])[key] assert len(remarks) == len(lm) with info._unlock(): - info['dig'] = list() + info["dig"] = list() cards = dict() for remark, r in zip(remarks, lm): kind = ident = None @@ -348,70 +400,83 @@ def _make_trans_dig(curry_paths, info, curry_dev_dev_t): kind = FIFF.FIFFV_POINT_CARDINAL ident = _card_dict[remark] cards[ident] = r - elif remark.startswith('HPI'): + elif remark.startswith("HPI"): kind = FIFF.FIFFV_POINT_HPI ident = int(remark[3:]) - 1 if kind is not None: - info['dig'].append(dict( - kind=kind, ident=ident, r=r, - coord_frame=FIFF.FIFFV_COORD_UNKNOWN)) + info["dig"].append( + dict(kind=kind, ident=ident, r=r, coord_frame=FIFF.FIFFV_COORD_UNKNOWN) + ) with info._unlock(): - info['dig'].sort(key=lambda x: (x['kind'], x['ident'])) + info["dig"].sort(key=lambda x: (x["kind"], x["ident"])) has_cards = len(cards) == 3 - has_hpi = 'hpi' in curry_paths + has_hpi = "hpi" in curry_paths if has_cards and has_hpi: # have all three - logger.info('Composing device<->head transformation from dig points') - hpi_u = np.array([d['r'] for d in info['dig'] - if d['kind'] == FIFF.FIFFV_POINT_HPI], float) - hpi_c = np.ascontiguousarray( - _first_hpi(curry_paths['hpi'])[:len(hpi_u), 1:4]) - unknown_curry_t = _quaternion_align( - 'unknown', 'ctf_meg', hpi_u, hpi_c, 1e-2) - angle = np.rad2deg(_angle_between_quats( - np.zeros(3), rot_to_quat(unknown_curry_t['trans'][:3, :3]))) - dist = 1000 * np.linalg.norm(unknown_curry_t['trans'][:3, 3]) - logger.info(' Fit a %0.1f° rotation, %0.1f mm translation' - % (angle, dist)) + logger.info("Composing device<->head transformation from dig points") + hpi_u = np.array( + [d["r"] for d in info["dig"] if d["kind"] == FIFF.FIFFV_POINT_HPI], float + ) + hpi_c = np.ascontiguousarray(_first_hpi(curry_paths["hpi"])[: len(hpi_u), 1:4]) + unknown_curry_t = _quaternion_align("unknown", "ctf_meg", hpi_u, hpi_c, 1e-2) + angle = np.rad2deg( + _angle_between_quats( + np.zeros(3), rot_to_quat(unknown_curry_t["trans"][:3, :3]) + ) + ) + dist = 1000 * np.linalg.norm(unknown_curry_t["trans"][:3, 3]) + logger.info(" Fit a %0.1f° rotation, %0.1f mm translation" % (angle, dist)) unknown_dev_t = combine_transforms( - unknown_curry_t, curry_dev_dev_t, 'unknown', 'meg') + unknown_curry_t, curry_dev_dev_t, "unknown", "meg" + ) unknown_head_t = Transform( - 'unknown', 'head', + "unknown", + "head", get_ras_to_neuromag_trans( - *(cards[key] for key in (FIFF.FIFFV_POINT_NASION, - FIFF.FIFFV_POINT_LPA, - FIFF.FIFFV_POINT_RPA)))) + *( + cards[key] + for key in ( + FIFF.FIFFV_POINT_NASION, + FIFF.FIFFV_POINT_LPA, + FIFF.FIFFV_POINT_RPA, + ) + ) + ), + ) with info._unlock(): - info['dev_head_t'] = combine_transforms( - invert_transform(unknown_dev_t), unknown_head_t, 'meg', 'head') - for d in info['dig']: - d.update(coord_frame=FIFF.FIFFV_COORD_HEAD, - r=apply_trans(unknown_head_t, d['r'])) + info["dev_head_t"] = combine_transforms( + invert_transform(unknown_dev_t), unknown_head_t, "meg", "head" + ) + for d in info["dig"]: + d.update( + coord_frame=FIFF.FIFFV_COORD_HEAD, + r=apply_trans(unknown_head_t, d["r"]), + ) else: if has_cards: - no_msg += ' (no .hpi file found)' + no_msg += " (no .hpi file found)" elif has_hpi: - no_msg += ' (not all cardinal points found)' + no_msg += " (not all cardinal points found)" else: - no_msg += ' (neither cardinal points nor .hpi file found)' + no_msg += " (neither cardinal points nor .hpi file found)" logger.info(no_msg) def _first_hpi(fname): # Get the first HPI result - with open(fname, 'r') as fid: + with open(fname, "r") as fid: for line in fid: line = line.strip() - if any(x in line for x in ('FileVersion', 'NumCoils')) or not line: + if any(x in line for x in ("FileVersion", "NumCoils")) or not line: continue hpi = np.array(line.split(), float) break else: - raise RuntimeError('Could not find valid HPI in %s' % (fname,)) + raise RuntimeError("Could not find valid HPI in %s" % (fname,)) # t is the first entry assert hpi.ndim == 1 hpi = hpi[1:] hpi.shape = (-1, 5) - hpi /= 1000. + hpi /= 1000.0 return hpi @@ -429,8 +494,12 @@ def _read_events_curry(fname): events : ndarray, shape (n_events, 3) The array of events. """ - check_fname(fname, 'curry event', ('.cef', '.ceo', '.cdt.cef', '.cdt.ceo'), - endings_err=('.cef', '.ceo', '.cdt.cef', '.cdt.ceo')) + check_fname( + fname, + "curry event", + (".cef", ".ceo", ".cdt.cef", ".cdt.ceo"), + endings_err=(".cef", ".ceo", ".cdt.cef", ".cdt.ceo"), + ) events_dict = _read_curry_lines(fname, ["NUMBER_LIST"]) # The first 3 column seem to contain the event information @@ -439,7 +508,7 @@ def _read_events_curry(fname): return curry_events -def _read_annotations_curry(fname, sfreq='auto'): +def _read_annotations_curry(fname, sfreq="auto"): r"""Read events from Curry event files. Parameters @@ -457,12 +526,12 @@ def _read_annotations_curry(fname, sfreq='auto'): annot : instance of Annotations | None The annotations. """ - required = ["events", "info"] if sfreq == 'auto' else ["events"] + required = ["events", "info"] if sfreq == "auto" else ["events"] curry_paths = _get_curry_file_structure(fname, required) - events = _read_events_curry(curry_paths['events']) + events = _read_events_curry(curry_paths["events"]) - if sfreq == 'auto': - sfreq = _read_curry_parameters(curry_paths['info']).sfreq + if sfreq == "auto": + sfreq = _read_curry_parameters(curry_paths["info"]).sfreq onset = events[:, 0] / sfreq duration = np.zeros(events.shape[0]) @@ -515,11 +584,11 @@ class RawCurry(BaseRaw): @verbose def __init__(self, fname, preload=False, verbose=None): - curry_paths = _get_curry_file_structure( - fname, required=["info", "data", "labels"]) + fname, required=["info", "data", "labels"] + ) - data_fname = op.abspath(curry_paths['data']) + data_fname = op.abspath(curry_paths["data"]) info, n_samples, is_ascii = _read_curry_info(curry_paths) @@ -527,28 +596,38 @@ def __init__(self, fname, preload=False, verbose=None): raw_extras = dict(is_ascii=is_ascii) super(RawCurry, self).__init__( - info, preload, filenames=[data_fname], last_samps=last_samps, - orig_format='int', raw_extras=[raw_extras], verbose=verbose) - - if 'events' in curry_paths: - logger.info('Event file found. Extracting Annotations from' - ' %s...' % curry_paths['events']) - annots = _read_annotations_curry(curry_paths['events'], - sfreq=self.info["sfreq"]) + info, + preload, + filenames=[data_fname], + last_samps=last_samps, + orig_format="int", + raw_extras=[raw_extras], + verbose=verbose, + ) + + if "events" in curry_paths: + logger.info( + "Event file found. Extracting Annotations from" + " %s..." % curry_paths["events"] + ) + annots = _read_annotations_curry( + curry_paths["events"], sfreq=self.info["sfreq"] + ) self.set_annotations(annots) else: - logger.info('Event file not found. No Annotations set.') + logger.info("Event file not found. No Annotations set.") def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): """Read a chunk of raw data.""" - if self._raw_extras[fi]['is_ascii']: + if self._raw_extras[fi]["is_ascii"]: if isinstance(idx, slice): idx = np.arange(idx.start, idx.stop) block = np.loadtxt( - self._filenames[0], skiprows=start, max_rows=stop - start, - ndmin=2).T + self._filenames[0], skiprows=start, max_rows=stop - start, ndmin=2 + ).T _mult_cal_one(data, block, idx, cals, mult) else: _read_segments_file( - self, data, idx, fi, start, stop, cals, mult, dtype=" 0: + if len(edf_info["tal_idx"]) > 0: # Read TAL data exploiting the header info (no regexp) idx = np.empty(0, int) tal_data = self._read_segment_file( - np.empty((0, self.n_times)), idx, 0, 0, int(self.n_times), - np.ones((len(idx), 1)), None) + np.empty((0, self.n_times)), + idx, + 0, + 0, + int(self.n_times), + np.ones((len(idx), 1)), + None, + ) onset, duration, desc = _read_annotations_edf( tal_data[0], encoding=encoding, ) - self.set_annotations(Annotations(onset=onset, duration=duration, - description=desc, orig_time=None)) + self.set_annotations( + Annotations( + onset=onset, duration=duration, description=desc, orig_time=None + ) + ) def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): """Read a chunk of raw data.""" - return _read_segment_file(data, idx, fi, start, stop, - self._raw_extras[fi], self._filenames[fi], - cals, mult) + return _read_segment_file( + data, + idx, + fi, + start, + stop, + self._raw_extras[fi], + self._filenames[fi], + cals, + mult, + ) @fill_doc @@ -240,48 +276,70 @@ class RawGDF(BaseRaw): """ @verbose - def __init__(self, input_fname, eog=None, misc=None, - stim_channel='auto', exclude=(), preload=False, include=None, - verbose=None): - logger.info('Extracting EDF parameters from {}...'.format(input_fname)) + def __init__( + self, + input_fname, + eog=None, + misc=None, + stim_channel="auto", + exclude=(), + preload=False, + include=None, + verbose=None, + ): + logger.info("Extracting EDF parameters from {}...".format(input_fname)) input_fname = os.path.abspath(input_fname) - info, edf_info, orig_units = _get_info(input_fname, stim_channel, eog, - misc, exclude, True, preload, - include) - logger.info('Creating raw.info structure...') + info, edf_info, orig_units = _get_info( + input_fname, stim_channel, eog, misc, exclude, True, preload, include + ) + logger.info("Creating raw.info structure...") # Raw attributes - last_samps = [edf_info['nsamples'] - 1] - super().__init__(info, preload, filenames=[input_fname], - raw_extras=[edf_info], last_samps=last_samps, - orig_format='int', orig_units=orig_units, - verbose=verbose) + last_samps = [edf_info["nsamples"] - 1] + super().__init__( + info, + preload, + filenames=[input_fname], + raw_extras=[edf_info], + last_samps=last_samps, + orig_format="int", + orig_units=orig_units, + verbose=verbose, + ) # Read annotations from file and set it - onset, duration, desc = _get_annotations_gdf(edf_info, - self.info['sfreq']) + onset, duration, desc = _get_annotations_gdf(edf_info, self.info["sfreq"]) - self.set_annotations(Annotations(onset=onset, duration=duration, - description=desc, orig_time=None)) + self.set_annotations( + Annotations( + onset=onset, duration=duration, description=desc, orig_time=None + ) + ) def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): """Read a chunk of raw data.""" - return _read_segment_file(data, idx, fi, start, stop, - self._raw_extras[fi], self._filenames[fi], - cals, mult) + return _read_segment_file( + data, + idx, + fi, + start, + stop, + self._raw_extras[fi], + self._filenames[fi], + cals, + mult, + ) def _read_ch(fid, subtype, samp, dtype_byte, dtype=None): """Read a number of samples for a single channel.""" # BDF - if subtype == 'bdf': + if subtype == "bdf": ch_data = np.fromfile(fid, dtype=dtype, count=samp * dtype_byte) ch_data = ch_data.reshape(-1, 3).astype(INT32) - ch_data = ((ch_data[:, 0]) + - (ch_data[:, 1] << 8) + - (ch_data[:, 2] << 16)) + ch_data = (ch_data[:, 0]) + (ch_data[:, 1] << 8) + (ch_data[:, 2] << 16) # 24th bit determines the sign - ch_data[ch_data >= (1 << 23)] -= (1 << 24) + ch_data[ch_data >= (1 << 23)] -= 1 << 24 # GDF data and EDF data else: @@ -290,23 +348,22 @@ def _read_ch(fid, subtype, samp, dtype_byte, dtype=None): return ch_data -def _read_segment_file(data, idx, fi, start, stop, raw_extras, filenames, - cals, mult): +def _read_segment_file(data, idx, fi, start, stop, raw_extras, filenames, cals, mult): """Read a chunk of raw data.""" from scipy.interpolate import interp1d - n_samps = raw_extras['n_samps'] - buf_len = int(raw_extras['max_samp']) - dtype = raw_extras['dtype_np'] - dtype_byte = raw_extras['dtype_byte'] - data_offset = raw_extras['data_offset'] - stim_channel_idxs = raw_extras['stim_channel_idxs'] - orig_sel = raw_extras['sel'] - tal_idx = raw_extras.get('tal_idx', np.empty(0, int)) - subtype = raw_extras['subtype'] - cal = raw_extras['cal'] - offsets = raw_extras['offsets'] - gains = raw_extras['units'] + n_samps = raw_extras["n_samps"] + buf_len = int(raw_extras["max_samp"]) + dtype = raw_extras["dtype_np"] + dtype_byte = raw_extras["dtype_byte"] + data_offset = raw_extras["data_offset"] + stim_channel_idxs = raw_extras["stim_channel_idxs"] + orig_sel = raw_extras["sel"] + tal_idx = raw_extras.get("tal_idx", np.empty(0, int)) + subtype = raw_extras["subtype"] + cal = raw_extras["cal"] + offsets = raw_extras["offsets"] + gains = raw_extras["units"] read_sel = np.concatenate([orig_sel[idx], tal_idx]) tal_data = [] @@ -322,27 +379,25 @@ def _read_segment_file(data, idx, fi, start, stop, raw_extras, filenames, # Otherwise we can end up with e.g. 18,181 chunks for a 20 MB file! # Let's do ~10 MB chunks: n_per = max(10 * 1024 * 1024 // (ch_offsets[-1] * dtype_byte), 1) - with open(filenames, 'rb', buffering=0) as fid: - + with open(filenames, "rb", buffering=0) as fid: # Extract data - start_offset = (data_offset + - block_start_idx * ch_offsets[-1] * dtype_byte) + start_offset = data_offset + block_start_idx * ch_offsets[-1] * dtype_byte for ai in range(0, len(r_lims), n_per): block_offset = ai * ch_offsets[-1] * dtype_byte n_read = min(len(r_lims) - ai, n_per) fid.seek(start_offset + block_offset, 0) # Read and reshape to (n_chunks_read, ch0_ch1_ch2_ch3...) - many_chunk = _read_ch(fid, subtype, ch_offsets[-1] * n_read, - dtype_byte, dtype).reshape(n_read, -1) + many_chunk = _read_ch( + fid, subtype, ch_offsets[-1] * n_read, dtype_byte, dtype + ).reshape(n_read, -1) r_sidx = r_lims[ai][0] - r_eidx = (buf_len * (n_read - 1) + r_lims[ai + n_read - 1][1]) + r_eidx = buf_len * (n_read - 1) + r_lims[ai + n_read - 1][1] d_sidx = d_lims[ai][0] d_eidx = d_lims[ai + n_read - 1][1] one = np.zeros((len(orig_sel), d_eidx - d_sidx), dtype=data.dtype) for ii, ci in enumerate(read_sel): # This now has size (n_chunks_read, n_samp[ci]) - ch_data = many_chunk[:, - ch_offsets[ci]:ch_offsets[ci + 1]].copy() + ch_data = many_chunk[:, ch_offsets[ci] : ch_offsets[ci + 1]].copy() if ci in tal_idx: tal_data.append(ch_data) @@ -360,18 +415,20 @@ def _read_segment_file(data, idx, fi, start, stop, raw_extras, filenames, # Stim channel will be interpolated old = np.linspace(0, 1, n_samps[ci] + 1, True) new = np.linspace(0, 1, buf_len, False) - ch_data = np.append( - ch_data, np.zeros((len(ch_data), 1)), -1) - ch_data = interp1d(old, ch_data, - kind='zero', axis=-1)(new) + ch_data = np.append(ch_data, np.zeros((len(ch_data), 1)), -1) + ch_data = interp1d(old, ch_data, kind="zero", axis=-1)(new) else: # XXX resampling each chunk isn't great, # it forces edge artifacts to appear at # each buffer boundary :( # it can also be very slow... ch_data = resample( - ch_data.astype(np.float64), buf_len, n_samps[ci], - npad=0, axis=-1) + ch_data.astype(np.float64), + buf_len, + n_samps[ci], + npad=0, + axis=-1, + ) elif orig_idx in stim_channel_idxs: ch_data = np.bitwise_and(ch_data.astype(int), 2**17 - 1) one[orig_idx] = ch_data.ravel()[r_sidx:r_eidx] @@ -411,18 +468,20 @@ def _read_header(fname, exclude, infer_types, include=None): (edf_info, orig_units) : tuple """ ext = os.path.splitext(fname)[1][1:].lower() - logger.info('%s file detected' % ext.upper()) - if ext in ('bdf', 'edf'): + logger.info("%s file detected" % ext.upper()) + if ext in ("bdf", "edf"): return _read_edf_header(fname, exclude, infer_types, include) - elif ext == 'gdf': + elif ext == "gdf": return _read_gdf_header(fname, exclude, include), None else: raise NotImplementedError( - f'Only GDF, EDF, and BDF files are supported, got {ext}.') + f"Only GDF, EDF, and BDF files are supported, got {ext}." + ) -def _get_info(fname, stim_channel, eog, misc, exclude, infer_types, preload, - include=None): +def _get_info( + fname, stim_channel, eog, misc, exclude, infer_types, preload, include=None +): """Extract information from EDF+, BDF or GDF file.""" eog = eog if eog is not None else [] misc = misc if misc is not None else [] @@ -432,35 +491,38 @@ def _get_info(fname, stim_channel, eog, misc, exclude, infer_types, preload, # XXX: `tal_ch_names` to pass to `_check_stim_channel` should be computed # from `edf_info['ch_names']` and `edf_info['tal_idx']` but 'tal_idx' # contains stim channels that are not TAL. - stim_channel_idxs, _ = _check_stim_channel( - stim_channel, edf_info['ch_names']) + stim_channel_idxs, _ = _check_stim_channel(stim_channel, edf_info["ch_names"]) - sel = edf_info['sel'] # selection of channels not excluded - ch_names = edf_info['ch_names'] # of length len(sel) - if 'ch_types' in edf_info: - ch_types = edf_info['ch_types'] # of length len(sel) + sel = edf_info["sel"] # selection of channels not excluded + ch_names = edf_info["ch_names"] # of length len(sel) + if "ch_types" in edf_info: + ch_types = edf_info["ch_types"] # of length len(sel) else: ch_types = [None] * len(sel) if len(sel) == 0: # only want stim channels - n_samps = edf_info['n_samps'][[0]] + n_samps = edf_info["n_samps"][[0]] else: - n_samps = edf_info['n_samps'][sel] - nchan = edf_info['nchan'] - physical_ranges = edf_info['physical_max'] - edf_info['physical_min'] - cals = edf_info['digital_max'] - edf_info['digital_min'] + n_samps = edf_info["n_samps"][sel] + nchan = edf_info["nchan"] + physical_ranges = edf_info["physical_max"] - edf_info["physical_min"] + cals = edf_info["digital_max"] - edf_info["digital_min"] bad_idx = np.where((~np.isfinite(cals)) | (cals == 0))[0] if len(bad_idx) > 0: - warn('Scaling factor is not defined in following channels:\n' + - ', '.join(ch_names[i] for i in bad_idx)) + warn( + "Scaling factor is not defined in following channels:\n" + + ", ".join(ch_names[i] for i in bad_idx) + ) cals[bad_idx] = 1 bad_idx = np.where(physical_ranges == 0)[0] if len(bad_idx) > 0: - warn('Physical range is not defined in following channels:\n' + - ', '.join(ch_names[i] for i in bad_idx)) + warn( + "Physical range is not defined in following channels:\n" + + ", ".join(ch_names[i] for i in bad_idx) + ) physical_ranges[bad_idx] = 1 # Creates a list of dicts of eeg channels for raw.info - logger.info('Setting channel info structure...') + logger.info("Setting channel info structure...") chs = list() pick_mask = np.ones(len(ch_names)) @@ -468,144 +530,155 @@ def _get_info(fname, stim_channel, eog, misc, exclude, infer_types, preload, for idx, ch_name in enumerate(ch_names): chan_info = {} - chan_info['cal'] = 1. - chan_info['logno'] = idx + 1 - chan_info['scanno'] = idx + 1 - chan_info['range'] = 1. - chan_info['unit_mul'] = FIFF.FIFF_UNITM_NONE - chan_info['ch_name'] = ch_name - chan_info['unit'] = FIFF.FIFF_UNIT_V - chan_info['coord_frame'] = FIFF.FIFFV_COORD_HEAD - chan_info['coil_type'] = FIFF.FIFFV_COIL_EEG - chan_info['kind'] = FIFF.FIFFV_EEG_CH + chan_info["cal"] = 1.0 + chan_info["logno"] = idx + 1 + chan_info["scanno"] = idx + 1 + chan_info["range"] = 1.0 + chan_info["unit_mul"] = FIFF.FIFF_UNITM_NONE + chan_info["ch_name"] = ch_name + chan_info["unit"] = FIFF.FIFF_UNIT_V + chan_info["coord_frame"] = FIFF.FIFFV_COORD_HEAD + chan_info["coil_type"] = FIFF.FIFFV_COIL_EEG + chan_info["kind"] = FIFF.FIFFV_EEG_CH # montage can't be stored in EDF so channel locs are unknown: - chan_info['loc'] = np.full(12, np.nan) + chan_info["loc"] = np.full(12, np.nan) # if the edf info contained channel type information # set it now ch_type = ch_types[idx] if ch_type is not None and ch_type in CH_TYPE_MAPPING: - chan_info['kind'] = CH_TYPE_MAPPING.get(ch_type) - if ch_type not in ['EEG', 'ECOG', 'SEEG', 'DBS']: - chan_info['coil_type'] = FIFF.FIFFV_COIL_NONE + chan_info["kind"] = CH_TYPE_MAPPING.get(ch_type) + if ch_type not in ["EEG", "ECOG", "SEEG", "DBS"]: + chan_info["coil_type"] = FIFF.FIFFV_COIL_NONE pick_mask[idx] = False # if user passes in explicit mapping for eog, misc and stim # channels set them here if ch_name in eog or idx in eog or idx - nchan in eog: - chan_info['coil_type'] = FIFF.FIFFV_COIL_NONE - chan_info['kind'] = FIFF.FIFFV_EOG_CH + chan_info["coil_type"] = FIFF.FIFFV_COIL_NONE + chan_info["kind"] = FIFF.FIFFV_EOG_CH pick_mask[idx] = False elif ch_name in misc or idx in misc or idx - nchan in misc: - chan_info['coil_type'] = FIFF.FIFFV_COIL_NONE - chan_info['kind'] = FIFF.FIFFV_MISC_CH + chan_info["coil_type"] = FIFF.FIFFV_COIL_NONE + chan_info["kind"] = FIFF.FIFFV_MISC_CH pick_mask[idx] = False elif idx in stim_channel_idxs: - chan_info['coil_type'] = FIFF.FIFFV_COIL_NONE - chan_info['unit'] = FIFF.FIFF_UNIT_NONE - chan_info['kind'] = FIFF.FIFFV_STIM_CH + chan_info["coil_type"] = FIFF.FIFFV_COIL_NONE + chan_info["unit"] = FIFF.FIFF_UNIT_NONE + chan_info["kind"] = FIFF.FIFFV_STIM_CH pick_mask[idx] = False - chan_info['ch_name'] = ch_name - ch_names[idx] = chan_info['ch_name'] - edf_info['units'][idx] = 1 + chan_info["ch_name"] = ch_name + ch_names[idx] = chan_info["ch_name"] + edf_info["units"][idx] = 1 elif ch_type not in CH_TYPE_MAPPING: chs_without_types.append(ch_name) chs.append(chan_info) # warn if channel type was not inferable if len(chs_without_types): - msg = ('Could not determine channel type of the following channels, ' - f'they will be set as EEG:\n{", ".join(chs_without_types)}') + msg = ( + "Could not determine channel type of the following channels, " + f'they will be set as EEG:\n{", ".join(chs_without_types)}' + ) logger.info(msg) - edf_info['stim_channel_idxs'] = stim_channel_idxs + edf_info["stim_channel_idxs"] = stim_channel_idxs if any(pick_mask): picks = [item for item, mask in zip(range(nchan), pick_mask) if mask] - edf_info['max_samp'] = max_samp = n_samps[picks].max() + edf_info["max_samp"] = max_samp = n_samps[picks].max() else: - edf_info['max_samp'] = max_samp = n_samps.max() + edf_info["max_samp"] = max_samp = n_samps.max() # Info structure # ------------------------------------------------------------------------- - not_stim_ch = [x for x in range(n_samps.shape[0]) - if x not in stim_channel_idxs] + not_stim_ch = [x for x in range(n_samps.shape[0]) if x not in stim_channel_idxs] if len(not_stim_ch) == 0: # only loading stim channels not_stim_ch = list(range(len(n_samps))) - sfreq = np.take(n_samps, not_stim_ch).max() * \ - edf_info['record_length'][1] / edf_info['record_length'][0] + sfreq = ( + np.take(n_samps, not_stim_ch).max() + * edf_info["record_length"][1] + / edf_info["record_length"][0] + ) del n_samps info = _empty_info(sfreq) - info['meas_date'] = edf_info['meas_date'] - info['chs'] = chs - info['ch_names'] = ch_names + info["meas_date"] = edf_info["meas_date"] + info["chs"] = chs + info["ch_names"] = ch_names # Filter settings - highpass = edf_info['highpass'] - lowpass = edf_info['lowpass'] + highpass = edf_info["highpass"] + lowpass = edf_info["lowpass"] if highpass.size == 0: pass elif all(highpass): - if highpass[0] == 'NaN': + if highpass[0] == "NaN": # Placeholder for future use. Highpass set in _empty_info. pass - elif highpass[0] == 'DC': - info['highpass'] = 0. + elif highpass[0] == "DC": + info["highpass"] = 0.0 else: hp = highpass[0] try: hp = float(hp) except Exception: - hp = 0. - info['highpass'] = hp + hp = 0.0 + info["highpass"] = hp else: - info['highpass'] = float(np.max(highpass)) - warn('Channels contain different highpass filters. Highest filter ' - 'setting will be stored.') - if np.isnan(info['highpass']): - info['highpass'] = 0. + info["highpass"] = float(np.max(highpass)) + warn( + "Channels contain different highpass filters. Highest filter " + "setting will be stored." + ) + if np.isnan(info["highpass"]): + info["highpass"] = 0.0 if lowpass.size == 0: # Placeholder for future use. Lowpass set in _empty_info. pass elif all(lowpass): - if lowpass[0] in ('NaN', '0', '0.0'): + if lowpass[0] in ("NaN", "0", "0.0"): # Placeholder for future use. Lowpass set in _empty_info. pass else: - info['lowpass'] = float(lowpass[0]) + info["lowpass"] = float(lowpass[0]) else: - info['lowpass'] = float(np.min(lowpass)) - warn('Channels contain different lowpass filters. Lowest filter ' - 'setting will be stored.') - if np.isnan(info['lowpass']): - info['lowpass'] = info['sfreq'] / 2. - - if info['highpass'] > info['lowpass']: - warn(f'Highpass cutoff frequency {info["highpass"]} is greater ' - f'than lowpass cutoff frequency {info["lowpass"]}, ' - 'setting values to 0 and Nyquist.') - info['highpass'] = 0. - info['lowpass'] = info['sfreq'] / 2. + info["lowpass"] = float(np.min(lowpass)) + warn( + "Channels contain different lowpass filters. Lowest filter " + "setting will be stored." + ) + if np.isnan(info["lowpass"]): + info["lowpass"] = info["sfreq"] / 2.0 + + if info["highpass"] > info["lowpass"]: + warn( + f'Highpass cutoff frequency {info["highpass"]} is greater ' + f'than lowpass cutoff frequency {info["lowpass"]}, ' + "setting values to 0 and Nyquist." + ) + info["highpass"] = 0.0 + info["lowpass"] = info["sfreq"] / 2.0 # Some keys to be consistent with FIF measurement info - info['description'] = None - edf_info['nsamples'] = int(edf_info['n_records'] * max_samp) + info["description"] = None + edf_info["nsamples"] = int(edf_info["n_records"] * max_samp) info._unlocked = False info._update_redundant() # Later used for reading - edf_info['cal'] = physical_ranges / cals + edf_info["cal"] = physical_ranges / cals # physical dimension in µV - edf_info['offsets'] = ( - edf_info['physical_min'] - edf_info['digital_min'] * edf_info['cal']) - del edf_info['physical_min'] - del edf_info['digital_min'] + edf_info["offsets"] = ( + edf_info["physical_min"] - edf_info["digital_min"] * edf_info["cal"] + ) + del edf_info["physical_min"] + del edf_info["digital_min"] - if edf_info['subtype'] == 'bdf': - edf_info['cal'][stim_channel_idxs] = 1 - edf_info['offsets'][stim_channel_idxs] = 0 - edf_info['units'][stim_channel_idxs] = 1 + if edf_info["subtype"] == "bdf": + edf_info["cal"][stim_channel_idxs] = 1 + edf_info["offsets"][stim_channel_idxs] = 0 + edf_info["units"][stim_channel_idxs] = 1 return info, edf_info, orig_units @@ -613,18 +686,28 @@ def _get_info(fname, stim_channel, eog, misc, exclude, infer_types, preload, def _parse_prefilter_string(prefiltering): """Parse prefilter string from EDF+ and BDF headers.""" highpass = np.array( - [v for hp in [re.findall(r'HP:\s*([0-9]+[.]*[0-9]*)', filt) - for filt in prefiltering] for v in hp] + [ + v + for hp in [ + re.findall(r"HP:\s*([0-9]+[.]*[0-9]*)", filt) for filt in prefiltering + ] + for v in hp + ] ) lowpass = np.array( - [v for hp in [re.findall(r'LP:\s*([0-9]+[.]*[0-9]*)', filt) - for filt in prefiltering] for v in hp] + [ + v + for hp in [ + re.findall(r"LP:\s*([0-9]+[.]*[0-9]*)", filt) for filt in prefiltering + ] + for v in hp + ] ) return highpass, lowpass def _edf_str(x): - return x.decode('latin-1').split('\x00')[0] + return x.decode("latin-1").split("\x00")[0] def _edf_str_num(x): @@ -633,30 +716,29 @@ def _edf_str_num(x): def _read_edf_header(fname, exclude, infer_types, include=None): """Read header information from EDF+ or BDF file.""" - edf_info = {'events': []} - - with open(fname, 'rb') as fid: + edf_info = {"events": []} + with open(fname, "rb") as fid: fid.read(8) # version (unused here) # patient ID patient = {} - id_info = fid.read(80).decode('latin-1').rstrip() - id_info = id_info.split(' ') + id_info = fid.read(80).decode("latin-1").rstrip() + id_info = id_info.split(" ") if len(id_info): - patient['id'] = id_info[0] + patient["id"] = id_info[0] if len(id_info) == 4: try: birthdate = datetime.strptime(id_info[2], "%d-%b-%Y") except ValueError: birthdate = "X" - patient['sex'] = id_info[1] - patient['birthday'] = birthdate - patient['name'] = id_info[3] + patient["sex"] = id_info[1] + patient["birthday"] = birthdate + patient["name"] = id_info[3] # Recording ID meas_id = {} - rec_info = fid.read(80).decode('latin-1').rstrip().split(' ') + rec_info = fid.read(80).decode("latin-1").rstrip().split(" ") valid_startdate = False if len(rec_info) == 5: try: @@ -665,31 +747,34 @@ def _read_edf_header(fname, exclude, infer_types, include=None): startdate = "X" else: valid_startdate = True - meas_id['startdate'] = startdate - meas_id['study_id'] = rec_info[2] - meas_id['technician'] = rec_info[3] - meas_id['equipment'] = rec_info[4] + meas_id["startdate"] = startdate + meas_id["study_id"] = rec_info[2] + meas_id["technician"] = rec_info[3] + meas_id["equipment"] = rec_info[4] # If startdate available in recording info, use it instead of the # file's meas_date since it contains all 4 digits of the year if valid_startdate: - day = meas_id['startdate'].day - month = meas_id['startdate'].month - year = meas_id['startdate'].year + day = meas_id["startdate"].day + month = meas_id["startdate"].month + year = meas_id["startdate"].year fid.read(8) # skip file's meas_date else: - meas_date = fid.read(8).decode('latin-1') - day, month, year = [int(x) for x in meas_date.split('.')] + meas_date = fid.read(8).decode("latin-1") + day, month, year = [int(x) for x in meas_date.split(".")] year = year + 2000 if year < 85 else year + 1900 - meas_time = fid.read(8).decode('latin-1') - hour, minute, sec = [int(x) for x in meas_time.split('.')] + meas_time = fid.read(8).decode("latin-1") + hour, minute, sec = [int(x) for x in meas_time.split(".")] try: - meas_date = datetime(year, month, day, hour, minute, sec, - tzinfo=timezone.utc) + meas_date = datetime( + year, month, day, hour, minute, sec, tzinfo=timezone.utc + ) except ValueError: - warn(f'Invalid date encountered ({year:04d}-{month:02d}-' - f'{day:02d} {hour:02d}:{minute:02d}:{sec:02d}).') + warn( + f"Invalid date encountered ({year:04d}-{month:02d}-" + f"{day:02d} {hour:02d}:{minute:02d}:{sec:02d})." + ) meas_date = None header_nbytes = int(_edf_str(fid.read(8))) @@ -704,19 +789,21 @@ def _read_edf_header(fname, exclude, infer_types, include=None): n_records = int(_edf_str(fid.read(8))) record_length = float(_edf_str(fid.read(8))) - record_length = np.array([record_length, 1.]) # in seconds + record_length = np.array([record_length, 1.0]) # in seconds if record_length[0] == 0: - record_length[0] = 1. - warn('Header information is incorrect for record length. Default ' - 'record length set to 1.\nIt is possible that this file only' - ' contains annotations and no signals. In that case, please ' - 'use mne.read_annotations() to load these annotations.') + record_length[0] = 1.0 + warn( + "Header information is incorrect for record length. Default " + "record length set to 1.\nIt is possible that this file only" + " contains annotations and no signals. In that case, please " + "use mne.read_annotations() to load these annotations." + ) nchan = int(_edf_str(fid.read(4))) channels = list(range(nchan)) # read in 16 byte labels and strip any extra spaces at the end - ch_labels = [fid.read(16).strip().decode('latin-1') for _ in channels] + ch_labels = [fid.read(16).strip().decode("latin-1") for _ in channels] # get channel names and optionally channel type # EDF specification contains 16 bytes that encode channel names, @@ -725,18 +812,20 @@ def _read_edf_header(fname, exclude, infer_types, include=None): if infer_types: ch_types, ch_names = [], [] for ch_label in ch_labels: - ch_type, ch_name = 'EEG', ch_label # default to EEG - parts = ch_label.split(' ') + ch_type, ch_name = "EEG", ch_label # default to EEG + parts = ch_label.split(" ") if len(parts) > 1: if parts[0].upper() in CH_TYPE_MAPPING: ch_type = parts[0].upper() - ch_name = ' '.join(parts[1:]) - logger.info(f"Channel '{ch_label}' recognized as type " - f"{ch_type} (renamed to '{ch_name}').") + ch_name = " ".join(parts[1:]) + logger.info( + f"Channel '{ch_label}' recognized as type " + f"{ch_type} (renamed to '{ch_name}')." + ) ch_types.append(ch_type) ch_names.append(ch_name) else: - ch_types, ch_names = ['EEG'] * nchan, ch_labels + ch_types, ch_names = ["EEG"] * nchan, ch_labels exclude = _find_exclude_idx(ch_names, exclude, include) tal_idx = _find_tal_idx(ch_names) @@ -744,19 +833,19 @@ def _read_edf_header(fname, exclude, infer_types, include=None): sel = np.setdiff1d(np.arange(len(ch_names)), exclude) for ch in channels: fid.read(80) # transducer - units = [fid.read(8).strip().decode('latin-1') for ch in channels] - edf_info['units'] = list() + units = [fid.read(8).strip().decode("latin-1") for ch in channels] + edf_info["units"] = list() for i, unit in enumerate(units): if i in exclude: continue # allow μ (greek mu), µ (micro symbol) and μ (sjis mu) codepoints - if unit in ('\u03BCV', '\u00B5V', '\x83\xCAV', 'uV'): - edf_info['units'].append(1e-6) - elif unit == 'mV': - edf_info['units'].append(1e-3) + if unit in ("\u03BCV", "\u00B5V", "\x83\xCAV", "uV"): + edf_info["units"].append(1e-6) + elif unit == "mV": + edf_info["units"].append(1e-3) else: - edf_info['units'].append(1) - edf_info['units'] = np.array(edf_info['units'], float) + edf_info["units"].append(1) + edf_info["units"] = np.array(edf_info["units"], float) ch_names = [ch_names[idx] for idx in sel] units = [units[idx] for idx in sel] @@ -765,14 +854,18 @@ def _read_edf_header(fname, exclude, infer_types, include=None): ch_names = _unique_channel_names(ch_names) orig_units = dict(zip(ch_names, units)) - physical_min = np.array( - [float(_edf_str_num(fid.read(8))) for ch in channels])[sel] - physical_max = np.array( - [float(_edf_str_num(fid.read(8))) for ch in channels])[sel] - digital_min = np.array( - [float(_edf_str_num(fid.read(8))) for ch in channels])[sel] - digital_max = np.array( - [float(_edf_str_num(fid.read(8))) for ch in channels])[sel] + physical_min = np.array([float(_edf_str_num(fid.read(8))) for ch in channels])[ + sel + ] + physical_max = np.array([float(_edf_str_num(fid.read(8))) for ch in channels])[ + sel + ] + digital_min = np.array([float(_edf_str_num(fid.read(8))) for ch in channels])[ + sel + ] + digital_max = np.array([float(_edf_str_num(fid.read(8))) for ch in channels])[ + sel + ] prefiltering = [_edf_str(fid.read(80)).strip() for ch in channels][:-1] highpass, lowpass = _parse_prefilter_string(prefiltering) @@ -781,13 +874,25 @@ def _read_edf_header(fname, exclude, infer_types, include=None): # Populate edf_info edf_info.update( - ch_names=ch_names, ch_types=ch_types, data_offset=header_nbytes, - digital_max=digital_max, digital_min=digital_min, - highpass=highpass, sel=sel, lowpass=lowpass, meas_date=meas_date, - n_records=n_records, n_samps=n_samps, nchan=nchan, - subject_info=patient, physical_max=physical_max, - physical_min=physical_min, record_length=record_length, - subtype=subtype, tal_idx=tal_idx) + ch_names=ch_names, + ch_types=ch_types, + data_offset=header_nbytes, + digital_max=digital_max, + digital_min=digital_min, + highpass=highpass, + sel=sel, + lowpass=lowpass, + meas_date=meas_date, + n_records=n_records, + n_samps=n_samps, + nchan=nchan, + subject_info=patient, + physical_max=physical_max, + physical_min=physical_min, + record_length=record_length, + subtype=subtype, + tal_idx=tal_idx, + ) fid.read(32 * nchan).decode() # reserved assert fid.tell() == header_nbytes @@ -795,41 +900,58 @@ def _read_edf_header(fname, exclude, infer_types, include=None): fid.seek(0, 2) n_bytes = fid.tell() n_data_bytes = n_bytes - header_nbytes - total_samps = (n_data_bytes // 3 if subtype == 'bdf' - else n_data_bytes // 2) + total_samps = n_data_bytes // 3 if subtype == "bdf" else n_data_bytes // 2 read_records = total_samps // np.sum(n_samps) if n_records != read_records: - warn('Number of records from the header does not match the file ' - 'size (perhaps the recording was not stopped before exiting).' - ' Inferring from the file size.') - edf_info['n_records'] = read_records + warn( + "Number of records from the header does not match the file " + "size (perhaps the recording was not stopped before exiting)." + " Inferring from the file size." + ) + edf_info["n_records"] = read_records del n_records - if subtype == 'bdf': - edf_info['dtype_byte'] = 3 # 24-bit (3 byte) integers - edf_info['dtype_np'] = UINT8 + if subtype == "bdf": + edf_info["dtype_byte"] = 3 # 24-bit (3 byte) integers + edf_info["dtype_np"] = UINT8 else: - edf_info['dtype_byte'] = 2 # 16-bit (2 byte) integers - edf_info['dtype_np'] = INT16 + edf_info["dtype_byte"] = 2 # 16-bit (2 byte) integers + edf_info["dtype_np"] = INT16 return edf_info, orig_units -INT8 = '= 2: - patient['id'] = pid[0] - patient['name'] = pid[1] + patient["id"] = pid[0] + patient["name"] = pid[1] # Recording ID meas_id = {} - meas_id['recording_id'] = _edf_str(fid.read(80)).strip() + meas_id["recording_id"] = _edf_str(fid.read(80)).strip() # date tm = _edf_str(fid.read(16)).strip() try: - if tm[14:16] == ' ': - tm = tm[:14] + '00' + tm[16:] + if tm[14:16] == " ": + tm = tm[:14] + "00" + tm[16:] meas_date = datetime( - int(tm[0:4]), int(tm[4:6]), - int(tm[6:8]), int(tm[8:10]), - int(tm[10:12]), int(tm[12:14]), + int(tm[0:4]), + int(tm[4:6]), + int(tm[6:8]), + int(tm[8:10]), + int(tm[10:12]), + int(tm[12:14]), int(tm[14:16]) * pow(10, 4), - tzinfo=timezone.utc) + tzinfo=timezone.utc, + ) except Exception: pass header_nbytes = np.fromfile(fid, INT64, 1)[0] - meas_id['equipment'] = np.fromfile(fid, UINT8, 8)[0] - meas_id['hospital'] = np.fromfile(fid, UINT8, 8)[0] - meas_id['technician'] = np.fromfile(fid, UINT8, 8)[0] - fid.seek(20, 1) # 20bytes reserved + meas_id["equipment"] = np.fromfile(fid, UINT8, 8)[0] + meas_id["hospital"] = np.fromfile(fid, UINT8, 8)[0] + meas_id["technician"] = np.fromfile(fid, UINT8, 8)[0] + fid.seek(20, 1) # 20bytes reserved n_records = np.fromfile(fid, INT64, 1)[0] # record length in seconds record_length = np.fromfile(fid, UINT32, 2) if record_length[0] == 0: - record_length[0] = 1. - warn('Header information is incorrect for record length. ' - 'Default record length set to 1.') + record_length[0] = 1.0 + warn( + "Header information is incorrect for record length. " + "Default record length set to 1." + ) nchan = np.fromfile(fid, UINT32, 1)[0] channels = list(range(nchan)) ch_names = [_edf_str(fid.read(16)).strip() for ch in channels] @@ -903,15 +1029,15 @@ def _read_gdf_header(fname, exclude, include=None): sel = np.setdiff1d(np.arange(len(ch_names)), exclude) fid.seek(80 * len(channels), 1) # transducer units = [_edf_str(fid.read(8)).strip() for ch in channels] - edf_info['units'] = list() + edf_info["units"] = list() for i, unit in enumerate(units): if i in exclude: continue - if unit[:2] == 'uV': - edf_info['units'].append(1e-6) + if unit[:2] == "uV": + edf_info["units"].append(1e-6) else: - edf_info['units'].append(1) - edf_info['units'] = np.array(edf_info['units'], float) + edf_info["units"].append(1) + edf_info["units"] = np.array(edf_info["units"], float) ch_names = [ch_names[idx] for idx in sel] physical_min = np.fromfile(fid, FLOAT64, len(channels)) @@ -928,28 +1054,41 @@ def _read_gdf_header(fname, exclude, include=None): dtype = np.fromfile(fid, INT32, len(channels)) # total number of bytes for data - bytes_tot = np.sum([GDFTYPE_BYTE[t] * n_samps[i] - for i, t in enumerate(dtype)]) + bytes_tot = np.sum( + [GDFTYPE_BYTE[t] * n_samps[i] for i, t in enumerate(dtype)] + ) # Populate edf_info dtype_np, dtype_byte = _check_dtype_byte(dtype) edf_info.update( - bytes_tot=bytes_tot, ch_names=ch_names, - data_offset=header_nbytes, digital_min=digital_min, + bytes_tot=bytes_tot, + ch_names=ch_names, + data_offset=header_nbytes, + digital_min=digital_min, digital_max=digital_max, - dtype_byte=dtype_byte, dtype_np=dtype_np, exclude=exclude, - highpass=highpass, sel=sel, lowpass=lowpass, + dtype_byte=dtype_byte, + dtype_np=dtype_np, + exclude=exclude, + highpass=highpass, + sel=sel, + lowpass=lowpass, meas_date=meas_date, - meas_id=meas_id, n_records=n_records, n_samps=n_samps, - nchan=nchan, subject_info=patient, physical_max=physical_max, - physical_min=physical_min, record_length=record_length) + meas_id=meas_id, + n_records=n_records, + n_samps=n_samps, + nchan=nchan, + subject_info=patient, + physical_max=physical_max, + physical_min=physical_min, + record_length=record_length, + ) - fid.seek(32 * edf_info['nchan'], 1) # reserved + fid.seek(32 * edf_info["nchan"], 1) # reserved assert fid.tell() == header_nbytes # Event table # ----------------------------------------------------------------- - etp = header_nbytes + n_records * edf_info['bytes_tot'] + etp = header_nbytes + n_records * edf_info["bytes_tot"] # skip data to go to event table fid.seek(etp) etmode = np.fromfile(fid, UINT8, 1)[0] @@ -975,63 +1114,62 @@ def _read_gdf_header(fname, exclude, include=None): # --------------------------------------------------------------------- else: # FIXED HEADER - handedness = ('Unknown', 'Right', 'Left', 'Equal') - gender = ('Unknown', 'Male', 'Female') - scale = ('Unknown', 'No', 'Yes', 'Corrected') + handedness = ("Unknown", "Right", "Left", "Equal") + gender = ("Unknown", "Male", "Female") + scale = ("Unknown", "No", "Yes", "Corrected") # date pid = fid.read(66).decode() - pid = pid.split(' ', 2) + pid = pid.split(" ", 2) patient = {} if len(pid) >= 2: - patient['id'] = pid[0] - patient['name'] = pid[1] + patient["id"] = pid[0] + patient["name"] = pid[1] fid.seek(10, 1) # 10bytes reserved # Smoking / Alcohol abuse / drug abuse / medication sadm = np.fromfile(fid, UINT8, 1)[0] - patient['smoking'] = scale[sadm % 4] - patient['alcohol_abuse'] = scale[(sadm >> 2) % 4] - patient['drug_abuse'] = scale[(sadm >> 4) % 4] - patient['medication'] = scale[(sadm >> 6) % 4] - patient['weight'] = np.fromfile(fid, UINT8, 1)[0] - if patient['weight'] == 0 or patient['weight'] == 255: - patient['weight'] = None - patient['height'] = np.fromfile(fid, UINT8, 1)[0] - if patient['height'] == 0 or patient['height'] == 255: - patient['height'] = None + patient["smoking"] = scale[sadm % 4] + patient["alcohol_abuse"] = scale[(sadm >> 2) % 4] + patient["drug_abuse"] = scale[(sadm >> 4) % 4] + patient["medication"] = scale[(sadm >> 6) % 4] + patient["weight"] = np.fromfile(fid, UINT8, 1)[0] + if patient["weight"] == 0 or patient["weight"] == 255: + patient["weight"] = None + patient["height"] = np.fromfile(fid, UINT8, 1)[0] + if patient["height"] == 0 or patient["height"] == 255: + patient["height"] = None # Gender / Handedness / Visual Impairment ghi = np.fromfile(fid, UINT8, 1)[0] - patient['sex'] = gender[ghi % 4] - patient['handedness'] = handedness[(ghi >> 2) % 4] - patient['visual'] = scale[(ghi >> 4) % 4] + patient["sex"] = gender[ghi % 4] + patient["handedness"] = handedness[(ghi >> 2) % 4] + patient["visual"] = scale[(ghi >> 4) % 4] # Recording identification meas_id = {} - meas_id['recording_id'] = _edf_str(fid.read(64)).strip() + meas_id["recording_id"] = _edf_str(fid.read(64)).strip() vhsv = np.fromfile(fid, UINT8, 4) loc = {} if vhsv[3] == 0: - loc['vertpre'] = 10 * int(vhsv[0] >> 4) + int(vhsv[0] % 16) - loc['horzpre'] = 10 * int(vhsv[1] >> 4) + int(vhsv[1] % 16) - loc['size'] = 10 * int(vhsv[2] >> 4) + int(vhsv[2] % 16) + loc["vertpre"] = 10 * int(vhsv[0] >> 4) + int(vhsv[0] % 16) + loc["horzpre"] = 10 * int(vhsv[1] >> 4) + int(vhsv[1] % 16) + loc["size"] = 10 * int(vhsv[2] >> 4) + int(vhsv[2] % 16) else: - loc['vertpre'] = 29 - loc['horzpre'] = 29 - loc['size'] = 29 - loc['version'] = 0 - loc['latitude'] = \ - float(np.fromfile(fid, UINT32, 1)[0]) / 3600000 - loc['longitude'] = \ - float(np.fromfile(fid, UINT32, 1)[0]) / 3600000 - loc['altitude'] = float(np.fromfile(fid, INT32, 1)[0]) / 100 - meas_id['loc'] = loc + loc["vertpre"] = 29 + loc["horzpre"] = 29 + loc["size"] = 29 + loc["version"] = 0 + loc["latitude"] = float(np.fromfile(fid, UINT32, 1)[0]) / 3600000 + loc["longitude"] = float(np.fromfile(fid, UINT32, 1)[0]) / 3600000 + loc["altitude"] = float(np.fromfile(fid, INT32, 1)[0]) / 100 + meas_id["loc"] = loc meas_date = np.fromfile(fid, UINT64, 1)[0] if meas_date != 0: - meas_date = (datetime(1, 1, 1, tzinfo=timezone.utc) + - timedelta(meas_date * pow(2, -32) - 367)) + meas_date = datetime(1, 1, 1, tzinfo=timezone.utc) + timedelta( + meas_date * pow(2, -32) - 367 + ) else: meas_date = None @@ -1039,29 +1177,29 @@ def _read_gdf_header(fname, exclude, include=None): if birthday == 0: birthday = datetime(1, 1, 1, tzinfo=timezone.utc) else: - birthday = (datetime(1, 1, 1, tzinfo=timezone.utc) + - timedelta(birthday * pow(2, -32) - 367)) - patient['birthday'] = birthday - if patient['birthday'] != datetime(1, 1, 1, 0, 0, - tzinfo=timezone.utc): + birthday = datetime(1, 1, 1, tzinfo=timezone.utc) + timedelta( + birthday * pow(2, -32) - 367 + ) + patient["birthday"] = birthday + if patient["birthday"] != datetime(1, 1, 1, 0, 0, tzinfo=timezone.utc): today = datetime.now(tz=timezone.utc) - patient['age'] = today.year - patient['birthday'].year - today = today.replace(year=patient['birthday'].year) - if today < patient['birthday']: - patient['age'] -= 1 + patient["age"] = today.year - patient["birthday"].year + today = today.replace(year=patient["birthday"].year) + if today < patient["birthday"]: + patient["age"] -= 1 else: - patient['age'] = None + patient["age"] = None header_nbytes = np.fromfile(fid, UINT16, 1)[0] * 256 fid.seek(6, 1) # 6 bytes reserved - meas_id['equipment'] = np.fromfile(fid, UINT8, 8) - meas_id['ip'] = np.fromfile(fid, UINT8, 6) - patient['headsize'] = np.fromfile(fid, UINT16, 3) - patient['headsize'] = np.asarray(patient['headsize'], np.float32) - patient['headsize'] = np.ma.masked_array( - patient['headsize'], - np.equal(patient['headsize'], 0), None).filled() + meas_id["equipment"] = np.fromfile(fid, UINT8, 8) + meas_id["ip"] = np.fromfile(fid, UINT8, 6) + patient["headsize"] = np.fromfile(fid, UINT16, 3) + patient["headsize"] = np.asarray(patient["headsize"], np.float32) + patient["headsize"] = np.ma.masked_array( + patient["headsize"], np.equal(patient["headsize"], 0), None + ).filled() ref = np.fromfile(fid, FLOAT32, 3) gnd = np.fromfile(fid, FLOAT32, 3) n_records = np.fromfile(fid, INT64, 1)[0] @@ -1069,9 +1207,11 @@ def _read_gdf_header(fname, exclude, include=None): # record length in seconds record_length = np.fromfile(fid, UINT32, 2) if record_length[0] == 0: - record_length[0] = 1. - warn('Header information is incorrect for record length. ' - 'Default record length set to 1.') + record_length[0] = 1.0 + warn( + "Header information is incorrect for record length. " + "Default record length set to 1." + ) nchan = np.fromfile(fid, UINT16, 1)[0] fid.seek(2, 1) # 2bytes reserved @@ -1093,24 +1233,26 @@ def _read_gdf_header(fname, exclude, include=None): """ # noqa units = np.fromfile(fid, UINT16, len(channels)).tolist() unitcodes = np.array(units[:]) - edf_info['units'] = list() + edf_info["units"] = list() for i, unit in enumerate(units): if i in exclude: continue if unit == 4275: # microvolts - edf_info['units'].append(1e-6) + edf_info["units"].append(1e-6) elif unit == 4274: # millivolts - edf_info['units'].append(1e-3) + edf_info["units"].append(1e-3) elif unit == 512: # dimensionless - edf_info['units'].append(1) + edf_info["units"].append(1) elif unit == 0: - edf_info['units'].append(1) # unrecognized + edf_info["units"].append(1) # unrecognized else: - warn('Unsupported physical dimension for channel %d ' - '(assuming dimensionless). Please contact the ' - 'MNE-Python developers for support.' % i) - edf_info['units'].append(1) - edf_info['units'] = np.array(edf_info['units'], float) + warn( + "Unsupported physical dimension for channel %d " + "(assuming dimensionless). Please contact the " + "MNE-Python developers for support." % i + ) + edf_info["units"].append(1) + edf_info["units"] = np.array(edf_info["units"], float) ch_names = [ch_names[idx] for idx in sel] physical_min = np.fromfile(fid, FLOAT64, len(channels)) @@ -1130,14 +1272,12 @@ def _read_gdf_header(fname, exclude, include=None): dtype = np.fromfile(fid, INT32, len(channels)) channel = {} - channel['xyz'] = [np.fromfile(fid, FLOAT32, 3)[0] - for ch in channels] + channel["xyz"] = [np.fromfile(fid, FLOAT32, 3)[0] for ch in channels] - if edf_info['number'] < 2.19: - impedance = np.fromfile(fid, UINT8, - len(channels)).astype(float) + if edf_info["number"] < 2.19: + impedance = np.fromfile(fid, UINT8, len(channels)).astype(float) impedance[impedance == 255] = np.nan - channel['impedance'] = pow(2, impedance / 8) + channel["impedance"] = pow(2, impedance / 8) fid.seek(19 * len(channels), 1) # reserved else: tmp = np.fromfile(fid, FLOAT32, 5 * len(channels)) @@ -1154,43 +1294,60 @@ def _read_gdf_header(fname, exclude, include=None): assert fid.tell() == header_nbytes # total number of bytes for data - bytes_tot = np.sum([GDFTYPE_BYTE[t] * n_samps[i] - for i, t in enumerate(dtype)]) + bytes_tot = np.sum( + [GDFTYPE_BYTE[t] * n_samps[i] for i, t in enumerate(dtype)] + ) # Populate edf_info dtype_np, dtype_byte = _check_dtype_byte(dtype) edf_info.update( - bytes_tot=bytes_tot, ch_names=ch_names, + bytes_tot=bytes_tot, + ch_names=ch_names, data_offset=header_nbytes, - dtype_byte=dtype_byte, dtype_np=dtype_np, - digital_min=digital_min, digital_max=digital_max, - exclude=exclude, gnd=gnd, highpass=highpass, sel=sel, - impedance=impedance, lowpass=lowpass, meas_date=meas_date, - meas_id=meas_id, n_records=n_records, n_samps=n_samps, - nchan=nchan, notch=notch, subject_info=patient, - physical_max=physical_max, physical_min=physical_min, - record_length=record_length, ref=ref) + dtype_byte=dtype_byte, + dtype_np=dtype_np, + digital_min=digital_min, + digital_max=digital_max, + exclude=exclude, + gnd=gnd, + highpass=highpass, + sel=sel, + impedance=impedance, + lowpass=lowpass, + meas_date=meas_date, + meas_id=meas_id, + n_records=n_records, + n_samps=n_samps, + nchan=nchan, + notch=notch, + subject_info=patient, + physical_max=physical_max, + physical_min=physical_min, + record_length=record_length, + ref=ref, + ) # EVENT TABLE # ----------------------------------------------------------------- - etp = edf_info['data_offset'] + edf_info['n_records'] * \ - edf_info['bytes_tot'] + etp = ( + edf_info["data_offset"] + edf_info["n_records"] * edf_info["bytes_tot"] + ) fid.seek(etp) # skip data to go to event table etmode = fid.read(1).decode() - if etmode != '': + if etmode != "": etmode = np.fromstring(etmode, UINT8).tolist()[0] - if edf_info['number'] < 1.94: + if edf_info["number"] < 1.94: sr = np.fromfile(fid, UINT8, 3) event_sr = sr[0] for i in range(1, len(sr)): - event_sr = event_sr + sr[i] * 2**(i * 8) + event_sr = event_sr + sr[i] * 2 ** (i * 8) n_events = np.fromfile(fid, UINT32, 1)[0] else: ne = np.fromfile(fid, UINT8, 3) n_events = ne[0] for i in range(1, len(ne)): - n_events = n_events + ne[i] * 2**(i * 8) + n_events = n_events + ne[i] * 2 ** (i * 8) event_sr = np.fromfile(fid, FLOAT32, 1)[0] pos = np.fromfile(fid, UINT32, n_events) - 1 # 1-based inds @@ -1204,30 +1361,34 @@ def _read_gdf_header(fname, exclude, include=None): dur = np.ones(n_events, dtype=np.uint32) np.clip(dur, 1, np.inf, out=dur) events = [n_events, pos, typ, chn, dur] - edf_info['event_sfreq'] = event_sr + edf_info["event_sfreq"] = event_sr - edf_info.update(events=events, sel=np.arange(len(edf_info['ch_names']))) + edf_info.update(events=events, sel=np.arange(len(edf_info["ch_names"]))) return edf_info -def _check_stim_channel(stim_channel, ch_names, - tal_ch_names=['EDF Annotations', 'BDF Annotations']): +def _check_stim_channel( + stim_channel, ch_names, tal_ch_names=["EDF Annotations", "BDF Annotations"] +): """Check that the stimulus channel exists in the current datafile.""" - DEFAULT_STIM_CH_NAMES = ['status', 'trigger'] + DEFAULT_STIM_CH_NAMES = ["status", "trigger"] if stim_channel is None or stim_channel is False: return [], [] if stim_channel is True: # convenient aliases - stim_channel = 'auto' + stim_channel = "auto" elif isinstance(stim_channel, str): - if stim_channel == 'auto': - if 'auto' in ch_names: - warn(RuntimeWarning, "Using `stim_channel='auto'` when auto" - " also corresponds to a channel name is ambiguous." - " Please use `stim_channel=['auto']`.") + if stim_channel == "auto": + if "auto" in ch_names: + warn( + RuntimeWarning, + "Using `stim_channel='auto'` when auto" + " also corresponds to a channel name is ambiguous." + " Please use `stim_channel=['auto']`.", + ) else: valid_stim_ch_names = DEFAULT_STIM_CH_NAMES else: @@ -1242,18 +1403,20 @@ def _check_stim_channel(stim_channel, ch_names, elif all([isinstance(s, int) for s in stim_channel]): valid_stim_ch_names = [ch_names[s].lower() for s in stim_channel] else: - raise ValueError('Invalid stim_channel') + raise ValueError("Invalid stim_channel") else: - raise ValueError('Invalid stim_channel') + raise ValueError("Invalid stim_channel") # Forbid the synthesis of stim channels from TAL Annotations - tal_ch_names_found = [ch for ch in valid_stim_ch_names - if ch in [t.lower() for t in tal_ch_names]] + tal_ch_names_found = [ + ch for ch in valid_stim_ch_names if ch in [t.lower() for t in tal_ch_names] + ] if len(tal_ch_names_found): - _msg = ('The synthesis of the stim channel is not supported' - ' since 0.18. Please remove {} from `stim_channel`' - ' and use `mne.events_from_annotations` instead' - ).format(tal_ch_names_found) + _msg = ( + "The synthesis of the stim channel is not supported" + " since 0.18. Please remove {} from `stim_channel`" + " and use `mne.events_from_annotations` instead" + ).format(tal_ch_names_found) raise ValueError(_msg) ch_names_low = [ch.lower() for ch in ch_names] @@ -1276,8 +1439,8 @@ def _find_exclude_idx(ch_names, exclude, include=None): if include: # find other than include channels if exclude: raise ValueError( - "'exclude' must be empty if 'include' is assigned. " - f"Got {exclude}.") + "'exclude' must be empty if 'include' is assigned. " f"Got {exclude}." + ) if isinstance(include, str): # regex for channel names indices_include = [] for idx, ch in enumerate(ch_names): @@ -1300,15 +1463,26 @@ def _find_exclude_idx(ch_names, exclude, include=None): def _find_tal_idx(ch_names): # Annotations / TAL Channels - accepted_tal_ch_names = ['EDF Annotations', 'BDF Annotations'] + accepted_tal_ch_names = ["EDF Annotations", "BDF Annotations"] tal_channel_idx = np.where(np.in1d(ch_names, accepted_tal_ch_names))[0] return tal_channel_idx @fill_doc -def read_raw_edf(input_fname, eog=None, misc=None, stim_channel='auto', - exclude=(), infer_types=False, include=None, preload=False, - units=None, encoding='utf8', *, verbose=None): +def read_raw_edf( + input_fname, + eog=None, + misc=None, + stim_channel="auto", + exclude=(), + infer_types=False, + include=None, + preload=False, + units=None, + encoding="utf8", + *, + verbose=None, +): """Reader function for EDF and EDF+ files. Parameters @@ -1409,18 +1583,38 @@ def read_raw_edf(input_fname, eog=None, misc=None, stim_channel='auto', """ input_fname = os.path.abspath(input_fname) ext = os.path.splitext(input_fname)[1][1:].lower() - if ext != 'edf': - raise NotImplementedError(f'Only EDF files are supported, got {ext}.') - return RawEDF(input_fname=input_fname, eog=eog, misc=misc, - stim_channel=stim_channel, exclude=exclude, - infer_types=infer_types, preload=preload, include=include, - units=units, encoding=encoding, verbose=verbose) + if ext != "edf": + raise NotImplementedError(f"Only EDF files are supported, got {ext}.") + return RawEDF( + input_fname=input_fname, + eog=eog, + misc=misc, + stim_channel=stim_channel, + exclude=exclude, + infer_types=infer_types, + preload=preload, + include=include, + units=units, + encoding=encoding, + verbose=verbose, + ) @fill_doc -def read_raw_bdf(input_fname, eog=None, misc=None, stim_channel='auto', - exclude=(), infer_types=False, include=None, preload=False, - units=None, encoding='utf8', *, verbose=None): +def read_raw_bdf( + input_fname, + eog=None, + misc=None, + stim_channel="auto", + exclude=(), + infer_types=False, + include=None, + preload=False, + units=None, + encoding="utf8", + *, + verbose=None, +): """Reader function for BDF files. Parameters @@ -1514,17 +1708,34 @@ def read_raw_bdf(input_fname, eog=None, misc=None, stim_channel='auto', """ input_fname = os.path.abspath(input_fname) ext = os.path.splitext(input_fname)[1][1:].lower() - if ext != 'bdf': - raise NotImplementedError(f'Only BDF files are supported, got {ext}.') - return RawEDF(input_fname=input_fname, eog=eog, misc=misc, - stim_channel=stim_channel, exclude=exclude, - infer_types=infer_types, preload=preload, include=include, - units=units, encoding=encoding, verbose=verbose) + if ext != "bdf": + raise NotImplementedError(f"Only BDF files are supported, got {ext}.") + return RawEDF( + input_fname=input_fname, + eog=eog, + misc=misc, + stim_channel=stim_channel, + exclude=exclude, + infer_types=infer_types, + preload=preload, + include=include, + units=units, + encoding=encoding, + verbose=verbose, + ) @fill_doc -def read_raw_gdf(input_fname, eog=None, misc=None, stim_channel='auto', - exclude=(), include=None, preload=False, verbose=None): +def read_raw_gdf( + input_fname, + eog=None, + misc=None, + stim_channel="auto", + exclude=(), + include=None, + preload=False, + verbose=None, +): """Reader function for GDF files. Parameters @@ -1574,15 +1785,22 @@ def read_raw_gdf(input_fname, eog=None, misc=None, stim_channel='auto', """ input_fname = os.path.abspath(input_fname) ext = os.path.splitext(input_fname)[1][1:].lower() - if ext != 'gdf': - raise NotImplementedError(f'Only BDF files are supported, got {ext}.') - return RawGDF(input_fname=input_fname, eog=eog, misc=misc, - stim_channel=stim_channel, exclude=exclude, preload=preload, - include=include, verbose=verbose) + if ext != "gdf": + raise NotImplementedError(f"Only BDF files are supported, got {ext}.") + return RawGDF( + input_fname=input_fname, + eog=eog, + misc=misc, + stim_channel=stim_channel, + exclude=exclude, + preload=preload, + include=include, + verbose=verbose, + ) @fill_doc -def _read_annotations_edf(annotations, encoding='utf8'): +def _read_annotations_edf(annotations, encoding="utf8"): """Annotation File Reader. Parameters @@ -1602,7 +1820,7 @@ def _read_annotations_edf(annotations, encoding='utf8'): string, all the annotations are given the same description. To reject epochs, use description starting with keyword 'bad'. See example above. """ - pat = '([+-]\\d+\\.?\\d*)(\x15(\\d+\\.?\\d*))?(\x14.*?)\x14\x00' + pat = "([+-]\\d+\\.?\\d*)(\x15(\\d+\\.?\\d*))?(\x14.*?)\x14\x00" if isinstance(annotations, str): with open(annotations, "rb") as annot_file: triggers = re.findall(pat.encode(), annot_file.read()) @@ -1623,8 +1841,7 @@ def _read_annotations_edf(annotations, encoding='utf8'): else: this_chan = chan.astype(np.int64) # Exploit np vectorized processing - tals.extend(np.uint8([this_chan % 256, this_chan // 256]) - .flatten('F')) + tals.extend(np.uint8([this_chan % 256, this_chan // 256]).flatten("F")) try: triggers = re.findall(pat, tals.decode(encoding)) except UnicodeDecodeError as e: @@ -1634,11 +1851,11 @@ def _read_annotations_edf(annotations, encoding='utf8'): ) from e events = [] - offset = 0. + offset = 0.0 for k, ev in enumerate(triggers): onset = float(ev[0]) + offset duration = float(ev[2]) if ev[2] else 0 - for description in ev[3].split('\x14')[1:]: + for description in ev[3].split("\x14")[1:]: if description: events.append([onset, duration, description]) elif k == 0: @@ -1657,7 +1874,7 @@ def _read_annotations_edf(annotations, encoding='utf8'): def _get_annotations_gdf(edf_info, sfreq): onset, duration, desc = list(), list(), list() - events = edf_info.get('events', None) + events = edf_info.get("events", None) # Annotations in GDF: events are stored as the following # list: `events = [n_events, pos, typ, chn, dur]` where pos is the # latency, dur is the duration in samples. They both are diff --git a/mne/io/edf/tests/test_edf.py b/mne/io/edf/tests/test_edf.py index 597cae5eee1..f09c45ee419 100644 --- a/mne/io/edf/tests/test_edf.py +++ b/mne/io/edf/tests/test_edf.py @@ -12,8 +12,12 @@ from pathlib import Path import numpy as np -from numpy.testing import (assert_array_almost_equal, assert_array_equal, - assert_equal, assert_allclose) +from numpy.testing import ( + assert_array_almost_equal, + assert_array_equal, + assert_equal, + assert_allclose, +) from scipy.io import loadmat import pytest @@ -24,9 +28,14 @@ from mne.utils import requires_pandas, _record_warnings from mne.io import read_raw_edf, read_raw_bdf, read_raw_fif, edf, read_raw_gdf from mne.io.tests.test_raw import _test_raw_reader -from mne.io.edf.edf import (_read_annotations_edf, - _read_ch, _parse_prefilter_string, _edf_str, - _read_edf_header, _read_header) +from mne.io.edf.edf import ( + _read_annotations_edf, + _read_ch, + _parse_prefilter_string, + _edf_str, + _read_edf_header, + _read_header, +) from mne.io.pick import channel_indices_by_type, get_channel_type_constants from mne.tests.test_annotations import _assert_annotations_equal @@ -46,23 +55,19 @@ data_path = testing.data_path(download=False) edf_stim_resamp_path = data_path / "EDF" / "test_edf_stim_resamp.edf" -edf_overlap_annot_path = ( - data_path / "EDF" / "test_edf_overlapping_annotations.edf" -) +edf_overlap_annot_path = data_path / "EDF" / "test_edf_overlapping_annotations.edf" edf_reduced = data_path / "EDF" / "test_reduced.edf" edf_annot_only = data_path / "EDF" / "SC4001EC-Hypnogram.edf" bdf_stim_channel_path = data_path / "BDF" / "test_bdf_stim_channel.bdf" -bdf_multiple_annotations_path = ( - data_path / "BDF" / "multiple_annotation_chans.bdf" -) +bdf_multiple_annotations_path = data_path / "BDF" / "multiple_annotation_chans.bdf" test_generator_bdf = data_path / "BDF" / "test_generator_2.bdf" test_generator_edf = data_path / "EDF" / "test_generator_2.edf" edf_annot_sub_s_path = data_path / "EDF" / "subsecond_starttime.edf" edf_chtypes_path = data_path / "EDF" / "chtypes_edf.edf" edf_utf8_annotations = data_path / "EDF" / "test_utf8_annotations.edf" -eog = ['REOG', 'LEOG', 'IEOG'] -misc = ['EXG1', 'EXG5', 'EXG8', 'M1', 'M2'] +eog = ["REOG", "LEOG", "IEOG"] +misc = ["EXG1", "EXG5", "EXG8", "M1", "M2"] def test_orig_units(): @@ -72,16 +77,16 @@ def test_orig_units(): # Test original units orig_units = raw._orig_units assert len(orig_units) == len(raw.ch_names) - assert orig_units['A1'] == 'µV' # formerly 'uV' edit by _check_orig_units + assert orig_units["A1"] == "µV" # formerly 'uV' edit by _check_orig_units del orig_units - raw.rename_channels(dict(A1='AA')) - assert raw._orig_units['AA'] == 'µV' - raw.rename_channels(dict(AA='A1')) + raw.rename_channels(dict(A1="AA")) + assert raw._orig_units["AA"] == "µV" + raw.rename_channels(dict(AA="A1")) raw_back = raw.copy().pick(raw.ch_names[:1]) # _pick_drop_channels - assert raw_back.ch_names == ['A1'] - assert set(raw_back._orig_units) == {'A1'} + assert raw_back.ch_names == ["A1"] + assert set(raw_back._orig_units) == {"A1"} raw_back.add_channels([raw.copy().pick(raw.ch_names[1:])]) assert raw_back.ch_names == raw.ch_names assert set(raw_back._orig_units) == set(raw.ch_names) @@ -91,41 +96,41 @@ def test_orig_units(): def test_units_params(): """Test enforcing original channel units.""" - with pytest.raises(ValueError, - match=r"Unit for channel .* is present .* cannot " - "overwrite it"): - _ = read_raw_edf(edf_path, units='V', preload=True) + with pytest.raises( + ValueError, match=r"Unit for channel .* is present .* cannot " "overwrite it" + ): + _ = read_raw_edf(edf_path, units="V", preload=True) def test_edf_temperature(monkeypatch): """Test that we can parse temperature channel type.""" raw = read_raw_edf(edf_path) - assert raw.get_channel_types()[0] == 'eeg' + assert raw.get_channel_types()[0] == "eeg" def _first_chan_temp(*args, **kwargs): out, orig_units = _read_edf_header(*args, **kwargs) - out['ch_types'][0] = 'TEMP' + out["ch_types"][0] = "TEMP" return out, orig_units - monkeypatch.setattr(edf.edf, '_read_edf_header', _first_chan_temp) + monkeypatch.setattr(edf.edf, "_read_edf_header", _first_chan_temp) raw = read_raw_edf(edf_path) - assert 'temperature' in raw - assert raw.get_channel_types()[0] == 'temperature' + assert "temperature" in raw + assert raw.get_channel_types()[0] == "temperature" def test_subject_info(tmp_path): """Test exposure of original channel units.""" raw = read_raw_edf(edf_path) - assert raw.info['subject_info'] is None # XXX this is arguably a bug + assert raw.info["subject_info"] is None # XXX this is arguably a bug edf_info = raw._raw_extras[0] - assert edf_info['subject_info'] is not None - want = {'id': 'X', 'sex': 'X', 'birthday': 'X', 'name': 'X'} + assert edf_info["subject_info"] is not None + want = {"id": "X", "sex": "X", "birthday": "X", "name": "X"} for key, val in want.items(): - assert edf_info['subject_info'][key] == val, key - fname = tmp_path / 'test_raw.fif' + assert edf_info["subject_info"][key] == val, key + fname = tmp_path / "test_raw.fif" raw.save(fname) raw = read_raw_fif(fname) - assert raw.info['subject_info'] is None # XXX should eventually round-trip + assert raw.info["subject_info"] is None # XXX should eventually round-trip def test_bdf_data(): @@ -133,83 +138,97 @@ def test_bdf_data(): # XXX BDF data for these is around 0.01 when it should be in the uV range, # probably some bug test_scaling = False - raw_py = _test_raw_reader(read_raw_bdf, input_fname=bdf_path, - eog=eog, misc=misc, - exclude=['M2', 'IEOG'], - test_scaling=test_scaling, - ) + raw_py = _test_raw_reader( + read_raw_bdf, + input_fname=bdf_path, + eog=eog, + misc=misc, + exclude=["M2", "IEOG"], + test_scaling=test_scaling, + ) assert len(raw_py.ch_names) == 71 - raw_py = _test_raw_reader(read_raw_bdf, input_fname=bdf_path, - montage='biosemi64', eog=eog, misc=misc, - exclude=['M2', 'IEOG'], - test_scaling=test_scaling) + raw_py = _test_raw_reader( + read_raw_bdf, + input_fname=bdf_path, + montage="biosemi64", + eog=eog, + misc=misc, + exclude=["M2", "IEOG"], + test_scaling=test_scaling, + ) assert len(raw_py.ch_names) == 71 - assert 'RawEDF' in repr(raw_py) - picks = pick_types(raw_py.info, meg=False, eeg=True, exclude='bads') + assert "RawEDF" in repr(raw_py) + picks = pick_types(raw_py.info, meg=False, eeg=True, exclude="bads") data_py, _ = raw_py[picks] # this .mat was generated using the EEG Lab Biosemi Reader raw_eeglab = loadmat(bdf_eeglab_path) - raw_eeglab = raw_eeglab['data'] * 1e-6 # data are stored in microvolts + raw_eeglab = raw_eeglab["data"] * 1e-6 # data are stored in microvolts data_eeglab = raw_eeglab[picks] # bdf saved as a single, resolution to seven decimal points in matlab assert_array_almost_equal(data_py, data_eeglab, 8) # Manually checking that float coordinates are imported - assert (raw_py.info['chs'][0]['loc']).any() - assert (raw_py.info['chs'][25]['loc']).any() - assert (raw_py.info['chs'][63]['loc']).any() + assert (raw_py.info["chs"][0]["loc"]).any() + assert (raw_py.info["chs"][25]["loc"]).any() + assert (raw_py.info["chs"][63]["loc"]).any() @testing.requires_testing_data def test_bdf_crop_save_stim_channel(tmp_path): """Test EDF with various sampling rates.""" raw = read_raw_bdf(bdf_stim_channel_path) - raw.save(tmp_path / 'test-raw.fif', tmin=1.2, tmax=4.0, overwrite=True) + raw.save(tmp_path / "test-raw.fif", tmin=1.2, tmax=4.0, overwrite=True) @testing.requires_testing_data -@pytest.mark.parametrize('fname', [ - edf_reduced, - edf_overlap_annot_path, -]) -@pytest.mark.parametrize('stim_channel', (None, False, 'auto')) +@pytest.mark.parametrize( + "fname", + [ + edf_reduced, + edf_overlap_annot_path, + ], +) +@pytest.mark.parametrize("stim_channel", (None, False, "auto")) def test_edf_others(fname, stim_channel): """Test EDF with various sampling rates and overlapping annotations.""" _test_raw_reader( - read_raw_edf, input_fname=fname, stim_channel=stim_channel, - verbose='error') + read_raw_edf, input_fname=fname, stim_channel=stim_channel, verbose="error" + ) def test_edf_data_broken(tmp_path): """Test edf files.""" - raw = _test_raw_reader(read_raw_edf, input_fname=edf_path, - exclude=['Ergo-Left', 'H10'], verbose='error') + raw = _test_raw_reader( + read_raw_edf, + input_fname=edf_path, + exclude=["Ergo-Left", "H10"], + verbose="error", + ) raw_py = read_raw_edf(edf_path) data = raw_py.get_data() assert_equal(len(raw.ch_names) + 2, len(raw_py.ch_names)) # Test with number of records not in header (-1). broken_fname = tmp_path / "broken.edf" - with open(edf_path, 'rb') as fid_in: + with open(edf_path, "rb") as fid_in: fid_in.seek(0, 2) n_bytes = fid_in.tell() fid_in.seek(0, 0) rbytes = fid_in.read() - with open(broken_fname, 'wb') as fid_out: + with open(broken_fname, "wb") as fid_out: fid_out.write(rbytes[:236]) - fid_out.write(b'-1 ') - fid_out.write(rbytes[244:244 + int(n_bytes * 0.4)]) - with pytest.warns(RuntimeWarning, - match='records .* not match the file size'): + fid_out.write(b"-1 ") + fid_out.write(rbytes[244 : 244 + int(n_bytes * 0.4)]) + with pytest.warns(RuntimeWarning, match="records .* not match the file size"): raw = read_raw_edf(broken_fname, preload=True) read_raw_edf(broken_fname, exclude=raw.ch_names[:132], preload=True) # Test with \x00's in the data - with open(broken_fname, 'wb') as fid_out: + with open(broken_fname, "wb") as fid_out: fid_out.write(rbytes[:184]) - assert rbytes[184:192] == b'36096 ' - fid_out.write(rbytes[184:192].replace(b' ', b'\x00')) + assert rbytes[184:192] == b"36096 " + fid_out.write(rbytes[184:192].replace(b" ", b"\x00")) fid_out.write(rbytes[192:]) raw_py = read_raw_edf(broken_fname) data_new = raw_py.get_data() @@ -218,8 +237,8 @@ def test_edf_data_broken(tmp_path): def test_duplicate_channel_labels_edf(): """Test reading edf file with duplicate channel names.""" - EXPECTED_CHANNEL_NAMES = ['EEG F1-Ref-0', 'EEG F2-Ref', 'EEG F1-Ref-1'] - with pytest.warns(RuntimeWarning, match='Channel names are not unique'): + EXPECTED_CHANNEL_NAMES = ["EEG F1-Ref-0", "EEG F2-Ref", "EEG F1-Ref-1"] + with pytest.warns(RuntimeWarning, match="Channel names are not unique"): raw = read_raw_edf(duplicate_channel_labels_path, preload=False) assert raw.ch_names == EXPECTED_CHANNEL_NAMES @@ -228,31 +247,44 @@ def test_duplicate_channel_labels_edf(): def test_parse_annotation(tmp_path): """Test parsing the tal channel.""" # test the parser - annot = (b'+180\x14Lights off\x14Close door\x14\x00\x00\x00\x00\x00' - b'+180\x14Lights off\x14\x00\x00\x00\x00\x00\x00\x00\x00' - b'+180\x14Close door\x14\x00\x00\x00\x00\x00\x00\x00\x00' - b'+3.14\x1504.20\x14nothing\x14\x00\x00\x00\x00' - b'+1800.2\x1525.5\x14Apnea\x14\x00\x00\x00\x00\x00\x00\x00' - b'+123\x14\x14\x00\x00\x00\x00\x00\x00\x00') - annot_file = tmp_path / 'annotations.txt' + annot = ( + b"+180\x14Lights off\x14Close door\x14\x00\x00\x00\x00\x00" + b"+180\x14Lights off\x14\x00\x00\x00\x00\x00\x00\x00\x00" + b"+180\x14Close door\x14\x00\x00\x00\x00\x00\x00\x00\x00" + b"+3.14\x1504.20\x14nothing\x14\x00\x00\x00\x00" + b"+1800.2\x1525.5\x14Apnea\x14\x00\x00\x00\x00\x00\x00\x00" + b"+123\x14\x14\x00\x00\x00\x00\x00\x00\x00" + ) + annot_file = tmp_path / "annotations.txt" with open(annot_file, "wb") as f: f.write(annot) annot = [a for a in bytes(annot)] annot[1::2] = [a * 256 for a in annot[1::2]] - tal_channel_A = np.array(list(map(sum, zip(annot[0::2], annot[1::2]))), - dtype=np.int64) + tal_channel_A = np.array( + list(map(sum, zip(annot[0::2], annot[1::2]))), dtype=np.int64 + ) - with open(annot_file, 'rb') as fid: + with open(annot_file, "rb") as fid: # ch_data = np.fromfile(fid, dtype=' 0: return ll * scale_units @@ -93,14 +104,12 @@ def _eeg_has_montage_information(eeg): if not len(eeg.chanlocs): has_pos = False else: - pos_fields = ['X', 'Y', 'Z'] + pos_fields = ["X", "Y", "Z"] if isinstance(eeg.chanlocs[0], mat_struct): - has_pos = all(hasattr(eeg.chanlocs[0], fld) - for fld in pos_fields) + has_pos = all(hasattr(eeg.chanlocs[0], fld) for fld in pos_fields) elif isinstance(eeg.chanlocs[0], np.ndarray): # Old files - has_pos = all(fld in eeg.chanlocs[0].dtype.names - for fld in pos_fields) + has_pos = all(fld in eeg.chanlocs[0].dtype.names for fld in pos_fields) elif isinstance(eeg.chanlocs[0], dict): # new files has_pos = all(fld in eeg.chanlocs[0] for fld in pos_fields) @@ -110,54 +119,60 @@ def _eeg_has_montage_information(eeg): return has_pos -def _get_montage_information(eeg, get_pos, scale_units=1.): +def _get_montage_information(eeg, get_pos, scale_units=1.0): """Get channel name, type and montage information from ['chanlocs'].""" ch_names, ch_types, pos_ch_names, pos = list(), list(), list(), list() unknown_types = dict() for chanloc in eeg.chanlocs: # channel name - ch_names.append(chanloc['labels']) + ch_names.append(chanloc["labels"]) # channel type - ch_type = 'eeg' - try_type = chanloc.get('type', None) + ch_type = "eeg" + try_type = chanloc.get("type", None) if isinstance(try_type, str): try_type = try_type.strip().lower() if try_type in _PICK_TYPES_KEYS: ch_type = try_type else: if try_type in unknown_types: - unknown_types[try_type].append(chanloc['labels']) + unknown_types[try_type].append(chanloc["labels"]) else: - unknown_types[try_type] = [chanloc['labels']] + unknown_types[try_type] = [chanloc["labels"]] ch_types.append(ch_type) # channel loc if get_pos: - loc_x = _to_loc(chanloc['X'], scale_units=scale_units) - loc_y = _to_loc(chanloc['Y'], scale_units=scale_units) - loc_z = _to_loc(chanloc['Z'], scale_units=scale_units) + loc_x = _to_loc(chanloc["X"], scale_units=scale_units) + loc_y = _to_loc(chanloc["Y"], scale_units=scale_units) + loc_z = _to_loc(chanloc["Z"], scale_units=scale_units) locs = np.r_[-loc_y, loc_x, loc_z] - pos_ch_names.append(chanloc['labels']) + pos_ch_names.append(chanloc["labels"]) pos.append(locs) # warn if unknown types were provided if len(unknown_types): - warn('Unknown types found, setting as type EEG:\n' + - '\n'.join([f'{key}: {sorted(unknown_types[key])}' - for key in sorted(unknown_types)])) + warn( + "Unknown types found, setting as type EEG:\n" + + "\n".join( + [ + f"{key}: {sorted(unknown_types[key])}" + for key in sorted(unknown_types) + ] + ) + ) lpa, rpa, nasion = None, None, None - if hasattr(eeg, "chaninfo") and len(eeg.chaninfo.get('nodatchans', [])): - for item in list(zip(*eeg.chaninfo['nodatchans'].values())): - d = dict(zip(eeg.chaninfo['nodatchans'].keys(), item)) - if d.get("type", None) != 'FID': + if hasattr(eeg, "chaninfo") and len(eeg.chaninfo.get("nodatchans", [])): + for item in list(zip(*eeg.chaninfo["nodatchans"].values())): + d = dict(zip(eeg.chaninfo["nodatchans"].keys(), item)) + if d.get("type", None) != "FID": continue - elif d.get('description', None) == 'Nasion': + elif d.get("description", None) == "Nasion": nasion = np.array([d["X"], d["Y"], d["Z"]]) - elif d.get('description', None) == 'Right periauricular point': + elif d.get("description", None) == "Right periauricular point": rpa = np.array([d["X"], d["Y"], d["Z"]]) - elif d.get('description', None) == 'Left periauricular point': + elif d.get("description", None) == "Left periauricular point": lpa = np.array([d["X"], d["Y"], d["Z"]]) if pos_ch_names: @@ -166,17 +181,21 @@ def _get_montage_information(eeg, get_pos, scale_units=1.): # roughly estimate head radius and check if its reasonable is_nan_pos = np.isnan(pos).all(axis=1) if not is_nan_pos.all(): - mean_radius = np.mean(np.linalg.norm( - pos_array[~is_nan_pos], axis=1)) + mean_radius = np.mean(np.linalg.norm(pos_array[~is_nan_pos], axis=1)) additional_info = ( - ' Check if the montage_units argument is correct (the default ' + " Check if the montage_units argument is correct (the default " 'is "mm", but your channel positions may be in different units' - ').') + ")." + ) _check_head_radius(mean_radius, add_info=additional_info) montage = make_dig_montage( ch_pos=dict(zip(ch_names, pos_array)), - coord_frame='head', lpa=lpa, rpa=rpa, nasion=nasion) + coord_frame="head", + lpa=lpa, + rpa=rpa, + nasion=nasion, + ) _ensure_fiducials_head(montage.dig) else: montage = None @@ -184,7 +203,7 @@ def _get_montage_information(eeg, get_pos, scale_units=1.): return ch_names, ch_types, montage -def _get_info(eeg, eog=(), scale_units=1.): +def _get_info(eeg, eog=(), scale_units=1.0): """Get measurement info.""" # add the ch_names and info['chs'][idx]['loc'] if not isinstance(eeg.chanlocs, np.ndarray) and eeg.nbchan == 1: @@ -197,23 +216,24 @@ def _get_info(eeg, eog=(), scale_units=1.): if eeg_has_ch_names_info: has_pos = _eeg_has_montage_information(eeg) - ch_names, ch_types, eeg_montage = \ - _get_montage_information(eeg, has_pos, scale_units=scale_units) + ch_names, ch_types, eeg_montage = _get_montage_information( + eeg, has_pos, scale_units=scale_units + ) update_ch_names = False else: # if eeg.chanlocs is empty, we still need default chan names ch_names = ["EEG %03d" % ii for ii in range(eeg.nbchan)] - ch_types = 'eeg' + ch_types = "eeg" eeg_montage = None update_ch_names = True info = create_info(ch_names, sfreq=eeg.srate, ch_types=ch_types) - eog = _find_channels(ch_names, ch_type='EOG') if eog == 'auto' else eog - for idx, ch in enumerate(info['chs']): - ch['cal'] = CAL - if ch['ch_name'] in eog or idx in eog: - ch['coil_type'] = FIFF.FIFFV_COIL_NONE - ch['kind'] = FIFF.FIFFV_EOG_CH + eog = _find_channels(ch_names, ch_type="EOG") if eog == "auto" else eog + for idx, ch in enumerate(info["chs"]): + ch["cal"] = CAL + if ch["ch_name"] in eog or idx in eog: + ch["coil_type"] = FIFF.FIFFV_COIL_NONE + ch["kind"] = FIFF.FIFFV_EOG_CH return info, eeg_montage, update_ch_names @@ -228,29 +248,34 @@ def _set_dig_montage_in_init(self, montage): self.set_montage(None) else: missing_channels = set(self.ch_names) - set(montage.ch_names) - ch_pos = dict(zip( - list(missing_channels), - np.full((len(missing_channels), 3), np.nan) - )) - self.set_montage( - montage + make_dig_montage(ch_pos=ch_pos, coord_frame='head') + ch_pos = dict( + zip(list(missing_channels), np.full((len(missing_channels), 3), np.nan)) ) + self.set_montage(montage + make_dig_montage(ch_pos=ch_pos, coord_frame="head")) def _handle_montage_units(montage_units): n_char_unit = len(montage_units) - if montage_units[-1:] != 'm' or n_char_unit > 2: - raise ValueError('``montage_units`` has to be in prefix + "m" format' - f', got "{montage_units}"') + if montage_units[-1:] != "m" or n_char_unit > 2: + raise ValueError( + '``montage_units`` has to be in prefix + "m" format' + f', got "{montage_units}"' + ) prefix = montage_units[:-1] - scale_units = 1 / DEFAULTS['prefixes'][prefix] + scale_units = 1 / DEFAULTS["prefixes"][prefix] return scale_units @fill_doc -def read_raw_eeglab(input_fname, eog=(), preload=False, - uint16_codec=None, montage_units='mm', verbose=None): +def read_raw_eeglab( + input_fname, + eog=(), + preload=False, + uint16_codec=None, + montage_units="mm", + verbose=None, +): r"""Read an EEGLAB .set file. Parameters @@ -283,15 +308,27 @@ def read_raw_eeglab(input_fname, eog=(), preload=False, ----- .. versionadded:: 0.11.0 """ - return RawEEGLAB(input_fname=input_fname, preload=preload, - eog=eog, uint16_codec=uint16_codec, - montage_units=montage_units, verbose=verbose) + return RawEEGLAB( + input_fname=input_fname, + preload=preload, + eog=eog, + uint16_codec=uint16_codec, + montage_units=montage_units, + verbose=verbose, + ) @fill_doc -def read_epochs_eeglab(input_fname, events=None, event_id=None, - eog=(), *, uint16_codec=None, montage_units='mm', - verbose=None): +def read_epochs_eeglab( + input_fname, + events=None, + event_id=None, + eog=(), + *, + uint16_codec=None, + montage_units="mm", + verbose=None, +): r"""Reader function for EEGLAB epochs files. Parameters @@ -337,9 +374,15 @@ def read_epochs_eeglab(input_fname, events=None, event_id=None, ----- .. versionadded:: 0.11.0 """ - epochs = EpochsEEGLAB(input_fname=input_fname, events=events, eog=eog, - event_id=event_id, uint16_codec=uint16_codec, - montage_units=montage_units, verbose=verbose) + epochs = EpochsEEGLAB( + input_fname=input_fname, + events=events, + eog=eog, + event_id=event_id, + uint16_codec=uint16_codec, + montage_units=montage_units, + verbose=verbose, + ) return epochs @@ -373,17 +416,24 @@ class RawEEGLAB(BaseRaw): """ @verbose - def __init__(self, input_fname, eog=(), - preload=False, *, uint16_codec=None, montage_units='mm', - verbose=None): # noqa: D102 - input_fname = str( - _check_fname(input_fname, "read", True, "input_fname") - ) + def __init__( + self, + input_fname, + eog=(), + preload=False, + *, + uint16_codec=None, + montage_units="mm", + verbose=None, + ): # noqa: D102 + input_fname = str(_check_fname(input_fname, "read", True, "input_fname")) eeg = _check_load_mat(input_fname, uint16_codec) if eeg.trials != 1: - raise TypeError('The number of trials is %d. It must be 1 for raw' - ' files. Please use `mne.io.read_epochs_eeglab` if' - ' the .set file contains epochs.' % eeg.trials) + raise TypeError( + "The number of trials is %d. It must be 1 for raw" + " files. Please use `mne.io.read_epochs_eeglab` if" + " the .set file contains epochs." % eeg.trials + ) last_samps = [eeg.pnts - 1] scale_units = _handle_montage_units(montage_units) @@ -392,16 +442,23 @@ def __init__(self, input_fname, eog=(), # read the data if isinstance(eeg.data, str): data_fname = _check_eeglab_fname(input_fname, eeg.data) - logger.info('Reading %s' % data_fname) + logger.info("Reading %s" % data_fname) super(RawEEGLAB, self).__init__( - info, preload, filenames=[data_fname], last_samps=last_samps, - orig_format='double', verbose=verbose) + info, + preload, + filenames=[data_fname], + last_samps=last_samps, + orig_format="double", + verbose=verbose, + ) else: if preload is False or isinstance(preload, str): - warn('Data will be preloaded. preload=False or a string ' - 'preload is not supported when the data is stored in ' - 'the .set file') + warn( + "Data will be preloaded. preload=False or a string " + "preload is not supported when the data is stored in " + "the .set file" + ) # can't be done in standard way with preload=True because of # different reading path (.set file) if eeg.nbchan == 1 and len(eeg.data.shape) == 1: @@ -412,8 +469,13 @@ def __init__(self, input_fname, eog=(), data[:n_chan] = eeg.data data *= CAL super(RawEEGLAB, self).__init__( - info, data, filenames=[input_fname], last_samps=last_samps, - orig_format='double', verbose=verbose) + info, + data, + filenames=[input_fname], + last_samps=last_samps, + orig_format="double", + verbose=verbose, + ) # create event_ch from annotations annot = read_annotations(input_fname) @@ -422,13 +484,12 @@ def __init__(self, input_fname, eog=(), _set_dig_montage_in_init(self, eeg_montage) - latencies = np.round(annot.onset * self.info['sfreq']) + latencies = np.round(annot.onset * self.info["sfreq"]) _check_latencies(latencies) def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): """Read a chunk of raw data.""" - _read_segments_file( - self, data, idx, fi, start, stop, cals, mult, dtype=' 1: # first extract the events and construct an event_id dict @@ -533,7 +609,7 @@ def __init__(self, input_fname, events=None, event_id=None, tmin=0, if isinstance(ep.eventtype, (int, float)): ep.eventtype = str(ep.eventtype) if not isinstance(ep.eventtype, str): - event_type = '/'.join([str(et) for et in ep.eventtype]) + event_type = "/".join([str(et) for et in ep.eventtype]) event_name.append(event_type) # store latency of only first event event_latencies.append(events[ev_idx].latency) @@ -553,16 +629,17 @@ def __init__(self, input_fname, events=None, event_id=None, tmin=0, # warn about multiple events in epoch if necessary if warn_multiple_events: - warn('At least one epoch has multiple events. Only the latency' - ' of the first event will be retained.') + warn( + "At least one epoch has multiple events. Only the latency" + " of the first event will be retained." + ) # now fill up the event array events = np.zeros((eeg.trials, 3), dtype=int) for idx in range(0, eeg.trials): if idx == 0: prev_stim = 0 - elif (idx > 0 and - event_latencies[idx] - event_latencies[idx - 1] == 1): + elif idx > 0 and event_latencies[idx] - event_latencies[idx - 1] == 1: prev_stim = event_id[event_name[idx - 1]] events[idx, 0] = event_latencies[idx] events[idx, 1] = prev_stim @@ -570,64 +647,81 @@ def __init__(self, input_fname, events=None, event_id=None, tmin=0, elif isinstance(events, (str, Path, PathLike)): events = read_events(events) - logger.info('Extracting parameters from %s...' % input_fname) + logger.info("Extracting parameters from %s..." % input_fname) scale_units = _handle_montage_units(montage_units) info, eeg_montage, _ = _get_info(eeg, eog=eog, scale_units=scale_units) for key, val in event_id.items(): if val not in events[:, 2]: - raise ValueError('No matching events found for %s ' - '(event id %i)' % (key, val)) + raise ValueError( + "No matching events found for %s " "(event id %i)" % (key, val) + ) if isinstance(eeg.data, str): data_fname = _check_eeglab_fname(input_fname, eeg.data) - with open(data_fname, 'rb') as data_fid: + with open(data_fname, "rb") as data_fid: data = np.fromfile(data_fid, dtype=np.float32) - data = data.reshape((eeg.nbchan, eeg.pnts, eeg.trials), - order="F") + data = data.reshape((eeg.nbchan, eeg.pnts, eeg.trials), order="F") else: data = eeg.data if eeg.nbchan == 1 and len(data.shape) == 2: data = data[np.newaxis, :] - data = data.transpose((2, 0, 1)).astype('double') + data = data.transpose((2, 0, 1)).astype("double") data *= CAL assert data.shape == (eeg.trials, eeg.nbchan, eeg.pnts) tmin, tmax = eeg.xmin, eeg.xmax super(EpochsEEGLAB, self).__init__( - info, data, events, event_id, tmin, tmax, baseline, - reject=reject, flat=flat, reject_tmin=reject_tmin, - reject_tmax=reject_tmax, filename=input_fname, verbose=verbose) + info, + data, + events, + event_id, + tmin, + tmax, + baseline, + reject=reject, + flat=flat, + reject_tmin=reject_tmin, + reject_tmax=reject_tmax, + filename=input_fname, + verbose=verbose, + ) # data are preloaded but _bad_dropped is not set so we do it here: self._bad_dropped = True _set_dig_montage_in_init(self, eeg_montage) - logger.info('Ready.') + logger.info("Ready.") def _check_boundary(annot, event_id): if event_id is None: event_id = dict() if "boundary" in annot.description and "boundary" not in event_id: - warn("The data contains 'boundary' events, indicating data " - "discontinuities. Be cautious of filtering and epoching around " - "these events.") + warn( + "The data contains 'boundary' events, indicating data " + "discontinuities. Be cautious of filtering and epoching around " + "these events." + ) def _check_latencies(latencies): if (latencies < -1).any(): - raise ValueError('At least one event sample index is negative. Please' - ' check if EEG.event.sample values are correct.') + raise ValueError( + "At least one event sample index is negative. Please" + " check if EEG.event.sample values are correct." + ) if (latencies == -1).any(): - warn("At least one event has a sample index of -1. This usually is " - "a consequence of how eeglab handles event latency after " - "resampling - especially when you had a boundary event at the " - "beginning of the file. Please make sure that the events at " - "the very beginning of your EEGLAB file can be safely dropped " - "(e.g., because they are boundary events).") + warn( + "At least one event has a sample index of -1. This usually is " + "a consequence of how eeglab handles event latency after " + "resampling - especially when you had a boundary event at the " + "beginning of the file. Please make sure that the events at " + "the very beginning of your EEGLAB file can be safely dropped " + "(e.g., because they are boundary events)." + ) def _bunchify(items): @@ -663,10 +757,9 @@ def _read_annotations_eeglab(eeg, uint16_codec=None): if isinstance(eeg, str): eeg = _check_load_mat(eeg, uint16_codec=uint16_codec) - if not hasattr(eeg, 'event'): + if not hasattr(eeg, "event"): events = [] - elif isinstance(eeg.event, dict) and \ - np.array(eeg.event['latency']).ndim > 0: + elif isinstance(eeg.event, dict) and np.array(eeg.event["latency"]).ndim > 0: events = _dol_to_lod(eeg.event) elif not isinstance(eeg.event, (np.ndarray, list)): events = [eeg.event] @@ -676,20 +769,25 @@ def _read_annotations_eeglab(eeg, uint16_codec=None): description = [str(event.type) for event in events] onset = [event.latency - 1 for event in events] duration = np.zeros(len(onset)) - if len(events) > 0 and hasattr(events[0], 'duration'): + if len(events) > 0 and hasattr(events[0], "duration"): for idx, event in enumerate(events): # empty duration fields are read as empty arrays - is_empty_array = (isinstance(event.duration, np.ndarray) - and len(event.duration) == 0) + is_empty_array = ( + isinstance(event.duration, np.ndarray) and len(event.duration) == 0 + ) duration[idx] = np.nan if is_empty_array else event.duration - return Annotations(onset=np.array(onset) / eeg.srate, - duration=duration / eeg.srate, - description=description, - orig_time=None) + return Annotations( + onset=np.array(onset) / eeg.srate, + duration=duration / eeg.srate, + description=description, + orig_time=None, + ) def _dol_to_lod(dol): """Convert a dict of lists to a list of dicts.""" - return [{key: dol[key][ii] for key in dol.keys()} - for ii in range(len(dol[list(dol.keys())[0]]))] + return [ + {key: dol[key][ii] for key in dol.keys()} + for ii in range(len(dol[list(dol.keys())[0]])) + ] diff --git a/mne/io/eeglab/tests/test_eeglab.py b/mne/io/eeglab/tests/test_eeglab.py index 8c2a966253e..590cc4873aa 100644 --- a/mne/io/eeglab/tests/test_eeglab.py +++ b/mne/io/eeglab/tests/test_eeglab.py @@ -9,8 +9,12 @@ from copy import deepcopy import numpy as np -from numpy.testing import (assert_array_equal, assert_array_almost_equal, - assert_equal, assert_allclose) +from numpy.testing import ( + assert_array_equal, + assert_array_almost_equal, + assert_equal, + assert_allclose, +) import pytest from scipy import io @@ -46,54 +50,55 @@ @testing.requires_testing_data -@pytest.mark.parametrize('fname', [ - raw_fname_mat, - pytest.param( - raw_fname_h5, - marks=[ - pytest.mark.skipif( - not _check_pymatreader_installed(strict=False), - reason='pymatreader not installed' - ) - ] - ), - raw_fname_chanloc, -], ids=os.path.basename) +@pytest.mark.parametrize( + "fname", + [ + raw_fname_mat, + pytest.param( + raw_fname_h5, + marks=[ + pytest.mark.skipif( + not _check_pymatreader_installed(strict=False), + reason="pymatreader not installed", + ) + ], + ), + raw_fname_chanloc, + ], + ids=os.path.basename, +) def test_io_set_raw(fname): """Test importing EEGLAB .set files.""" montage = read_custom_montage(montage_path) - montage.ch_names = [ - 'EEG {0:03d}'.format(ii) for ii in range(len(montage.ch_names)) - ] + montage.ch_names = ["EEG {0:03d}".format(ii) for ii in range(len(montage.ch_names))] kws = dict(reader=read_raw_eeglab, input_fname=fname) if fname.name == "test_raw_chanloc.set": - with pytest.warns(RuntimeWarning, - match="The data contains 'boundary' events"): + with pytest.warns(RuntimeWarning, match="The data contains 'boundary' events"): raw0 = _test_raw_reader(**kws) - elif '_h5' in fname.name: # should be safe enough, and much faster + elif "_h5" in fname.name: # should be safe enough, and much faster raw0 = read_raw_eeglab(fname, preload=True) else: raw0 = _test_raw_reader(**kws) # test that preloading works if fname.name == "test_raw_chanloc.set": - raw0.set_montage(montage, on_missing='ignore') + raw0.set_montage(montage, on_missing="ignore") # crop to check if the data has been properly preloaded; we cannot # filter as the snippet of raw data is very short raw0.crop(0, 1) else: raw0.set_montage(montage) - raw0.filter(1, None, l_trans_bandwidth='auto', filter_length='auto', - phase='zero') + raw0.filter( + 1, None, l_trans_bandwidth="auto", filter_length="auto", phase="zero" + ) # test that using uint16_codec does not break stuff - read_raw_kws = dict(input_fname=fname, preload=False, uint16_codec='ascii') + read_raw_kws = dict(input_fname=fname, preload=False, uint16_codec="ascii") if fname.name == "test_raw_chanloc.set": - with pytest.warns(RuntimeWarning, - match="The data contains 'boundary' events"): + with pytest.warns(RuntimeWarning, match="The data contains 'boundary' events"): raw0 = read_raw_eeglab(**read_raw_kws) - raw0.set_montage(montage, on_missing='ignore') + raw0.set_montage(montage, on_missing="ignore") else: raw0 = read_raw_eeglab(**read_raw_kws) raw0.set_montage(montage) @@ -101,27 +106,36 @@ def test_io_set_raw(fname): # Annotations if fname != raw_fname_chanloc: assert len(raw0.annotations) == 154 - assert set(raw0.annotations.description) == {'rt', 'square'} - assert_array_equal(raw0.annotations.duration, 0.) + assert set(raw0.annotations.description) == {"rt", "square"} + assert_array_equal(raw0.annotations.duration, 0.0) @testing.requires_testing_data def test_io_set_raw_more(tmp_path): """Test importing EEGLAB .set files.""" - eeg = io.loadmat(raw_fname_mat, struct_as_record=False, - squeeze_me=True)['EEG'] + eeg = io.loadmat(raw_fname_mat, struct_as_record=False, squeeze_me=True)["EEG"] # test reading file with one event (read old version) negative_latency_fname = tmp_path / "test_negative_latency.set" events = deepcopy(eeg.event[0]) events.latency = 0 - io.savemat(negative_latency_fname, - {'EEG': {'trials': eeg.trials, 'srate': eeg.srate, - 'nbchan': eeg.nbchan, - 'data': 'test_negative_latency.fdt', - 'epoch': eeg.epoch, 'event': events, - 'chanlocs': eeg.chanlocs, 'pnts': eeg.pnts}}, - appendmat=False, oned_as='row') + io.savemat( + negative_latency_fname, + { + "EEG": { + "trials": eeg.trials, + "srate": eeg.srate, + "nbchan": eeg.nbchan, + "data": "test_negative_latency.fdt", + "epoch": eeg.epoch, + "event": events, + "chanlocs": eeg.chanlocs, + "pnts": eeg.pnts, + } + }, + appendmat=False, + oned_as="row", + ) shutil.copyfile( base_dir / "test_raw.fdt", negative_latency_fname.with_suffix(".fdt") ) @@ -130,92 +144,147 @@ def test_io_set_raw_more(tmp_path): # test negative event latencies events.latency = -1 - io.savemat(negative_latency_fname, - {'EEG': {'trials': eeg.trials, 'srate': eeg.srate, - 'nbchan': eeg.nbchan, - 'data': 'test_negative_latency.fdt', - 'epoch': eeg.epoch, 'event': events, - 'chanlocs': eeg.chanlocs, 'pnts': eeg.pnts}}, - appendmat=False, oned_as='row') - with pytest.raises(ValueError, match='event sample index is negative'): + io.savemat( + negative_latency_fname, + { + "EEG": { + "trials": eeg.trials, + "srate": eeg.srate, + "nbchan": eeg.nbchan, + "data": "test_negative_latency.fdt", + "epoch": eeg.epoch, + "event": events, + "chanlocs": eeg.chanlocs, + "pnts": eeg.pnts, + } + }, + appendmat=False, + oned_as="row", + ) + with pytest.raises(ValueError, match="event sample index is negative"): with pytest.warns(RuntimeWarning, match="has a sample index of -1."): read_raw_eeglab(input_fname=negative_latency_fname, preload=True) # test overlapping events overlap_fname = tmp_path / "test_overlap_event.set" - io.savemat(overlap_fname, - {'EEG': {'trials': eeg.trials, 'srate': eeg.srate, - 'nbchan': eeg.nbchan, 'data': 'test_overlap_event.fdt', - 'epoch': eeg.epoch, - 'event': [eeg.event[0], eeg.event[0]], - 'chanlocs': eeg.chanlocs, 'pnts': eeg.pnts}}, - appendmat=False, oned_as='row') - shutil.copyfile( - base_dir / "test_raw.fdt", overlap_fname.with_suffix(".fdt") + io.savemat( + overlap_fname, + { + "EEG": { + "trials": eeg.trials, + "srate": eeg.srate, + "nbchan": eeg.nbchan, + "data": "test_overlap_event.fdt", + "epoch": eeg.epoch, + "event": [eeg.event[0], eeg.event[0]], + "chanlocs": eeg.chanlocs, + "pnts": eeg.pnts, + } + }, + appendmat=False, + oned_as="row", ) + shutil.copyfile(base_dir / "test_raw.fdt", overlap_fname.with_suffix(".fdt")) read_raw_eeglab(input_fname=overlap_fname, preload=True) # test reading file with empty event durations empty_dur_fname = tmp_path / "test_empty_durations.set" events = deepcopy(eeg.event) for ev in events: - ev.duration = np.array([], dtype='float') - - io.savemat(empty_dur_fname, - {'EEG': {'trials': eeg.trials, 'srate': eeg.srate, - 'nbchan': eeg.nbchan, - 'data': 'test_negative_latency.fdt', - 'epoch': eeg.epoch, 'event': events, - 'chanlocs': eeg.chanlocs, 'pnts': eeg.pnts}}, - appendmat=False, oned_as='row') - shutil.copyfile( - base_dir / "test_raw.fdt", empty_dur_fname.with_suffix(".fdt") + ev.duration = np.array([], dtype="float") + + io.savemat( + empty_dur_fname, + { + "EEG": { + "trials": eeg.trials, + "srate": eeg.srate, + "nbchan": eeg.nbchan, + "data": "test_negative_latency.fdt", + "epoch": eeg.epoch, + "event": events, + "chanlocs": eeg.chanlocs, + "pnts": eeg.pnts, + } + }, + appendmat=False, + oned_as="row", ) + shutil.copyfile(base_dir / "test_raw.fdt", empty_dur_fname.with_suffix(".fdt")) raw = read_raw_eeglab(input_fname=empty_dur_fname, preload=True) assert (raw.annotations.duration == 0).all() # test reading file when the EEG.data name is wrong - io.savemat(overlap_fname, - {'EEG': {'trials': eeg.trials, 'srate': eeg.srate, - 'nbchan': eeg.nbchan, 'data': 'test_overla_event.fdt', - 'epoch': eeg.epoch, - 'event': [eeg.event[0], eeg.event[0]], - 'chanlocs': eeg.chanlocs, 'pnts': eeg.pnts}}, - appendmat=False, oned_as='row') + io.savemat( + overlap_fname, + { + "EEG": { + "trials": eeg.trials, + "srate": eeg.srate, + "nbchan": eeg.nbchan, + "data": "test_overla_event.fdt", + "epoch": eeg.epoch, + "event": [eeg.event[0], eeg.event[0]], + "chanlocs": eeg.chanlocs, + "pnts": eeg.pnts, + } + }, + appendmat=False, + oned_as="row", + ) with pytest.warns(RuntimeWarning, match="must have changed on disk"): read_raw_eeglab(input_fname=overlap_fname, preload=True) # raise error when both EEG.data and fdt name from set are wrong overlap_fname = tmp_path / "test_ovrlap_event.set" - io.savemat(overlap_fname, - {'EEG': {'trials': eeg.trials, 'srate': eeg.srate, - 'nbchan': eeg.nbchan, 'data': 'test_overla_event.fdt', - 'epoch': eeg.epoch, - 'event': [eeg.event[0], eeg.event[0]], - 'chanlocs': eeg.chanlocs, 'pnts': eeg.pnts}}, - appendmat=False, oned_as='row') + io.savemat( + overlap_fname, + { + "EEG": { + "trials": eeg.trials, + "srate": eeg.srate, + "nbchan": eeg.nbchan, + "data": "test_overla_event.fdt", + "epoch": eeg.epoch, + "event": [eeg.event[0], eeg.event[0]], + "chanlocs": eeg.chanlocs, + "pnts": eeg.pnts, + } + }, + appendmat=False, + oned_as="row", + ) with pytest.raises(FileNotFoundError, match="not find the .fdt data file"): read_raw_eeglab(input_fname=overlap_fname, preload=True) # test reading file with one channel one_chan_fname = tmp_path / "test_one_channel.set" - io.savemat(one_chan_fname, - {'EEG': {'trials': eeg.trials, 'srate': eeg.srate, - 'nbchan': 1, 'data': np.random.random((1, 3)), - 'epoch': eeg.epoch, 'event': eeg.epoch, - 'chanlocs': {'labels': 'E1', 'Y': -6.6069, - 'X': 6.3023, 'Z': -2.9423}, - 'times': eeg.times[:3], 'pnts': 3}}, - appendmat=False, oned_as='row') - read_raw_eeglab(input_fname=one_chan_fname, preload=True, - montage_units='cm') + io.savemat( + one_chan_fname, + { + "EEG": { + "trials": eeg.trials, + "srate": eeg.srate, + "nbchan": 1, + "data": np.random.random((1, 3)), + "epoch": eeg.epoch, + "event": eeg.epoch, + "chanlocs": {"labels": "E1", "Y": -6.6069, "X": 6.3023, "Z": -2.9423}, + "times": eeg.times[:3], + "pnts": 3, + } + }, + appendmat=False, + oned_as="row", + ) + read_raw_eeglab(input_fname=one_chan_fname, preload=True, montage_units="cm") # test reading file with 3 channels - one without position information # first, create chanlocs structured array - ch_names = ['F3', 'unknown', 'FPz'] - x, y, z = [1., 2., np.nan], [4., 5., np.nan], [7., 8., np.nan] - dt = [('labels', 'S10'), ('X', 'f8'), ('Y', 'f8'), ('Z', 'f8')] - nopos_dt = [('labels', 'S10'), ('Z', 'f8')] + ch_names = ["F3", "unknown", "FPz"] + x, y, z = [1.0, 2.0, np.nan], [4.0, 5.0, np.nan], [7.0, 8.0, np.nan] + dt = [("labels", "S10"), ("X", "f8"), ("Y", "f8"), ("Z", "f8")] + nopos_dt = [("labels", "S10"), ("Z", "f8")] chanlocs = np.zeros((3,), dtype=dt) nopos_chanlocs = np.zeros((3,), dtype=nopos_dt) for ind, vals in enumerate(zip(ch_names, x, y, z)): @@ -230,44 +299,59 @@ def test_io_set_raw_more(tmp_path): # test reading channel names but not positions when there is no X (only Z) # field in the EEG.chanlocs structure nopos_fname = tmp_path / "test_no_chanpos.set" - io.savemat(nopos_fname, - {'EEG': {'trials': eeg.trials, 'srate': eeg.srate, 'nbchan': 3, - 'data': np.random.random((3, 2)), 'epoch': eeg.epoch, - 'event': eeg.epoch, 'chanlocs': nopos_chanlocs, - 'times': eeg.times[:2], 'pnts': 2}}, - appendmat=False, oned_as='row') + io.savemat( + nopos_fname, + { + "EEG": { + "trials": eeg.trials, + "srate": eeg.srate, + "nbchan": 3, + "data": np.random.random((3, 2)), + "epoch": eeg.epoch, + "event": eeg.epoch, + "chanlocs": nopos_chanlocs, + "times": eeg.times[:2], + "pnts": 2, + } + }, + appendmat=False, + oned_as="row", + ) # load the file - raw = read_raw_eeglab(input_fname=nopos_fname, preload=True, - montage_units='cm') + raw = read_raw_eeglab(input_fname=nopos_fname, preload=True, montage_units="cm") # test that channel names have been loaded but not channel positions for i in range(3): - assert_equal(raw.info['chs'][i]['ch_name'], ch_names[i]) - assert_array_equal(raw.info['chs'][i]['loc'][:3], - np.array([np.nan, np.nan, np.nan])) + assert_equal(raw.info["chs"][i]["ch_name"], ch_names[i]) + assert_array_equal( + raw.info["chs"][i]["loc"][:3], np.array([np.nan, np.nan, np.nan]) + ) @pytest.mark.timeout(60) # ~60 s on Travis OSX @testing.requires_testing_data -@pytest.mark.parametrize('fnames', [ - epochs_mat_fnames, - pytest.param( - epochs_h5_fnames, - marks=[ - pytest.mark.slowtest, - pytest.mark.skipif( - not _check_pymatreader_installed(strict=False), - reason='pymatreader not installed' - ) - ] - ) -]) +@pytest.mark.parametrize( + "fnames", + [ + epochs_mat_fnames, + pytest.param( + epochs_h5_fnames, + marks=[ + pytest.mark.slowtest, + pytest.mark.skipif( + not _check_pymatreader_installed(strict=False), + reason="pymatreader not installed", + ), + ], + ), + ], +) def test_io_set_epochs(fnames): """Test importing EEGLAB .set epochs files.""" epochs_fname, epochs_fname_onefile = fnames - with pytest.warns(RuntimeWarning, match='multiple events'): + with pytest.warns(RuntimeWarning, match="multiple events"): epochs = read_epochs_eeglab(epochs_fname) - with pytest.warns(RuntimeWarning, match='multiple events'): + with pytest.warns(RuntimeWarning, match="multiple events"): epochs2 = read_epochs_eeglab(epochs_fname_onefile) # one warning for each read_epochs_eeglab because both files have epochs # associated with multiple events @@ -280,67 +364,75 @@ def test_io_set_epochs_events(tmp_path): out_fname = tmp_path / "test-eve.fif" events = np.array([[4, 0, 1], [12, 0, 2], [20, 0, 3], [26, 0, 3]]) write_events(out_fname, events) - event_id = {'S255/S8': 1, 'S8': 2, 'S255/S9': 3} + event_id = {"S255/S8": 1, "S8": 2, "S255/S9": 3} epochs = read_epochs_eeglab(epochs_fname_mat, events, event_id) assert_equal(len(epochs.events), 4) assert epochs.preload assert epochs._bad_dropped epochs = read_epochs_eeglab(epochs_fname_mat, out_fname, event_id) - pytest.raises(ValueError, read_epochs_eeglab, epochs_fname_mat, - None, event_id) - pytest.raises(ValueError, read_epochs_eeglab, epochs_fname_mat, - epochs.events, None) + pytest.raises(ValueError, read_epochs_eeglab, epochs_fname_mat, None, event_id) + pytest.raises(ValueError, read_epochs_eeglab, epochs_fname_mat, epochs.events, None) @testing.requires_testing_data -@pytest.mark.filterwarnings('ignore:At least one epoch has multiple events') +@pytest.mark.filterwarnings("ignore:At least one epoch has multiple events") @pytest.mark.filterwarnings("ignore:The data contains 'boundary' events") def test_degenerate(tmp_path): """Test some degenerate conditions.""" # test if .dat file raises an error - eeg = io.loadmat(epochs_fname_mat, struct_as_record=False, - squeeze_me=True)['EEG'] - eeg.data = 'epochs_fname.dat' + eeg = io.loadmat(epochs_fname_mat, struct_as_record=False, squeeze_me=True)["EEG"] + eeg.data = "epochs_fname.dat" bad_epochs_fname = tmp_path / "test_epochs.set" - io.savemat(bad_epochs_fname, - {'EEG': {'trials': eeg.trials, 'srate': eeg.srate, - 'nbchan': eeg.nbchan, 'data': eeg.data, - 'epoch': eeg.epoch, 'event': eeg.event, - 'chanlocs': eeg.chanlocs, 'pnts': eeg.pnts}}, - appendmat=False, oned_as='row') - shutil.copyfile( - base_dir / "test_epochs.fdt", tmp_path / "test_epochs.dat" + io.savemat( + bad_epochs_fname, + { + "EEG": { + "trials": eeg.trials, + "srate": eeg.srate, + "nbchan": eeg.nbchan, + "data": eeg.data, + "epoch": eeg.epoch, + "event": eeg.event, + "chanlocs": eeg.chanlocs, + "pnts": eeg.pnts, + } + }, + appendmat=False, + oned_as="row", ) - pytest.raises(NotImplementedError, read_epochs_eeglab, - bad_epochs_fname) + shutil.copyfile(base_dir / "test_epochs.fdt", tmp_path / "test_epochs.dat") + pytest.raises(NotImplementedError, read_epochs_eeglab, bad_epochs_fname) # error when montage units incorrect with pytest.raises(ValueError, match=r'prefix \+ "m" format'): - read_epochs_eeglab(epochs_fname_mat, montage_units='mV') + read_epochs_eeglab(epochs_fname_mat, montage_units="mV") # warning when head radius too small - with pytest.warns(RuntimeWarning, match='is above'): - read_raw_eeglab(raw_fname_chanloc, montage_units='km') + with pytest.warns(RuntimeWarning, match="is above"): + read_raw_eeglab(raw_fname_chanloc, montage_units="km") # warning when head radius too large - with pytest.warns(RuntimeWarning, match='is below'): - read_raw_eeglab(raw_fname_chanloc, montage_units='µm') - - -@pytest.mark.parametrize("fname", [ - raw_fname_mat, - raw_fname_onefile_mat, - # We don't test the h5 variants here because they are implicitly tested - # in test_io_set_raw -]) -@pytest.mark.filterwarnings('ignore: Complex objects') + with pytest.warns(RuntimeWarning, match="is below"): + read_raw_eeglab(raw_fname_chanloc, montage_units="µm") + + +@pytest.mark.parametrize( + "fname", + [ + raw_fname_mat, + raw_fname_onefile_mat, + # We don't test the h5 variants here because they are implicitly tested + # in test_io_set_raw + ], +) +@pytest.mark.filterwarnings("ignore: Complex objects") @testing.requires_testing_data def test_eeglab_annotations(fname): """Test reading annotations in EEGLAB files.""" annotations = read_annotations(fname) assert len(annotations) == 154 - assert set(annotations.description) == {'rt', 'square'} - assert np.all(annotations.duration == 0.) + assert set(annotations.description) == {"rt", "square"} + assert np.all(annotations.duration == 0.0) @testing.requires_testing_data @@ -348,15 +440,30 @@ def test_eeglab_read_annotations(): """Test annotations onsets are timestamps (+ validate some).""" annotations = read_annotations(raw_fname_mat) validation_samples = [0, 1, 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31] - expected_onset = np.array([1.00, 1.69, 2.08, 4.70, 7.71, 11.30, 17.18, - 20.20, 26.12, 29.14, 35.25, 44.30, 47.15]) + expected_onset = np.array( + [ + 1.00, + 1.69, + 2.08, + 4.70, + 7.71, + 11.30, + 17.18, + 20.20, + 26.12, + 29.14, + 35.25, + 44.30, + 47.15, + ] + ) assert annotations.orig_time is None - assert_array_almost_equal(annotations.onset[validation_samples], - expected_onset, decimal=2) + assert_array_almost_equal( + annotations.onset[validation_samples], expected_onset, decimal=2 + ) # test if event durations are imported correctly - raw = read_raw_eeglab(raw_fname_event_duration, preload=True, - montage_units='dm') + raw = read_raw_eeglab(raw_fname_event_duration, preload=True, montage_units="dm") # file contains 3 annotations with 0.5 s (64 samples) duration each assert_allclose(raw.annotations.duration, np.ones(3) * 0.5) @@ -366,7 +473,7 @@ def test_eeglab_event_from_annot(): """Test all forms of obtaining annotations.""" raw_fname_mat = base_dir / "test_raw.set" raw_fname = raw_fname_mat - event_id = {'rt': 1, 'square': 2} + event_id = {"rt": 1, "square": 2} raw1 = read_raw_eeglab(input_fname=raw_fname, preload=False) annotations = read_annotations(raw_fname) @@ -381,7 +488,7 @@ def _assert_array_allclose_nan(left, right): assert_allclose(left[~np.isnan(left)], right[~np.isnan(left)], atol=1e-8) -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def three_chanpos_fname(tmp_path_factory): """Test file with 3 channels to exercise EEGLAB reader. @@ -393,21 +500,29 @@ def three_chanpos_fname(tmp_path_factory): Notes from when this code was factorized: # test reading file with one event (read old version) """ - fname = str(tmp_path_factory.mktemp('data') / 'test_chanpos.set') - file_conent = dict(EEG={ - 'trials': 1, 'nbchan': 3, 'pnts': 3, 'epoch': [], 'event': [], - 'srate': 128, 'times': np.array([0., 0.1, 0.2]), - 'data': np.empty([3, 3]), - 'chanlocs': np.array( - [(b'F3', 1., 4., 7.), - (b'unknown', np.nan, np.nan, np.nan), - (b'FPz', 2., 5., 8.)], - dtype=[('labels', 'S10'), ('X', 'f8'), ('Y', 'f8'), ('Z', 'f8')] - ) - }) + fname = str(tmp_path_factory.mktemp("data") / "test_chanpos.set") + file_conent = dict( + EEG={ + "trials": 1, + "nbchan": 3, + "pnts": 3, + "epoch": [], + "event": [], + "srate": 128, + "times": np.array([0.0, 0.1, 0.2]), + "data": np.empty([3, 3]), + "chanlocs": np.array( + [ + (b"F3", 1.0, 4.0, 7.0), + (b"unknown", np.nan, np.nan, np.nan), + (b"FPz", 2.0, 5.0, 8.0), + ], + dtype=[("labels", "S10"), ("X", "f8"), ("Y", "f8"), ("Z", "f8")], + ), + } + ) - io.savemat(file_name=fname, mdict=file_conent, appendmat=False, - oned_as='row') + io.savemat(file_name=fname, mdict=file_conent, appendmat=False, oned_as="row") return fname @@ -416,22 +531,31 @@ def three_chanpos_fname(tmp_path_factory): def test_position_information(three_chanpos_fname): """Test reading file with 3 channels - one without position information.""" nan = np.nan - EXPECTED_LOCATIONS_FROM_FILE = np.array([ - [-4., 1., 7., 0., 0., 0., nan, nan, nan, nan, nan, nan], - [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan], - [-5., 2., 8., 0., 0., 0., nan, nan, nan, nan, nan, nan], - ]) * 0.01 # 0.01 is to scale cm to meters - - EXPECTED_LOCATIONS_FROM_MONTAGE = np.array([ - [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan], - [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan], - [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan], - ]) - - raw = read_raw_eeglab(input_fname=three_chanpos_fname, preload=True, - montage_units='cm') - assert_array_equal(np.array([ch['loc'] for ch in raw.info['chs']]), - EXPECTED_LOCATIONS_FROM_FILE) + EXPECTED_LOCATIONS_FROM_FILE = ( + np.array( + [ + [-4.0, 1.0, 7.0, 0.0, 0.0, 0.0, nan, nan, nan, nan, nan, nan], + [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan], + [-5.0, 2.0, 8.0, 0.0, 0.0, 0.0, nan, nan, nan, nan, nan, nan], + ] + ) + * 0.01 + ) # 0.01 is to scale cm to meters + + EXPECTED_LOCATIONS_FROM_MONTAGE = np.array( + [ + [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan], + [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan], + [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan], + ] + ) + + raw = read_raw_eeglab( + input_fname=three_chanpos_fname, preload=True, montage_units="cm" + ) + assert_array_equal( + np.array([ch["loc"] for ch in raw.info["chs"]]), EXPECTED_LOCATIONS_FROM_FILE + ) # To accommodate the new behavior so that: # read_raw_eeglab(.. montage=montage) and raw.set_montage(montage) @@ -439,25 +563,33 @@ def test_position_information(three_chanpos_fname): # a mix of what is in montage and in the file raw = read_raw_eeglab( input_fname=three_chanpos_fname, - preload=True, montage_units='cm', - ).set_montage(None) # Flush the montage builtin within input_fname - - _assert_array_allclose_nan(np.array([ch['loc'] for ch in raw.info['chs']]), - EXPECTED_LOCATIONS_FROM_MONTAGE) + preload=True, + montage_units="cm", + ).set_montage( + None + ) # Flush the montage builtin within input_fname + + _assert_array_allclose_nan( + np.array([ch["loc"] for ch in raw.info["chs"]]), EXPECTED_LOCATIONS_FROM_MONTAGE + ) @testing.requires_testing_data def test_io_set_raw_2021(): """Test reading new default file format (no EEG struct).""" assert "EEG" not in io.loadmat(raw_fname_2021) - _test_raw_reader(reader=read_raw_eeglab, input_fname=raw_fname_2021, - test_preloading=False, preload=True) + _test_raw_reader( + reader=read_raw_eeglab, + input_fname=raw_fname_2021, + test_preloading=False, + preload=True, + ) @testing.requires_testing_data def test_read_single_epoch(): """Test reading raw set file as an Epochs instance.""" - with pytest.raises(ValueError, match='trials less than 2'): + with pytest.raises(ValueError, match="trials less than 2"): read_epochs_eeglab(raw_fname_mat) @@ -465,51 +597,51 @@ def test_read_single_epoch(): def test_get_montage_info_with_ch_type(): """Test that the channel types are properly returned.""" mat = _readmat(raw_fname_onefile_mat) - n = len(mat['EEG']['chanlocs']['labels']) - mat['EEG']['chanlocs']['type'] = ['eeg'] * (n - 2) + ['eog'] + ['stim'] - mat['EEG']['chanlocs'] = _dol_to_lod(mat['EEG']['chanlocs']) - mat['EEG'] = Bunch(**mat['EEG']) - ch_names, ch_types, montage = _get_montage_information(mat['EEG'], False) + n = len(mat["EEG"]["chanlocs"]["labels"]) + mat["EEG"]["chanlocs"]["type"] = ["eeg"] * (n - 2) + ["eog"] + ["stim"] + mat["EEG"]["chanlocs"] = _dol_to_lod(mat["EEG"]["chanlocs"]) + mat["EEG"] = Bunch(**mat["EEG"]) + ch_names, ch_types, montage = _get_montage_information(mat["EEG"], False) assert len(ch_names) == len(ch_types) == n - assert ch_types == ['eeg'] * (n - 2) + ['eog'] + ['stim'] + assert ch_types == ["eeg"] * (n - 2) + ["eog"] + ["stim"] assert montage is None # test unknown type warning mat = _readmat(raw_fname_onefile_mat) - n = len(mat['EEG']['chanlocs']['labels']) - mat['EEG']['chanlocs']['type'] = ['eeg'] * (n - 2) + ['eog'] + ['unknown'] - mat['EEG']['chanlocs'] = _dol_to_lod(mat['EEG']['chanlocs']) - mat['EEG'] = Bunch(**mat['EEG']) - with pytest.warns(RuntimeWarning, match='Unknown types found'): - ch_names, ch_types, montage = \ - _get_montage_information(mat['EEG'], False) + n = len(mat["EEG"]["chanlocs"]["labels"]) + mat["EEG"]["chanlocs"]["type"] = ["eeg"] * (n - 2) + ["eog"] + ["unknown"] + mat["EEG"]["chanlocs"] = _dol_to_lod(mat["EEG"]["chanlocs"]) + mat["EEG"] = Bunch(**mat["EEG"]) + with pytest.warns(RuntimeWarning, match="Unknown types found"): + ch_names, ch_types, montage = _get_montage_information(mat["EEG"], False) @testing.requires_testing_data -@pytest.mark.parametrize('has_type', (True, False)) +@pytest.mark.parametrize("has_type", (True, False)) def test_fidsposition_information(monkeypatch, has_type): """Test reading file with 3 fiducial locations.""" if not has_type: - def get_bad_information(eeg, get_pos, scale_units=1.): - del eeg.chaninfo['nodatchans']['type'] - return _get_montage_information(eeg, get_pos, - scale_units=scale_units) - - monkeypatch.setattr(mne.io.eeglab.eeglab, '_get_montage_information', - get_bad_information) - raw = read_raw_eeglab(raw_fname_chanloc_fids, montage_units='cm') + + def get_bad_information(eeg, get_pos, scale_units=1.0): + del eeg.chaninfo["nodatchans"]["type"] + return _get_montage_information(eeg, get_pos, scale_units=scale_units) + + monkeypatch.setattr( + mne.io.eeglab.eeglab, "_get_montage_information", get_bad_information + ) + raw = read_raw_eeglab(raw_fname_chanloc_fids, montage_units="cm") montage = raw.get_montage() pos = montage.get_positions() n_eeg = 129 if not has_type: # These should now be estimated from the data - assert_allclose(pos['nasion'], [0, 0.0997, 0], atol=1e-4) - assert_allclose(pos['lpa'], -pos['nasion'][[1, 0, 0]]) - assert_allclose(pos['rpa'], pos['nasion'][[1, 0, 0]]) - assert pos['nasion'] is not None - assert pos['lpa'] is not None - assert pos['rpa'] is not None - assert len(pos['nasion']) == 3 - assert len(pos['lpa']) == 3 - assert len(pos['rpa']) == 3 - assert len(raw.info['dig']) == n_eeg + 3 + assert_allclose(pos["nasion"], [0, 0.0997, 0], atol=1e-4) + assert_allclose(pos["lpa"], -pos["nasion"][[1, 0, 0]]) + assert_allclose(pos["rpa"], pos["nasion"][[1, 0, 0]]) + assert pos["nasion"] is not None + assert pos["lpa"] is not None + assert pos["rpa"] is not None + assert len(pos["nasion"]) == 3 + assert len(pos["lpa"]) == 3 + assert len(pos["rpa"]) == 3 + assert len(raw.info["dig"]) == n_eeg + 3 diff --git a/mne/io/egi/egi.py b/mne/io/egi/egi.py index 4e5321fb95f..530dcb9179c 100644 --- a/mne/io/egi/egi.py +++ b/mne/io/egi/egi.py @@ -19,77 +19,90 @@ def _read_header(fid): """Read EGI binary header.""" - version = np.fromfile(fid, ' 6 & ~np.bitwise_and(version, 6): version = version.byteswap().astype(np.uint32) else: - raise ValueError('Watchout. This does not seem to be a simple ' - 'binary EGI file.') + raise ValueError( + "Watchout. This does not seem to be a simple " "binary EGI file." + ) def my_fread(*x, **y): return np.fromfile(*x, **y)[0] info = dict( version=version, - year=my_fread(fid, '>i2', 1), - month=my_fread(fid, '>i2', 1), - day=my_fread(fid, '>i2', 1), - hour=my_fread(fid, '>i2', 1), - minute=my_fread(fid, '>i2', 1), - second=my_fread(fid, '>i2', 1), - millisecond=my_fread(fid, '>i4', 1), - samp_rate=my_fread(fid, '>i2', 1), - n_channels=my_fread(fid, '>i2', 1), - gain=my_fread(fid, '>i2', 1), - bits=my_fread(fid, '>i2', 1), - value_range=my_fread(fid, '>i2', 1) + year=my_fread(fid, ">i2", 1), + month=my_fread(fid, ">i2", 1), + day=my_fread(fid, ">i2", 1), + hour=my_fread(fid, ">i2", 1), + minute=my_fread(fid, ">i2", 1), + second=my_fread(fid, ">i2", 1), + millisecond=my_fread(fid, ">i4", 1), + samp_rate=my_fread(fid, ">i2", 1), + n_channels=my_fread(fid, ">i2", 1), + gain=my_fread(fid, ">i2", 1), + bits=my_fread(fid, ">i2", 1), + value_range=my_fread(fid, ">i2", 1), ) unsegmented = 1 if np.bitwise_and(version, 1) == 0 else 0 precision = np.bitwise_and(version, 6) if precision == 0: - raise RuntimeError('Floating point precision is undefined.') + raise RuntimeError("Floating point precision is undefined.") if unsegmented: - info.update(dict(n_categories=0, - n_segments=1, - n_samples=np.fromfile(fid, '>i4', 1)[0], - n_events=np.fromfile(fid, '>i2', 1)[0], - event_codes=[], - category_names=[], - category_lengths=[], - pre_baseline=0)) - for event in range(info['n_events']): - event_codes = ''.join(np.fromfile(fid, 'S1', 4).astype('U1')) - info['event_codes'].append(event_codes) + info.update( + dict( + n_categories=0, + n_segments=1, + n_samples=np.fromfile(fid, ">i4", 1)[0], + n_events=np.fromfile(fid, ">i2", 1)[0], + event_codes=[], + category_names=[], + category_lengths=[], + pre_baseline=0, + ) + ) + for event in range(info["n_events"]): + event_codes = "".join(np.fromfile(fid, "S1", 4).astype("U1")) + info["event_codes"].append(event_codes) else: - raise NotImplementedError('Only continuous files are supported') - info['unsegmented'] = unsegmented - info['dtype'], info['orig_format'] = {2: ('>i2', 'short'), - 4: ('>f4', 'float'), - 6: ('>f8', 'double')}[precision] - info['dtype'] = np.dtype(info['dtype']) + raise NotImplementedError("Only continuous files are supported") + info["unsegmented"] = unsegmented + info["dtype"], info["orig_format"] = { + 2: (">i2", "short"), + 4: (">f4", "float"), + 6: (">f8", "double"), + }[precision] + info["dtype"] = np.dtype(info["dtype"]) return info def _read_events(fid, info): """Read events.""" - events = np.zeros([info['n_events'], - info['n_segments'] * info['n_samples']]) - fid.seek(36 + info['n_events'] * 4, 0) # skip header - for si in range(info['n_samples']): + events = np.zeros([info["n_events"], info["n_segments"] * info["n_samples"]]) + fid.seek(36 + info["n_events"] * 4, 0) # skip header + for si in range(info["n_samples"]): # skip data channels - fid.seek(info['n_channels'] * info['dtype'].itemsize, 1) + fid.seek(info["n_channels"] * info["dtype"].itemsize, 1) # read event channels - events[:, si] = np.fromfile(fid, info['dtype'], info['n_events']) + events[:, si] = np.fromfile(fid, info["dtype"], info["n_events"]) return events @verbose -def read_raw_egi(input_fname, eog=None, misc=None, - include=None, exclude=None, preload=False, - channel_naming='E%d', verbose=None): +def read_raw_egi( + input_fname, + eog=None, + misc=None, + include=None, + exclude=None, + preload=False, + channel_naming="E%d", + verbose=None, +): """Read EGI simple binary as raw object. .. note:: This function attempts to create a synthetic trigger channel. @@ -151,131 +164,173 @@ def read_raw_egi(input_fname, eog=None, misc=None, This step will fail if events are not mutually exclusive. """ - _validate_type(input_fname, 'path-like', 'input_fname') + _validate_type(input_fname, "path-like", "input_fname") input_fname = str(input_fname) - if input_fname.rstrip('/\\').endswith('.mff'): # allows .mff or .mff/ - return _read_raw_egi_mff(input_fname, eog, misc, include, - exclude, preload, channel_naming, verbose) - return RawEGI(input_fname, eog, misc, include, exclude, preload, - channel_naming, verbose) + if input_fname.rstrip("/\\").endswith(".mff"): # allows .mff or .mff/ + return _read_raw_egi_mff( + input_fname, eog, misc, include, exclude, preload, channel_naming, verbose + ) + return RawEGI( + input_fname, eog, misc, include, exclude, preload, channel_naming, verbose + ) class RawEGI(BaseRaw): """Raw object from EGI simple binary file.""" @verbose - def __init__(self, input_fname, eog=None, misc=None, - include=None, exclude=None, preload=False, - channel_naming='E%d', verbose=None): # noqa: D102 - input_fname = str( - _check_fname(input_fname, "read", True, "input_fname") - ) + def __init__( + self, + input_fname, + eog=None, + misc=None, + include=None, + exclude=None, + preload=False, + channel_naming="E%d", + verbose=None, + ): # noqa: D102 + input_fname = str(_check_fname(input_fname, "read", True, "input_fname")) if eog is None: eog = [] if misc is None: misc = [] - with open(input_fname, 'rb') as fid: # 'rb' important for py3k - logger.info('Reading EGI header from %s...' % input_fname) + with open(input_fname, "rb") as fid: # 'rb' important for py3k + logger.info("Reading EGI header from %s..." % input_fname) egi_info = _read_header(fid) - logger.info(' Reading events ...') + logger.info(" Reading events ...") egi_events = _read_events(fid, egi_info) # update info + jump - if egi_info['value_range'] != 0 and egi_info['bits'] != 0: - cal = egi_info['value_range'] / 2. ** egi_info['bits'] + if egi_info["value_range"] != 0 and egi_info["bits"] != 0: + cal = egi_info["value_range"] / 2.0 ** egi_info["bits"] else: cal = 1e-6 - logger.info(' Assembling measurement info ...') + logger.info(" Assembling measurement info ...") event_codes = [] - if egi_info['n_events'] > 0: - event_codes = list(egi_info['event_codes']) + if egi_info["n_events"] > 0: + event_codes = list(egi_info["event_codes"]) if include is None: - exclude_list = ['sync', 'TREV'] if exclude is None else exclude - exclude_inds = [i for i, k in enumerate(event_codes) if k in - exclude_list] + exclude_list = ["sync", "TREV"] if exclude is None else exclude + exclude_inds = [ + i for i, k in enumerate(event_codes) if k in exclude_list + ] more_excludes = [] if exclude is None: for ii, event in enumerate(egi_events): if event.sum() <= 1 and event_codes[ii]: more_excludes.append(ii) if len(exclude_inds) + len(more_excludes) == len(event_codes): - warn('Did not find any event code with more than one ' - 'event.', RuntimeWarning) + warn( + "Did not find any event code with more than one " "event.", + RuntimeWarning, + ) else: exclude_inds.extend(more_excludes) exclude_inds.sort() - include_ = [i for i in np.arange(egi_info['n_events']) if - i not in exclude_inds] - include_names = [k for i, k in enumerate(event_codes) - if i in include_] + include_ = [ + i for i in np.arange(egi_info["n_events"]) if i not in exclude_inds + ] + include_names = [k for i, k in enumerate(event_codes) if i in include_] else: - include_ = [i for i, k in enumerate(event_codes) - if k in include] + include_ = [i for i, k in enumerate(event_codes) if k in include] include_names = include - for kk, v in [('include', include_names), ('exclude', exclude)]: + for kk, v in [("include", include_names), ("exclude", exclude)]: if isinstance(v, list): for k in v: if k not in event_codes: raise ValueError('Could find event named "%s"' % k) elif v is not None: - raise ValueError('`%s` must be None or of type list' % kk) + raise ValueError("`%s` must be None or of type list" % kk) event_ids = np.arange(len(include_)) + 1 logger.info(' Synthesizing trigger channel "STI 014" ...') - logger.info(' Excluding events {%s} ...' % - ", ".join([k for i, k in enumerate(event_codes) - if i not in include_])) - egi_info['new_trigger'] = _combine_triggers( - egi_events[include_], remapping=event_ids) - self.event_id = dict(zip([e for e in event_codes if e in - include_names], event_ids)) + logger.info( + " Excluding events {%s} ..." + % ", ".join([k for i, k in enumerate(event_codes) if i not in include_]) + ) + egi_info["new_trigger"] = _combine_triggers( + egi_events[include_], remapping=event_ids + ) + self.event_id = dict( + zip([e for e in event_codes if e in include_names], event_ids) + ) else: # No events self.event_id = None - egi_info['new_trigger'] = None - info = _empty_info(egi_info['samp_rate']) + egi_info["new_trigger"] = None + info = _empty_info(egi_info["samp_rate"]) my_time = datetime.datetime( - egi_info['year'], egi_info['month'], egi_info['day'], - egi_info['hour'], egi_info['minute'], egi_info['second']) + egi_info["year"], + egi_info["month"], + egi_info["day"], + egi_info["hour"], + egi_info["minute"], + egi_info["second"], + ) my_timestamp = time.mktime(my_time.timetuple()) - info['meas_date'] = (my_timestamp, 0) - ch_names = [channel_naming % (i + 1) for i in - range(egi_info['n_channels'])] - ch_names.extend(list(egi_info['event_codes'])) - if egi_info['new_trigger'] is not None: - ch_names.append('STI 014') # our new_trigger + info["meas_date"] = (my_timestamp, 0) + ch_names = [channel_naming % (i + 1) for i in range(egi_info["n_channels"])] + ch_names.extend(list(egi_info["event_codes"])) + if egi_info["new_trigger"] is not None: + ch_names.append("STI 014") # our new_trigger nchan = len(ch_names) cals = np.repeat(cal, nchan) ch_coil = FIFF.FIFFV_COIL_EEG ch_kind = FIFF.FIFFV_EEG_CH chs = _create_chs(ch_names, cals, ch_coil, ch_kind, eog, (), (), misc) - sti_ch_idx = [i for i, name in enumerate(ch_names) if - name.startswith('STI') or name in event_codes] + sti_ch_idx = [ + i + for i, name in enumerate(ch_names) + if name.startswith("STI") or name in event_codes + ] for idx in sti_ch_idx: - chs[idx].update({'unit_mul': FIFF.FIFF_UNITM_NONE, 'cal': 1., - 'kind': FIFF.FIFFV_STIM_CH, - 'coil_type': FIFF.FIFFV_COIL_NONE, - 'unit': FIFF.FIFF_UNIT_NONE, - 'loc': np.zeros(12)}) - info['chs'] = chs + chs[idx].update( + { + "unit_mul": FIFF.FIFF_UNITM_NONE, + "cal": 1.0, + "kind": FIFF.FIFFV_STIM_CH, + "coil_type": FIFF.FIFFV_COIL_NONE, + "unit": FIFF.FIFF_UNIT_NONE, + "loc": np.zeros(12), + } + ) + info["chs"] = chs info._unlocked = False info._update_redundant() - orig_format = egi_info["orig_format"] \ - if egi_info["orig_format"] != "float" else "single" + orig_format = ( + egi_info["orig_format"] if egi_info["orig_format"] != "float" else "single" + ) super(RawEGI, self).__init__( - info, preload, orig_format=orig_format, - filenames=[input_fname], last_samps=[egi_info['n_samples'] - 1], - raw_extras=[egi_info], verbose=verbose) + info, + preload, + orig_format=orig_format, + filenames=[input_fname], + last_samps=[egi_info["n_samples"] - 1], + raw_extras=[egi_info], + verbose=verbose, + ) def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): """Read a segment of data from a file.""" egi_info = self._raw_extras[fi] - dtype = egi_info['dtype'] - n_chan_read = egi_info['n_channels'] + egi_info['n_events'] - offset = 36 + egi_info['n_events'] * 4 - trigger_ch = egi_info['new_trigger'] - _read_segments_file(self, data, idx, fi, start, stop, cals, mult, - dtype=dtype, n_channels=n_chan_read, offset=offset, - trigger_ch=trigger_ch) + dtype = egi_info["dtype"] + n_chan_read = egi_info["n_channels"] + egi_info["n_events"] + offset = 36 + egi_info["n_events"] * 4 + trigger_ch = egi_info["new_trigger"] + _read_segments_file( + self, + data, + idx, + fi, + start, + stop, + cals, + mult, + dtype=dtype, + n_channels=n_chan_read, + offset=offset, + trigger_ch=trigger_ch, + ) diff --git a/mne/io/egi/egimff.py b/mne/io/egi/egimff.py index 1c745d8b10e..db7247730f8 100644 --- a/mne/io/egi/egimff.py +++ b/mne/io/egi/egimff.py @@ -11,8 +11,14 @@ import numpy as np from .events import _read_events, _combine_triggers -from .general import (_get_signalfname, _get_ep_info, _extract, _get_blocks, - _get_gains, _block_r) +from .general import ( + _get_signalfname, + _get_ep_info, + _extract, + _get_blocks, + _get_gains, + _block_r, +) from ..base import BaseRaw from ..constants import FIFF from ..meas_info import _empty_info, create_info, _ensure_meas_date_none_or_dt @@ -22,46 +28,46 @@ from ...utils import verbose, logger, warn, _check_option, _check_fname from ...evoked import EvokedArray -REFERENCE_NAMES = ('VREF', 'Vertex Reference') +REFERENCE_NAMES = ("VREF", "Vertex Reference") def _read_mff_header(filepath): """Read mff header.""" all_files = _get_signalfname(filepath) - eeg_file = all_files['EEG']['signal'] - eeg_info_file = all_files['EEG']['info'] + eeg_file = all_files["EEG"]["signal"] + eeg_info_file = all_files["EEG"]["info"] - info_filepath = op.join(filepath, 'info.xml') # add with filepath - tags = ['mffVersion', 'recordTime'] + info_filepath = op.join(filepath, "info.xml") # add with filepath + tags = ["mffVersion", "recordTime"] version_and_date = _extract(tags, filepath=info_filepath) version = "" - if len(version_and_date['mffVersion']): - version = version_and_date['mffVersion'][0] + if len(version_and_date["mffVersion"]): + version = version_and_date["mffVersion"][0] fname = op.join(filepath, eeg_file) signal_blocks = _get_blocks(fname) epochs = _get_ep_info(filepath) - summaryinfo = dict(eeg_fname=eeg_file, - info_fname=eeg_info_file) + summaryinfo = dict(eeg_fname=eeg_file, info_fname=eeg_info_file) summaryinfo.update(signal_blocks) # sanity check and update relevant values - record_time = version_and_date['recordTime'][0] + record_time = version_and_date["recordTime"][0] # e.g., # 2018-07-30T10:47:01.021673-04:00 # 2017-09-20T09:55:44.072000000+01:00 g = re.match( - r'\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}.(\d{6}(?:\d{3})?)[+-]\d{2}:\d{2}', # noqa: E501 - record_time) + r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}.(\d{6}(?:\d{3})?)[+-]\d{2}:\d{2}", # noqa: E501 + record_time, + ) if g is None: - raise RuntimeError('Could not parse recordTime %r' % (record_time,)) + raise RuntimeError("Could not parse recordTime %r" % (record_time,)) frac = g.groups()[0] assert len(frac) in (6, 9) and all(f.isnumeric() for f in frac) # regex div = 1000 if len(frac) == 6 else 1000000 - for key in ('last_samps', 'first_samps'): + for key in ("last_samps", "first_samps"): # convert from times in µS to samples for ei, e in enumerate(epochs[key]): if e % div != 0: - raise RuntimeError('Could not parse epoch time %s' % (e,)) + raise RuntimeError("Could not parse epoch time %s" % (e,)) epochs[key][ei] = e // div epochs[key] = np.array(epochs[key], np.uint64) # I guess they refer to times in milliseconds? @@ -72,108 +78,120 @@ def _read_mff_header(filepath): # by what we need to (e.g., a sample rate of 500 means we can multiply # by 1 and divide by 2 rather than multiplying by 500 and dividing by # 1000) - numerator = signal_blocks['sfreq'] + numerator = signal_blocks["sfreq"] denominator = 1000 this_gcd = math.gcd(numerator, denominator) numerator = numerator // this_gcd denominator = denominator // this_gcd - with np.errstate(over='raise'): + with np.errstate(over="raise"): epochs[key] *= numerator epochs[key] //= denominator # Should be safe to cast to int now, which makes things later not # upbroadcast to float epochs[key] = epochs[key].astype(np.int64) - n_samps_block = signal_blocks['samples_block'].sum() - n_samps_epochs = (epochs['last_samps'] - epochs['first_samps']).sum() - bad = (n_samps_epochs != n_samps_block or - not (epochs['first_samps'] < epochs['last_samps']).all() or - not (epochs['first_samps'][1:] >= epochs['last_samps'][:-1]).all()) + n_samps_block = signal_blocks["samples_block"].sum() + n_samps_epochs = (epochs["last_samps"] - epochs["first_samps"]).sum() + bad = ( + n_samps_epochs != n_samps_block + or not (epochs["first_samps"] < epochs["last_samps"]).all() + or not (epochs["first_samps"][1:] >= epochs["last_samps"][:-1]).all() + ) if bad: - raise RuntimeError('EGI epoch first/last samps could not be parsed:\n' - '%s\n%s' % (list(epochs['first_samps']), - list(epochs['last_samps']))) + raise RuntimeError( + "EGI epoch first/last samps could not be parsed:\n" + "%s\n%s" % (list(epochs["first_samps"]), list(epochs["last_samps"])) + ) summaryinfo.update(epochs) # index which samples in raw are actually readable from disk (i.e., not # in a skip) - disk_samps = np.full(epochs['last_samps'][-1], -1) + disk_samps = np.full(epochs["last_samps"][-1], -1) offset = 0 - for first, last in zip(epochs['first_samps'], epochs['last_samps']): + for first, last in zip(epochs["first_samps"], epochs["last_samps"]): n_this = last - first disk_samps[first:last] = np.arange(offset, offset + n_this) offset += n_this - summaryinfo['disk_samps'] = disk_samps + summaryinfo["disk_samps"] = disk_samps # Add the sensor info. - sensor_layout_file = op.join(filepath, 'sensorLayout.xml') + sensor_layout_file = op.join(filepath, "sensorLayout.xml") sensor_layout_obj = parse(sensor_layout_file) - summaryinfo['device'] = (sensor_layout_obj.getElementsByTagName('name') - [0].firstChild.data) - sensors = sensor_layout_obj.getElementsByTagName('sensor') + summaryinfo["device"] = sensor_layout_obj.getElementsByTagName("name")[ + 0 + ].firstChild.data + sensors = sensor_layout_obj.getElementsByTagName("sensor") chan_type = list() chan_unit = list() n_chans = 0 numbers = list() # used for identification for sensor in sensors: - sensortype = int(sensor.getElementsByTagName('type')[0] - .firstChild.data) + sensortype = int(sensor.getElementsByTagName("type")[0].firstChild.data) if sensortype in [0, 1]: - sn = sensor.getElementsByTagName('number')[0].firstChild.data + sn = sensor.getElementsByTagName("number")[0].firstChild.data sn = sn.encode() numbers.append(sn) - chan_type.append('eeg') - chan_unit.append('uV') + chan_type.append("eeg") + chan_unit.append("uV") n_chans = n_chans + 1 - if n_chans != summaryinfo['n_channels']: - raise RuntimeError('Number of defined channels (%d) did not match the ' - 'expected channels (%d)' - % (n_chans, summaryinfo['n_channels'])) + if n_chans != summaryinfo["n_channels"]: + raise RuntimeError( + "Number of defined channels (%d) did not match the " + "expected channels (%d)" % (n_chans, summaryinfo["n_channels"]) + ) # Check presence of PNS data pns_names = [] - if 'PNS' in all_files: - pns_fpath = op.join(filepath, all_files['PNS']['signal']) + if "PNS" in all_files: + pns_fpath = op.join(filepath, all_files["PNS"]["signal"]) pns_blocks = _get_blocks(pns_fpath) - pns_samples = pns_blocks['samples_block'] - signal_samples = signal_blocks['samples_block'] - same_blocks = (np.array_equal(pns_samples[:-1], - signal_samples[:-1]) and - pns_samples[-1] in (signal_samples[-1] - np.arange(2))) + pns_samples = pns_blocks["samples_block"] + signal_samples = signal_blocks["samples_block"] + same_blocks = np.array_equal( + pns_samples[:-1], signal_samples[:-1] + ) and pns_samples[-1] in (signal_samples[-1] - np.arange(2)) if not same_blocks: - raise RuntimeError('PNS and signals samples did not match:\n' - '%s\nvs\n%s' - % (list(pns_samples), list(signal_samples))) + raise RuntimeError( + "PNS and signals samples did not match:\n" + "%s\nvs\n%s" % (list(pns_samples), list(signal_samples)) + ) - pns_file = op.join(filepath, 'pnsSet.xml') + pns_file = op.join(filepath, "pnsSet.xml") pns_obj = parse(pns_file) - sensors = pns_obj.getElementsByTagName('sensor') + sensors = pns_obj.getElementsByTagName("sensor") pns_types = [] pns_units = [] for sensor in sensors: # sensor number: # sensor.getElementsByTagName('number')[0].firstChild.data - name = sensor.getElementsByTagName('name')[0].firstChild.data - unit_elem = sensor.getElementsByTagName('unit')[0].firstChild - unit = '' + name = sensor.getElementsByTagName("name")[0].firstChild.data + unit_elem = sensor.getElementsByTagName("unit")[0].firstChild + unit = "" if unit_elem is not None: unit = unit_elem.data - if name == 'ECG': - ch_type = 'ecg' - elif 'EMG' in name: - ch_type = 'emg' + if name == "ECG": + ch_type = "ecg" + elif "EMG" in name: + ch_type = "emg" else: - ch_type = 'bio' + ch_type = "bio" pns_types.append(ch_type) pns_units.append(unit) pns_names.append(name) - summaryinfo.update(pns_types=pns_types, pns_units=pns_units, - pns_fname=all_files['PNS']['signal'], - pns_sample_blocks=pns_blocks) - summaryinfo.update(pns_names=pns_names, version=version, - date=version_and_date['recordTime'][0], - chan_type=chan_type, chan_unit=chan_unit, - numbers=numbers) + summaryinfo.update( + pns_types=pns_types, + pns_units=pns_units, + pns_fname=all_files["PNS"]["signal"], + pns_sample_blocks=pns_blocks, + ) + summaryinfo.update( + pns_names=pns_names, + version=version, + date=version_and_date["recordTime"][0], + chan_type=chan_type, + chan_unit=chan_unit, + numbers=numbers, + ) return summaryinfo @@ -191,7 +209,7 @@ def utcoffset(self, dt): return self._offset def tzname(self, dt): - return 'MFF' + return "MFF" def dst(self, dt): return datetime.timedelta(0) @@ -214,7 +232,7 @@ def _read_header(input_fname): mff_hdr = _read_mff_header(input_fname) with open(input_fname + "/signal1.bin", "rb") as fid: version = np.fromfile(fid, np.int32, 1)[0] - ''' + """ the datetime.strptime .f directive (milleseconds) will only accept up to 6 digits. if there are more than six millesecond digits in the provided timestamp string @@ -223,72 +241,84 @@ def _read_header(input_fname): elements of the timestamp string to truncate the milleseconds to 6 digits and extract the timezone, and then piece these together and assign back to mff_hdr['date'] - ''' - if len(mff_hdr['date']) > 32: - dt, tz = [mff_hdr['date'][:26], mff_hdr['date'][-6:]] - mff_hdr['date'] = dt + tz + """ + if len(mff_hdr["date"]) > 32: + dt, tz = [mff_hdr["date"][:26], mff_hdr["date"][-6:]] + mff_hdr["date"] = dt + tz - time_n = (datetime.datetime.strptime( - mff_hdr['date'], '%Y-%m-%dT%H:%M:%S.%f%z')) + time_n = datetime.datetime.strptime(mff_hdr["date"], "%Y-%m-%dT%H:%M:%S.%f%z") info = dict( version=version, meas_dt_local=time_n, - utc_offset=time_n.strftime('%z'), + utc_offset=time_n.strftime("%z"), gain=0, bits=0, - value_range=0) - info.update(n_categories=0, n_segments=1, n_events=0, event_codes=[], - category_names=[], category_lengths=[], pre_baseline=0) + value_range=0, + ) + info.update( + n_categories=0, + n_segments=1, + n_events=0, + event_codes=[], + category_names=[], + category_lengths=[], + pre_baseline=0, + ) info.update(mff_hdr) return info def _get_eeg_calibration_info(filepath, egi_info): """Calculate calibration info for EEG channels.""" - gains = _get_gains(op.join(filepath, egi_info['info_fname'])) - if egi_info['value_range'] != 0 and egi_info['bits'] != 0: - cals = [egi_info['value_range'] / 2 ** egi_info['bits']] * \ - len(egi_info['chan_type']) + gains = _get_gains(op.join(filepath, egi_info["info_fname"])) + if egi_info["value_range"] != 0 and egi_info["bits"] != 0: + cals = [egi_info["value_range"] / 2 ** egi_info["bits"]] * len( + egi_info["chan_type"] + ) else: - cal_scales = {'uV': 1e-6, 'V': 1} - cals = [cal_scales[t] for t in egi_info['chan_unit']] - if 'gcal' in gains: - cals *= gains['gcal'] + cal_scales = {"uV": 1e-6, "V": 1} + cals = [cal_scales[t] for t in egi_info["chan_unit"]] + if "gcal" in gains: + cals *= gains["gcal"] return cals def _read_locs(filepath, egi_info, channel_naming): """Read channel locations.""" from ...channels.montage import make_dig_montage - fname = op.join(filepath, 'coordinates.xml') + + fname = op.join(filepath, "coordinates.xml") if not op.exists(fname): - logger.warn( - 'File coordinates.xml not found, not setting channel locations') - ch_names = [channel_naming % (i + 1) for i in - range(egi_info['n_channels'])] + logger.warn("File coordinates.xml not found, not setting channel locations") + ch_names = [channel_naming % (i + 1) for i in range(egi_info["n_channels"])] return ch_names, None dig_ident_map = { - 'Left periauricular point': 'lpa', - 'Right periauricular point': 'rpa', - 'Nasion': 'nasion', + "Left periauricular point": "lpa", + "Right periauricular point": "rpa", + "Nasion": "nasion", } - numbers = np.array(egi_info['numbers']) + numbers = np.array(egi_info["numbers"]) coordinates = parse(fname) - sensors = coordinates.getElementsByTagName('sensor') + sensors = coordinates.getElementsByTagName("sensor") ch_pos = OrderedDict() hsp = list() nlr = dict() ch_names = list() for sensor in sensors: - name_element = sensor.getElementsByTagName('name')[0].firstChild - num_element = sensor.getElementsByTagName('number')[0].firstChild - name = (channel_naming % int(num_element.data) if name_element is None - else name_element.data) + name_element = sensor.getElementsByTagName("name")[0].firstChild + num_element = sensor.getElementsByTagName("number")[0].firstChild + name = ( + channel_naming % int(num_element.data) + if name_element is None + else name_element.data + ) nr = num_element.data.encode() - coords = [float(sensor.getElementsByTagName(coord)[0].firstChild.data) - for coord in 'xyz'] + coords = [ + float(sensor.getElementsByTagName(coord)[0].firstChild.data) + for coord in "xyz" + ] loc = np.array(coords) / 100 # cm -> m # create dig entry if name in dig_ident_map: @@ -309,28 +339,33 @@ def _read_locs(filepath, egi_info, channel_naming): def _add_pns_channel_info(chs, egi_info, ch_names): """Add info for PNS channels to channel info dict.""" - for i_ch, ch_name in enumerate(egi_info['pns_names']): + for i_ch, ch_name in enumerate(egi_info["pns_names"]): idx = ch_names.index(ch_name) - ch_type = egi_info['pns_types'][i_ch] - type_to_kind_map = {'ecg': FIFF.FIFFV_ECG_CH, - 'emg': FIFF.FIFFV_EMG_CH - } + ch_type = egi_info["pns_types"][i_ch] + type_to_kind_map = {"ecg": FIFF.FIFFV_ECG_CH, "emg": FIFF.FIFFV_EMG_CH} ch_kind = type_to_kind_map.get(ch_type, FIFF.FIFFV_BIO_CH) ch_unit = FIFF.FIFF_UNIT_V ch_cal = 1e-6 - if egi_info['pns_units'][i_ch] != 'uV': + if egi_info["pns_units"][i_ch] != "uV": ch_unit = FIFF.FIFF_UNIT_NONE ch_cal = 1.0 chs[idx].update( - cal=ch_cal, kind=ch_kind, coil_type=FIFF.FIFFV_COIL_NONE, - unit=ch_unit) + cal=ch_cal, kind=ch_kind, coil_type=FIFF.FIFFV_COIL_NONE, unit=ch_unit + ) return chs @verbose -def _read_raw_egi_mff(input_fname, eog=None, misc=None, - include=None, exclude=None, preload=False, - channel_naming='E%d', verbose=None): +def _read_raw_egi_mff( + input_fname, + eog=None, + misc=None, + include=None, + exclude=None, + preload=False, + channel_naming="E%d", + verbose=None, +): """Read EGI mff binary as raw object. .. note:: This function attempts to create a synthetic trigger channel. @@ -389,17 +424,26 @@ def _read_raw_egi_mff(input_fname, eog=None, misc=None, .. versionadded:: 0.15.0 """ - return RawMff(input_fname, eog, misc, include, exclude, - preload, channel_naming, verbose) + return RawMff( + input_fname, eog, misc, include, exclude, preload, channel_naming, verbose + ) class RawMff(BaseRaw): """RawMff class.""" @verbose - def __init__(self, input_fname, eog=None, misc=None, - include=None, exclude=None, preload=False, - channel_naming='E%d', verbose=None): + def __init__( + self, + input_fname, + eog=None, + misc=None, + include=None, + exclude=None, + preload=False, + channel_naming="E%d", + verbose=None, + ): """Init the RawMff class.""" input_fname = str( _check_fname( @@ -410,208 +454,225 @@ def __init__(self, input_fname, eog=None, misc=None, need_dir=True, ) ) - logger.info('Reading EGI MFF Header from %s...' % input_fname) + logger.info("Reading EGI MFF Header from %s..." % input_fname) egi_info = _read_header(input_fname) if eog is None: eog = [] if misc is None: - misc = np.where(np.array( - egi_info['chan_type']) != 'eeg')[0].tolist() + misc = np.where(np.array(egi_info["chan_type"]) != "eeg")[0].tolist() - logger.info(' Reading events ...') + logger.info(" Reading events ...") egi_events, egi_info = _read_events(input_fname, egi_info) cals = _get_eeg_calibration_info(input_fname, egi_info) - logger.info(' Assembling measurement info ...') - if egi_info['n_events'] > 0: - event_codes = list(egi_info['event_codes']) + logger.info(" Assembling measurement info ...") + if egi_info["n_events"] > 0: + event_codes = list(egi_info["event_codes"]) if include is None: - exclude_list = ['sync', 'TREV'] if exclude is None else exclude - exclude_inds = [i for i, k in enumerate(event_codes) if k in - exclude_list] + exclude_list = ["sync", "TREV"] if exclude is None else exclude + exclude_inds = [ + i for i, k in enumerate(event_codes) if k in exclude_list + ] more_excludes = [] if exclude is None: for ii, event in enumerate(egi_events): if event.sum() <= 1 and event_codes[ii]: more_excludes.append(ii) if len(exclude_inds) + len(more_excludes) == len(event_codes): - warn('Did not find any event code with more than one ' - 'event.', RuntimeWarning) + warn( + "Did not find any event code with more than one " "event.", + RuntimeWarning, + ) else: exclude_inds.extend(more_excludes) exclude_inds.sort() - include_ = [i for i in np.arange(egi_info['n_events']) if - i not in exclude_inds] - include_names = [k for i, k in enumerate(event_codes) - if i in include_] + include_ = [ + i for i in np.arange(egi_info["n_events"]) if i not in exclude_inds + ] + include_names = [k for i, k in enumerate(event_codes) if i in include_] else: - include_ = [i for i, k in enumerate(event_codes) - if k in include] + include_ = [i for i, k in enumerate(event_codes) if k in include] include_names = include - for kk, v in [('include', include_names), ('exclude', exclude)]: + for kk, v in [("include", include_names), ("exclude", exclude)]: if isinstance(v, list): for k in v: if k not in event_codes: - raise ValueError( - f'Could not find event named {repr(k)}') + raise ValueError(f"Could not find event named {repr(k)}") elif v is not None: - raise ValueError('`%s` must be None or of type list' % kk) + raise ValueError("`%s` must be None or of type list" % kk) logger.info(' Synthesizing trigger channel "STI 014" ...') - logger.info(' Excluding events {%s} ...' % - ", ".join([k for i, k in enumerate(event_codes) - if i not in include_])) + logger.info( + " Excluding events {%s} ..." + % ", ".join([k for i, k in enumerate(event_codes) if i not in include_]) + ) events_ids = np.arange(len(include_)) + 1 - egi_info['new_trigger'] = _combine_triggers( - egi_events[include_], remapping=events_ids) - self.event_id = dict(zip([e for e in event_codes if e in - include_names], events_ids)) - if egi_info['new_trigger'] is not None: - egi_events = np.vstack([egi_events, egi_info['new_trigger']]) - assert egi_events.shape[1] == egi_info['last_samps'][-1] + egi_info["new_trigger"] = _combine_triggers( + egi_events[include_], remapping=events_ids + ) + self.event_id = dict( + zip([e for e in event_codes if e in include_names], events_ids) + ) + if egi_info["new_trigger"] is not None: + egi_events = np.vstack([egi_events, egi_info["new_trigger"]]) + assert egi_events.shape[1] == egi_info["last_samps"][-1] else: # No events self.event_id = None - egi_info['new_trigger'] = None + egi_info["new_trigger"] = None event_codes = [] - meas_dt_utc = (egi_info['meas_dt_local'] - .astimezone(datetime.timezone.utc)) - info = _empty_info(egi_info['sfreq']) - info['meas_date'] = _ensure_meas_date_none_or_dt(meas_dt_utc) - info['utc_offset'] = egi_info['utc_offset'] - info['device_info'] = dict(type=egi_info['device']) + meas_dt_utc = egi_info["meas_dt_local"].astimezone(datetime.timezone.utc) + info = _empty_info(egi_info["sfreq"]) + info["meas_date"] = _ensure_meas_date_none_or_dt(meas_dt_utc) + info["utc_offset"] = egi_info["utc_offset"] + info["device_info"] = dict(type=egi_info["device"]) # read in the montage, if it exists ch_names, mon = _read_locs(input_fname, egi_info, channel_naming) # Second: Stim - ch_names.extend(list(egi_info['event_codes'])) - if egi_info['new_trigger'] is not None: - ch_names.append('STI 014') # channel for combined events + ch_names.extend(list(egi_info["event_codes"])) + if egi_info["new_trigger"] is not None: + ch_names.append("STI 014") # channel for combined events cals = np.concatenate( - [cals, np.repeat(1, len(event_codes) + 1 + len(misc) + len(eog))]) + [cals, np.repeat(1, len(event_codes) + 1 + len(misc) + len(eog))] + ) # Third: PNS - ch_names.extend(egi_info['pns_names']) - cals = np.concatenate( - [cals, np.repeat(1, len(egi_info['pns_names']))]) + ch_names.extend(egi_info["pns_names"]) + cals = np.concatenate([cals, np.repeat(1, len(egi_info["pns_names"]))]) # Actually create channels as EEG, then update stim and PNS ch_coil = FIFF.FIFFV_COIL_EEG ch_kind = FIFF.FIFFV_EEG_CH chs = _create_chs(ch_names, cals, ch_coil, ch_kind, eog, (), (), misc) - sti_ch_idx = [i for i, name in enumerate(ch_names) if - name.startswith('STI') or name in event_codes] + sti_ch_idx = [ + i + for i, name in enumerate(ch_names) + if name.startswith("STI") or name in event_codes + ] for idx in sti_ch_idx: - chs[idx].update({'unit_mul': FIFF.FIFF_UNITM_NONE, - 'cal': cals[idx], - 'kind': FIFF.FIFFV_STIM_CH, - 'coil_type': FIFF.FIFFV_COIL_NONE, - 'unit': FIFF.FIFF_UNIT_NONE}) + chs[idx].update( + { + "unit_mul": FIFF.FIFF_UNITM_NONE, + "cal": cals[idx], + "kind": FIFF.FIFFV_STIM_CH, + "coil_type": FIFF.FIFFV_COIL_NONE, + "unit": FIFF.FIFF_UNIT_NONE, + } + ) chs = _add_pns_channel_info(chs, egi_info, ch_names) - info['chs'] = chs + info["chs"] = chs info._unlocked = False info._update_redundant() if mon is not None: - info.set_montage(mon, on_missing='ignore') + info.set_montage(mon, on_missing="ignore") ref_idx = np.flatnonzero(np.in1d(mon.ch_names, REFERENCE_NAMES)) if len(ref_idx): ref_idx = ref_idx.item() - ref_coords = info['chs'][int(ref_idx)]['loc'][:3] - for chan in info['chs']: - is_eeg = chan['kind'] == FIFF.FIFFV_EEG_CH - is_not_ref = chan['ch_name'] not in REFERENCE_NAMES + ref_coords = info["chs"][int(ref_idx)]["loc"][:3] + for chan in info["chs"]: + is_eeg = chan["kind"] == FIFF.FIFFV_EEG_CH + is_not_ref = chan["ch_name"] not in REFERENCE_NAMES if is_eeg and is_not_ref: - chan['loc'][3:6] = ref_coords + chan["loc"][3:6] = ref_coords # Cz ref was applied during acquisition, so mark as already set. with info._unlock(): - info['custom_ref_applied'] = FIFF.FIFFV_MNE_CUSTOM_REF_ON - file_bin = op.join(input_fname, egi_info['eeg_fname']) - egi_info['egi_events'] = egi_events + info["custom_ref_applied"] = FIFF.FIFFV_MNE_CUSTOM_REF_ON + file_bin = op.join(input_fname, egi_info["eeg_fname"]) + egi_info["egi_events"] = egi_events # Check how many channels to read are from EEG - keys = ('eeg', 'sti', 'pns') + keys = ("eeg", "sti", "pns") idx = dict() - idx['eeg'] = np.where( - [ch['kind'] == FIFF.FIFFV_EEG_CH for ch in chs])[0] - idx['sti'] = np.where( - [ch['kind'] == FIFF.FIFFV_STIM_CH for ch in chs])[0] - idx['pns'] = np.where( - [ch['kind'] in (FIFF.FIFFV_ECG_CH, FIFF.FIFFV_EMG_CH, - FIFF.FIFFV_BIO_CH) for ch in chs])[0] + idx["eeg"] = np.where([ch["kind"] == FIFF.FIFFV_EEG_CH for ch in chs])[0] + idx["sti"] = np.where([ch["kind"] == FIFF.FIFFV_STIM_CH for ch in chs])[0] + idx["pns"] = np.where( + [ + ch["kind"] in (FIFF.FIFFV_ECG_CH, FIFF.FIFFV_EMG_CH, FIFF.FIFFV_BIO_CH) + for ch in chs + ] + )[0] # By construction this should always be true, but check anyway if not np.array_equal( - np.concatenate([idx[key] for key in keys]), - np.arange(len(chs))): - raise ValueError('Currently interlacing EEG and PNS channels' - 'is not supported') - egi_info['kind_bounds'] = [0] + np.concatenate([idx[key] for key in keys]), np.arange(len(chs)) + ): + raise ValueError( + "Currently interlacing EEG and PNS channels" "is not supported" + ) + egi_info["kind_bounds"] = [0] for key in keys: - egi_info['kind_bounds'].append(len(idx[key])) - egi_info['kind_bounds'] = np.cumsum(egi_info['kind_bounds']) - assert egi_info['kind_bounds'][0] == 0 - assert egi_info['kind_bounds'][-1] == info['nchan'] + egi_info["kind_bounds"].append(len(idx[key])) + egi_info["kind_bounds"] = np.cumsum(egi_info["kind_bounds"]) + assert egi_info["kind_bounds"][0] == 0 + assert egi_info["kind_bounds"][-1] == info["nchan"] first_samps = [0] - last_samps = [egi_info['last_samps'][-1] - 1] + last_samps = [egi_info["last_samps"][-1] - 1] annot = dict(onset=list(), duration=list(), description=list()) - if len(idx['pns']): + if len(idx["pns"]): # PNS Data is present and should be read: - egi_info['pns_filepath'] = op.join( - input_fname, egi_info['pns_fname']) + egi_info["pns_filepath"] = op.join(input_fname, egi_info["pns_fname"]) # Check for PNS bug immediately - pns_samples = np.sum( - egi_info['pns_sample_blocks']['samples_block']) - eeg_samples = np.sum(egi_info['samples_block']) + pns_samples = np.sum(egi_info["pns_sample_blocks"]["samples_block"]) + eeg_samples = np.sum(egi_info["samples_block"]) if pns_samples == eeg_samples - 1: - warn('This file has the EGI PSG sample bug') - annot['onset'].append(last_samps[-1] / egi_info['sfreq']) - annot['duration'].append(1 / egi_info['sfreq']) - annot['description'].append('BAD_EGI_PSG') + warn("This file has the EGI PSG sample bug") + annot["onset"].append(last_samps[-1] / egi_info["sfreq"]) + annot["duration"].append(1 / egi_info["sfreq"]) + annot["description"].append("BAD_EGI_PSG") elif pns_samples != eeg_samples: raise RuntimeError( - 'PNS samples (%d) did not match EEG samples (%d)' - % (pns_samples, eeg_samples)) + "PNS samples (%d) did not match EEG samples (%d)" + % (pns_samples, eeg_samples) + ) self._filenames = [file_bin] self._raw_extras = [egi_info] super(RawMff, self).__init__( - info, preload=preload, orig_format="single", filenames=[file_bin], - first_samps=first_samps, last_samps=last_samps, - raw_extras=[egi_info], verbose=verbose) + info, + preload=preload, + orig_format="single", + filenames=[file_bin], + first_samps=first_samps, + last_samps=last_samps, + raw_extras=[egi_info], + verbose=verbose, + ) # Annotate acquisition skips - for first, prev_last in zip(egi_info['first_samps'][1:], - egi_info['last_samps'][:-1]): + for first, prev_last in zip( + egi_info["first_samps"][1:], egi_info["last_samps"][:-1] + ): gap = first - prev_last assert gap >= 0 if gap: - annot['onset'].append((prev_last - 0.5) / egi_info['sfreq']) - annot['duration'].append(gap / egi_info['sfreq']) - annot['description'].append('BAD_ACQ_SKIP') + annot["onset"].append((prev_last - 0.5) / egi_info["sfreq"]) + annot["duration"].append(gap / egi_info["sfreq"]) + annot["description"].append("BAD_ACQ_SKIP") - if len(annot['onset']): + if len(annot["onset"]): self.set_annotations(Annotations(**annot, orig_time=None)) def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): """Read a chunk of data.""" - logger.debug(f'Reading MFF {start:6d} ... {stop:6d} ...') - dtype = ' -1)[0] # short circuit in case we don't need any samples if not len(disk_use_idx): @@ -645,13 +706,14 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): # Get starting/stopping block/samples block_samples_offset = np.cumsum(samples_block) offset_blocks = np.sum(block_samples_offset <= start) - offset_samples = start - (block_samples_offset[offset_blocks - 1] - if offset_blocks > 0 else 0) + offset_samples = start - ( + block_samples_offset[offset_blocks - 1] if offset_blocks > 0 else 0 + ) # TODO: Refactor this reading with the PNS reading in a single function # (DRY) samples_to_read = stop - start - with open(self._filenames[fi], 'rb', buffering=0) as fid: + with open(self._filenames[fi], "rb", buffering=0) as fid: # Go to starting block current_block = 0 current_block_info = None @@ -660,26 +722,25 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): this_block_info = _block_r(fid) if this_block_info is not None: current_block_info = this_block_info - fid.seek(current_block_info['block_size'], 1) + fid.seek(current_block_info["block_size"], 1) current_block += 1 # Start reading samples while samples_to_read > 0: - logger.debug(f' Reading from block {current_block}') + logger.debug(f" Reading from block {current_block}") this_block_info = _block_r(fid) current_block += 1 if this_block_info is not None: current_block_info = this_block_info - to_read = (current_block_info['nsamples'] * - current_block_info['nc']) + to_read = current_block_info["nsamples"] * current_block_info["nc"] block_data = np.fromfile(fid, dtype, to_read) - block_data = block_data.reshape(n_channels, -1, order='C') + block_data = block_data.reshape(n_channels, -1, order="C") # Compute indexes samples_read = block_data.shape[1] - logger.debug(f' Read {samples_read} samples') - logger.debug(f' Offset {offset_samples} samples') + logger.debug(f" Read {samples_read} samples") + logger.debug(f" Offset {offset_samples} samples") if offset_samples > 0: # First block read, skip to the offset: block_data = block_data[:, offset_samples:] @@ -689,7 +750,7 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): # Last block to read, skip the last samples block_data = block_data[:, :samples_to_read] samples_read = samples_to_read - logger.debug(f' Keep {samples_read} samples') + logger.debug(f" Keep {samples_read} samples") s_start = current_data_sample s_end = s_start + samples_read @@ -700,19 +761,20 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): if len(pns_one) > 0: # PNS Data is present and should be read: - pns_filepath = egi_info['pns_filepath'] - pns_info = egi_info['pns_sample_blocks'] - n_channels = pns_info['n_channels'] - samples_block = pns_info['samples_block'] + pns_filepath = egi_info["pns_filepath"] + pns_info = egi_info["pns_sample_blocks"] + n_channels = pns_info["n_channels"] + samples_block = pns_info["samples_block"] # Get starting/stopping block/samples block_samples_offset = np.cumsum(samples_block) offset_blocks = np.sum(block_samples_offset < start) - offset_samples = start - (block_samples_offset[offset_blocks - 1] - if offset_blocks > 0 else 0) + offset_samples = start - ( + block_samples_offset[offset_blocks - 1] if offset_blocks > 0 else 0 + ) samples_to_read = stop - start - with open(pns_filepath, 'rb', buffering=0) as fid: + with open(pns_filepath, "rb", buffering=0) as fid: # Check file size fid.seek(0, 2) file_size = fid.tell() @@ -725,7 +787,7 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): this_block_info = _block_r(fid) if this_block_info is not None: current_block_info = this_block_info - fid.seek(current_block_info['block_size'], 1) + fid.seek(current_block_info["block_size"], 1) current_block += 1 # Start reading samples @@ -740,10 +802,9 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): if this_block_info is not None: current_block_info = this_block_info - to_read = (current_block_info['nsamples'] * - current_block_info['nc']) + to_read = current_block_info["nsamples"] * current_block_info["nc"] block_data = np.fromfile(fid, dtype, to_read) - block_data = block_data.reshape(n_channels, -1, order='C') + block_data = block_data.reshape(n_channels, -1, order="C") # Compute indexes samples_read = block_data.shape[1] @@ -761,8 +822,7 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): s_start = current_data_sample s_end = s_start + samples_read - one[pns_one, disk_use_idx[s_start:s_end]] = \ - block_data[pns_in] + one[pns_one, disk_use_idx[s_start:s_end]] = block_data[pns_in] samples_to_read = samples_to_read - samples_read current_data_sample = current_data_sample + samples_read @@ -771,8 +831,9 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): @verbose -def read_evokeds_mff(fname, condition=None, channel_naming='E%d', - baseline=None, verbose=None): +def read_evokeds_mff( + fname, condition=None, channel_naming="E%d", baseline=None, verbose=None +): """Read averaged MFF file as EvokedArray or list of EvokedArray. Parameters @@ -833,13 +894,17 @@ def read_evokeds_mff(fname, condition=None, channel_naming='E%d', flavor = mff.mff_flavor except AttributeError: # < 6.3 flavor = mff.flavor - if flavor not in ('averaged', 'segmented'): # old, new names - raise ValueError(f'{fname} is a {flavor} MFF file. ' - 'fname must be the path to an averaged MFF file.') + if flavor not in ("averaged", "segmented"): # old, new names + raise ValueError( + f"{fname} is a {flavor} MFF file. " + "fname must be the path to an averaged MFF file." + ) # Check for categories.xml file - if 'categories.xml' not in mff.directory.listdir(): - raise ValueError('categories.xml not found in MFF directory. ' - f'{fname} may not be an averaged MFF file.') + if "categories.xml" not in mff.directory.listdir(): + raise ValueError( + "categories.xml not found in MFF directory. " + f"{fname} may not be an averaged MFF file." + ) return_list = True if condition is None: categories = mff.categories.categories @@ -847,135 +912,149 @@ def read_evokeds_mff(fname, condition=None, channel_naming='E%d', elif not isinstance(condition, list): condition = [condition] return_list = False - logger.info(f'Reading {len(condition)} evoked datasets from {fname} ...') - output = [_read_evoked_mff(fname, c, channel_naming=channel_naming, - verbose=verbose).apply_baseline(baseline) - for c in condition] + logger.info(f"Reading {len(condition)} evoked datasets from {fname} ...") + output = [ + _read_evoked_mff( + fname, c, channel_naming=channel_naming, verbose=verbose + ).apply_baseline(baseline) + for c in condition + ] return output if return_list else output[0] -def _read_evoked_mff(fname, condition, channel_naming='E%d', verbose=None): +def _read_evoked_mff(fname, condition, channel_naming="E%d", verbose=None): """Read evoked data from MFF file.""" import mffpy + egi_info = _read_header(fname) mff = mffpy.Reader(fname) categories = mff.categories.categories if isinstance(condition, str): # Condition is interpreted as category name - category = _check_option('condition', condition, categories, - extra='provided as category name') + category = _check_option( + "condition", condition, categories, extra="provided as category name" + ) epoch = mff.epochs[category] elif isinstance(condition, int): # Condition is interpreted as epoch index try: epoch = mff.epochs[condition] except IndexError: - raise ValueError(f'"condition" parameter ({condition}), provided ' - 'as epoch index, is out of range for available ' - f'epochs ({len(mff.epochs)}).') + raise ValueError( + f'"condition" parameter ({condition}), provided ' + "as epoch index, is out of range for available " + f"epochs ({len(mff.epochs)})." + ) category = epoch.name else: raise TypeError('"condition" parameter must be either int or str.') # Read in signals from the target epoch data = mff.get_physical_samples_from_epoch(epoch) - eeg_data, t0 = data['EEG'] - if 'PNSData' in data: - pns_data, t0 = data['PNSData'] + eeg_data, t0 = data["EEG"] + if "PNSData" in data: + pns_data, t0 = data["PNSData"] all_data = np.vstack((eeg_data, pns_data)) - ch_types = egi_info['chan_type'] + egi_info['pns_types'] + ch_types = egi_info["chan_type"] + egi_info["pns_types"] else: all_data = eeg_data - ch_types = egi_info['chan_type'] + ch_types = egi_info["chan_type"] all_data *= 1e-6 # convert to volts # Load metadata into info object # Exclude info['meas_date'] because record time info in # averaged MFF is the time of the averaging, not true record time. ch_names, mon = _read_locs(fname, egi_info, channel_naming) - ch_names.extend(egi_info['pns_names']) - info = create_info(ch_names, mff.sampling_rates['EEG'], ch_types) + ch_names.extend(egi_info["pns_names"]) + info = create_info(ch_names, mff.sampling_rates["EEG"], ch_types) with info._unlock(): - info['device_info'] = dict(type=egi_info['device']) - info['nchan'] = sum(mff.num_channels.values()) + info["device_info"] = dict(type=egi_info["device"]) + info["nchan"] = sum(mff.num_channels.values()) # Add individual channel info # Get calibration info for EEG channels cals = _get_eeg_calibration_info(fname, egi_info) # Initialize calibration for PNS channels, will be updated later - cals = np.concatenate([cals, np.repeat(1, len(egi_info['pns_names']))]) + cals = np.concatenate([cals, np.repeat(1, len(egi_info["pns_names"]))]) ch_coil = FIFF.FIFFV_COIL_EEG ch_kind = FIFF.FIFFV_EEG_CH chs = _create_chs(ch_names, cals, ch_coil, ch_kind, (), (), (), ()) # Update PNS channel info chs = _add_pns_channel_info(chs, egi_info, ch_names) with info._unlock(): - info['chs'] = chs + info["chs"] = chs if mon is not None: - info.set_montage(mon, on_missing='ignore') + info.set_montage(mon, on_missing="ignore") # Add bad channels to info - info['description'] = category + info["description"] = category try: - channel_status = categories[category][0]['channelStatus'] + channel_status = categories[category][0]["channelStatus"] except KeyError: - warn(f'Channel status data not found for condition {category}. ' - 'No channels will be marked as bad.', category=UserWarning) + warn( + f"Channel status data not found for condition {category}. " + "No channels will be marked as bad.", + category=UserWarning, + ) channel_status = None bads = [] if channel_status: for entry in channel_status: - if entry['exclusion'] == 'badChannels': - if entry['signalBin'] == 1: + if entry["exclusion"] == "badChannels": + if entry["signalBin"] == 1: # Add bad EEG channels - for ch in entry['channels']: + for ch in entry["channels"]: bads.append(ch_names[ch - 1]) - elif entry['signalBin'] == 2: + elif entry["signalBin"] == 2: # Add bad PNS channels - for ch in entry['channels']: - bads.append(egi_info['pns_names'][ch - 1]) - info['bads'] = bads + for ch in entry["channels"]: + bads.append(egi_info["pns_names"][ch - 1]) + info["bads"] = bads # Add EEG reference to info # Initialize 'custom_ref_applied' to False with info._unlock(): - info['custom_ref_applied'] = False + info["custom_ref_applied"] = False try: - fp = mff.directory.filepointer('history') + fp = mff.directory.filepointer("history") except (ValueError, FileNotFoundError): # old (<=0.6.3) vs new mffpy pass else: with fp: history = mffpy.XML.from_file(fp) for entry in history.entries: - if entry['method'] == 'Montage Operations Tool': - if 'Average Reference' in entry['settings']: + if entry["method"] == "Montage Operations Tool": + if "Average Reference" in entry["settings"]: # Average reference has been applied projector, info = setup_proj(info) else: # Custom reference has been applied that is not an average - info['custom_ref_applied'] = True + info["custom_ref_applied"] = True # Get nave from categories.xml try: - nave = categories[category][0]['keys']['#seg']['data'] + nave = categories[category][0]["keys"]["#seg"]["data"] except KeyError: - warn(f'Number of averaged epochs not found for condition {category}. ' - 'nave will default to 1.', category=UserWarning) + warn( + f"Number of averaged epochs not found for condition {category}. " + "nave will default to 1.", + category=UserWarning, + ) nave = 1 # Let tmin default to 0 - return EvokedArray(all_data, info, tmin=0., comment=category, - nave=nave, verbose=verbose) + return EvokedArray( + all_data, info, tmin=0.0, comment=category, nave=nave, verbose=verbose + ) -def _import_mffpy(why='read averaged .mff files'): +def _import_mffpy(why="read averaged .mff files"): """Import and return module mffpy.""" try: import mffpy except ImportError as exp: - msg = f'mffpy is required to {why}, got:\n{exp}' + msg = f"mffpy is required to {why}, got:\n{exp}" raise ImportError(msg) return mffpy diff --git a/mne/io/egi/events.py b/mne/io/egi/events.py index 196a6ea717a..d9450913aa6 100644 --- a/mne/io/egi/events.py +++ b/mne/io/egi/events.py @@ -21,11 +21,11 @@ def _read_events(input_fname, info): info : dict Header info array. """ - n_samples = info['last_samps'][-1] - mff_events, event_codes = _read_mff_events(input_fname, info['sfreq']) - info['n_events'] = len(event_codes) - info['event_codes'] = event_codes - events = np.zeros([info['n_events'], info['n_segments'] * n_samples]) + n_samples = info["last_samps"][-1] + mff_events, event_codes = _read_mff_events(input_fname, info["sfreq"]) + info["n_events"] = len(event_codes) + info["event_codes"] = event_codes + events = np.zeros([info["n_events"], info["n_segments"] * n_samples]) for n, event in enumerate(event_codes): for i in mff_events[event]: if (i < 0) or (i >= events.shape[1]): @@ -45,34 +45,36 @@ def _read_mff_events(filename, sfreq): The sampling frequency """ orig = {} - for xml_file in glob(join(filename, '*.xml')): + for xml_file in glob(join(filename, "*.xml")): xml_type = splitext(basename(xml_file))[0] orig[xml_type] = _parse_xml(xml_file) xml_files = orig.keys() - xml_events = [x for x in xml_files if x[:7] == 'Events_'] - for item in orig['info']: - if 'recordTime' in item: - start_time = _ns2py_time(item['recordTime']) + xml_events = [x for x in xml_files if x[:7] == "Events_"] + for item in orig["info"]: + if "recordTime" in item: + start_time = _ns2py_time(item["recordTime"]) break markers = [] code = [] for xml in xml_events: for event in orig[xml][2:]: - event_start = _ns2py_time(event['beginTime']) + event_start = _ns2py_time(event["beginTime"]) start = (event_start - start_time).total_seconds() - if event['code'] not in code: - code.append(event['code']) - marker = {'name': event['code'], - 'start': start, - 'start_sample': int(np.fix(start * sfreq)), - 'end': start + float(event['duration']) / 1e9, - 'chan': None, - } + if event["code"] not in code: + code.append(event["code"]) + marker = { + "name": event["code"], + "start": start, + "start_sample": int(np.fix(start * sfreq)), + "end": start + float(event["duration"]) / 1e9, + "chan": None, + } markers.append(marker) events_tims = dict() for ev in code: - trig_samp = list(c['start_sample'] for n, - c in enumerate(markers) if c['name'] == ev) + trig_samp = list( + c["start_sample"] for n, c in enumerate(markers) if c["name"] == ev + ) events_tims.update({ev: trig_samp}) return events_tims, code @@ -88,7 +90,6 @@ def _xml2list(root): """Parse XML item.""" output = [] for element in root: - if len(element) > 0: if element[0].tag != element[-1].tag: output.append(_xml2dict(element)) @@ -106,8 +107,8 @@ def _xml2list(root): def _ns(s): """Remove namespace, but only if there is a namespace to begin with.""" - if '}' in s: - return '}'.join(s.split('}')[1:]) + if "}" in s: + return "}".join(s.split("}")[1:]) else: return s @@ -146,7 +147,7 @@ def _ns2py_time(nstime): nsdate = nstime[0:10] nstime0 = nstime[11:26] nstime00 = nsdate + " " + nstime0 - pytime = datetime.strptime(nstime00, '%Y-%m-%d %H:%M:%S.%f') + pytime = datetime.strptime(nstime00, "%Y-%m-%d %H:%M:%S.%f") return pytime @@ -154,8 +155,10 @@ def _combine_triggers(data, remapping=None): """Combine binary triggers.""" new_trigger = np.zeros(data.shape[1]) if data.astype(bool).sum(axis=0).max() > 1: # ensure no overlaps - logger.info(' Found multiple events at the same time ' - 'sample. Cannot create trigger channel.') + logger.info( + " Found multiple events at the same time " + "sample. Cannot create trigger channel." + ) return if remapping is None: remapping = np.arange(data) + 1 diff --git a/mne/io/egi/general.py b/mne/io/egi/general.py index c364e0eb9c7..6b8829ca4e6 100644 --- a/mne/io/egi/general.py +++ b/mne/io/egi/general.py @@ -17,7 +17,7 @@ def _extract(tags, filepath=None, obj=None): elif filepath is not None: fileobj = parse(filepath) else: - raise ValueError('There is not object or file to extract data') + raise ValueError("There is not object or file to extract data") infoxml = dict() for tag in tags: value = fileobj.getElementsByTagName(tag) @@ -30,38 +30,35 @@ def _extract(tags, filepath=None, obj=None): def _get_gains(filepath): """Parse gains.""" file_obj = parse(filepath) - objects = file_obj.getElementsByTagName('calibration') + objects = file_obj.getElementsByTagName("calibration") gains = dict() for ob in objects: - value = ob.getElementsByTagName('type') - if value[0].firstChild.data == 'GCAL': - data_g = _extract(['ch'], obj=ob)['ch'] + value = ob.getElementsByTagName("type") + if value[0].firstChild.data == "GCAL": + data_g = _extract(["ch"], obj=ob)["ch"] gains.update(gcal=np.asarray(data_g, dtype=np.float64)) - elif value[0].firstChild.data == 'ICAL': - data_g = _extract(['ch'], obj=ob)['ch'] + elif value[0].firstChild.data == "ICAL": + data_g = _extract(["ch"], obj=ob)["ch"] gains.update(ical=np.asarray(data_g, dtype=np.float64)) return gains def _get_ep_info(filepath): """Get epoch info.""" - epochfile = filepath + '/epochs.xml' + epochfile = filepath + "/epochs.xml" epochlist = parse(epochfile) - epochs = epochlist.getElementsByTagName('epoch') - keys = ('first_samps', 'last_samps', 'first_blocks', 'last_blocks') + epochs = epochlist.getElementsByTagName("epoch") + keys = ("first_samps", "last_samps", "first_blocks", "last_blocks") epoch_info = {key: list() for key in keys} for epoch in epochs: - ep_begin = int(epoch.getElementsByTagName('beginTime')[0] - .firstChild.data) - ep_end = int(epoch.getElementsByTagName('endTime')[0].firstChild.data) - first_block = int(epoch.getElementsByTagName('firstBlock')[0] - .firstChild.data) - last_block = int(epoch.getElementsByTagName('lastBlock')[0] - .firstChild.data) - epoch_info['first_samps'].append(ep_begin) - epoch_info['last_samps'].append(ep_end) - epoch_info['first_blocks'].append(first_block) - epoch_info['last_blocks'].append(last_block) + ep_begin = int(epoch.getElementsByTagName("beginTime")[0].firstChild.data) + ep_end = int(epoch.getElementsByTagName("endTime")[0].firstChild.data) + first_block = int(epoch.getElementsByTagName("firstBlock")[0].firstChild.data) + last_block = int(epoch.getElementsByTagName("lastBlock")[0].firstChild.data) + epoch_info["first_samps"].append(ep_begin) + epoch_info["last_samps"].append(ep_end) + epoch_info["first_blocks"].append(first_block) + epoch_info["last_blocks"].append(last_block) # Don't turn into ndarray here, keep native int because it can deal with # huge numbers (could use np.uint64 but it's more work) return epoch_info @@ -82,7 +79,7 @@ def _get_blocks(filepath): # * 1 byte of n_channels # * n_channels bytes of offsets # * n_channels bytes of sigfreqs? - with open(binfile, 'rb') as fid: + with open(binfile, "rb") as fid: fid.seek(0, 2) # go to end of file file_length = fid.tell() block_size = file_length @@ -96,79 +93,86 @@ def _get_blocks(filepath): fid.seek(block_size, 1) position = fid.tell() continue - block_size = block['block_size'] - header_size = block['header_size'] + block_size = block["block_size"] + header_size = block["header_size"] header_sizes.append(header_size) - samples_block.append(block['nsamples']) + samples_block.append(block["nsamples"]) n_blocks += 1 fid.seek(block_size, 1) - sfreq.append(block['sfreq']) - n_channels.append(block['nc']) + sfreq.append(block["sfreq"]) + n_channels.append(block["nc"]) position = fid.tell() if any([n != n_channels[0] for n in n_channels]): - raise RuntimeError("All the blocks don't have the same amount of " - "channels.") + raise RuntimeError("All the blocks don't have the same amount of " "channels.") if any([f != sfreq[0] for f in sfreq]): - raise RuntimeError("All the blocks don't have the same sampling " - "frequency.") + raise RuntimeError("All the blocks don't have the same sampling " "frequency.") if len(samples_block) < 1: raise RuntimeError("There seems to be no data") samples_block = np.array(samples_block) - signal_blocks = dict(n_channels=n_channels[0], sfreq=sfreq[0], - n_blocks=n_blocks, samples_block=samples_block, - header_sizes=header_sizes) + signal_blocks = dict( + n_channels=n_channels[0], + sfreq=sfreq[0], + n_blocks=n_blocks, + samples_block=samples_block, + header_sizes=header_sizes, + ) return signal_blocks def _get_signalfname(filepath): """Get filenames.""" listfiles = os.listdir(filepath) - binfiles = list(f for f in listfiles if 'signal' in f and - f[-4:] == '.bin' and f[0] != '.') + binfiles = list( + f for f in listfiles if "signal" in f and f[-4:] == ".bin" and f[0] != "." + ) all_files = {} infofiles = list() for binfile in binfiles: - bin_num_str = re.search(r'\d+', binfile).group() - infofile = 'info' + bin_num_str + '.xml' + bin_num_str = re.search(r"\d+", binfile).group() + infofile = "info" + bin_num_str + ".xml" infofiles.append(infofile) infobjfile = os.path.join(filepath, infofile) infobj = parse(infobjfile) - if len(infobj.getElementsByTagName('EEG')): - signal_type = 'EEG' - elif len(infobj.getElementsByTagName('PNSData')): - signal_type = 'PNS' + if len(infobj.getElementsByTagName("EEG")): + signal_type = "EEG" + elif len(infobj.getElementsByTagName("PNSData")): + signal_type = "PNS" all_files[signal_type] = { - 'signal': 'signal{}.bin'.format(bin_num_str), - 'info': infofile} - if 'EEG' not in all_files: + "signal": "signal{}.bin".format(bin_num_str), + "info": infofile, + } + if "EEG" not in all_files: raise FileNotFoundError( - 'Could not find any EEG data in the %d file%s found in %s:\n%s' - % (len(infofiles), _pl(infofiles), filepath, '\n'.join(infofiles))) + "Could not find any EEG data in the %d file%s found in %s:\n%s" + % (len(infofiles), _pl(infofiles), filepath, "\n".join(infofiles)) + ) return all_files def _block_r(fid): """Read meta data.""" - if np.fromfile(fid, dtype=np.dtype('i4'), count=1).item() != 1: # not meta + if np.fromfile(fid, dtype=np.dtype("i4"), count=1).item() != 1: # not meta return None - header_size = np.fromfile(fid, dtype=np.dtype('i4'), count=1).item() - block_size = np.fromfile(fid, dtype=np.dtype('i4'), count=1).item() + header_size = np.fromfile(fid, dtype=np.dtype("i4"), count=1).item() + block_size = np.fromfile(fid, dtype=np.dtype("i4"), count=1).item() hl = int(block_size / 4) - nc = np.fromfile(fid, dtype=np.dtype('i4'), count=1).item() + nc = np.fromfile(fid, dtype=np.dtype("i4"), count=1).item() nsamples = int(hl / nc) - np.fromfile(fid, dtype=np.dtype('i4'), count=nc) # sigoffset - sigfreq = np.fromfile(fid, dtype=np.dtype('i4'), count=nc) + np.fromfile(fid, dtype=np.dtype("i4"), count=nc) # sigoffset + sigfreq = np.fromfile(fid, dtype=np.dtype("i4"), count=nc) depth = sigfreq[0] & 0xFF if depth != 32: - raise ValueError('I do not know how to read this MFF (depth != 32)') + raise ValueError("I do not know how to read this MFF (depth != 32)") sfreq = sigfreq[0] >> 8 count = int(header_size / 4 - (4 + 2 * nc)) - np.fromfile(fid, dtype=np.dtype('i4'), count=count) # sigoffset - block = dict(nc=nc, - hl=hl, - nsamples=nsamples, - block_size=block_size, - header_size=header_size, - sfreq=sfreq) + np.fromfile(fid, dtype=np.dtype("i4"), count=count) # sigoffset + block = dict( + nc=nc, + hl=hl, + nsamples=nsamples, + block_size=block_size, + header_size=header_size, + sfreq=sfreq, + ) return block diff --git a/mne/io/egi/tests/test_egi.py b/mne/io/egi/tests/test_egi.py index 45b0ca1109e..9086253cad5 100644 --- a/mne/io/egi/tests/test_egi.py +++ b/mne/io/egi/tests/test_egi.py @@ -35,17 +35,21 @@ egi_txt_evoked_cat2_fname = egi_path / "test_egi_evoked_cat2.txt" # absolute event times from NetStation -egi_pause_events = {'AM40': [7.224, 11.928, 14.413, 16.848], - 'bgin': [6.121, 8.434, 13.369, 15.815, 18.094], - 'FIX+': [6.225, 10.929, 13.414, 15.849], - 'ITI+': [8.293, 12.997, 15.482, 17.918]} +egi_pause_events = { + "AM40": [7.224, 11.928, 14.413, 16.848], + "bgin": [6.121, 8.434, 13.369, 15.815, 18.094], + "FIX+": [6.225, 10.929, 13.414, 15.849], + "ITI+": [8.293, 12.997, 15.482, 17.918], +} # absolute epoch times egi_pause_skips = [(1304000.0, 1772000.0), (8660000.0, 12296000.0)] -egi_eprime_pause_events = {'AM40': [6.049, 8.434, 10.936, 13.321], - 'bgin': [4.902, 7.381, 9.901, 12.268, 14.619], - 'FIX+': [5.050, 7.435, 9.937, 12.322], - 'ITI+': [7.185, 9.503, 12.005, 14.391]} +egi_eprime_pause_events = { + "AM40": [6.049, 8.434, 10.936, 13.321], + "bgin": [4.902, 7.381, 9.901, 12.268, 14.619], + "FIX+": [5.050, 7.435, 9.937, 12.322], + "ITI+": [7.185, 9.503, 12.005, 14.391], +} egi_eprime_pause_skips = [(1344000.0, 1804000.0)] egi_pause_w1337_events = None @@ -53,69 +57,79 @@ @requires_testing_data -@pytest.mark.parametrize('fname, skip_times, event_times', [ - (egi_pause_fname, egi_pause_skips, egi_pause_events), - (egi_eprime_pause_fname, egi_eprime_pause_skips, egi_eprime_pause_events), - (egi_pause_w1337_fname, egi_pause_w1337_skips, egi_pause_w1337_events), -]) +@pytest.mark.parametrize( + "fname, skip_times, event_times", + [ + (egi_pause_fname, egi_pause_skips, egi_pause_events), + (egi_eprime_pause_fname, egi_eprime_pause_skips, egi_eprime_pause_events), + (egi_pause_w1337_fname, egi_pause_w1337_skips, egi_pause_w1337_events), + ], +) def test_egi_mff_pause(fname, skip_times, event_times): """Test EGI MFF with pauses.""" if fname == egi_pause_w1337_fname: # too slow to _test_raw_reader raw = read_raw_egi(fname).load_data() else: - with pytest.warns(RuntimeWarning, match='Acquisition skips detected'): - raw = _test_raw_reader(read_raw_egi, input_fname=fname, - test_scaling=False, # XXX probably some bug - test_rank='less', - ) - assert raw.info['sfreq'] == 250. # true for all of these files + with pytest.warns(RuntimeWarning, match="Acquisition skips detected"): + raw = _test_raw_reader( + read_raw_egi, + input_fname=fname, + test_scaling=False, # XXX probably some bug + test_rank="less", + ) + assert raw.info["sfreq"] == 250.0 # true for all of these files assert len(raw.annotations) == len(skip_times) # assert event onsets match expected times if event_times is None: - with pytest.raises(ValueError, match='Consider using .*events_from'): + with pytest.raises(ValueError, match="Consider using .*events_from"): find_events(raw) else: events = find_events(raw) for event_type in event_times.keys(): - ns_samples = np.floor(np.array(event_times[event_type]) * - raw.info['sfreq']) + ns_samples = np.floor(np.array(event_times[event_type]) * raw.info["sfreq"]) assert_array_equal( - events[events[:, 2] == raw.event_id[event_type], 0], - ns_samples) + events[events[:, 2] == raw.event_id[event_type], 0], ns_samples + ) # read some data from the middle of the skip, assert it's all zeros stim_picks = pick_types(raw.info, meg=False, stim=True, exclude=()) other_picks = np.setdiff1d(np.arange(len(raw.ch_names)), stim_picks) for ii, annot in enumerate(raw.annotations): - assert annot['description'] == 'BAD_ACQ_SKIP' + assert annot["description"] == "BAD_ACQ_SKIP" start, stop = raw.time_as_index( - [annot['onset'], annot['onset'] + annot['duration']]) + [annot["onset"], annot["onset"] + annot["duration"]] + ) data, _ = raw[:, start:stop] - assert_array_equal(data[other_picks], 0.) + assert_array_equal(data[other_picks], 0.0) if event_times is not None: - assert raw.ch_names[-1] == 'STI 014' - assert not np.array_equal(data[stim_picks], 0.) + assert raw.ch_names[-1] == "STI 014" + assert not np.array_equal(data[stim_picks], 0.0) # assert skips match expected onset and duration - skip = ((start + 1) / raw.info['sfreq'] * 1e6, - (stop + 1) / raw.info['sfreq'] * 1e6) + skip = ( + (start + 1) / raw.info["sfreq"] * 1e6, + (stop + 1) / raw.info["sfreq"] * 1e6, + ) assert skip == skip_times[ii] @requires_testing_data -@pytest.mark.parametrize('fname', [ - egi_pause_fname, - egi_eprime_pause_fname, - egi_pause_w1337_fname, -]) +@pytest.mark.parametrize( + "fname", + [ + egi_pause_fname, + egi_eprime_pause_fname, + egi_pause_w1337_fname, + ], +) def test_egi_mff_pause_chunks(fname, tmp_path): """Test that on-demand of all short segments works (via I/O).""" - fname_temp = tmp_path / 'test_raw.fif' + fname_temp = tmp_path / "test_raw.fif" raw_data = read_raw_egi(fname, preload=True).get_data() raw = read_raw_egi(fname) - with pytest.warns(RuntimeWarning, match='Acquisition skips detected'): + with pytest.warns(RuntimeWarning, match="Acquisition skips detected"): raw.save(fname_temp) del raw raw_data_2 = read_raw_fif(fname_temp).get_data() @@ -131,49 +145,52 @@ def test_io_egi_mff(): n_card = 3 raw = read_raw_egi(egi_mff_fname, include=None) - assert ('RawMff' in repr(raw)) + assert "RawMff" in repr(raw) assert raw.orig_format == "single" - include = ['DIN1', 'DIN2', 'DIN3', 'DIN4', 'DIN5', 'DIN7'] - raw = _test_raw_reader(read_raw_egi, input_fname=egi_mff_fname, - include=include, channel_naming='EEG %03d', - test_scaling=False, # XXX probably some bug - ) - assert raw.info['sfreq'] == 1000. - assert len(raw.info['dig']) == n_card + n_eeg + n_ref - assert raw.info['dig'][0]['ident'] == FIFF.FIFFV_POINT_LPA - assert raw.info['dig'][0]['kind'] == FIFF.FIFFV_POINT_CARDINAL - assert raw.info['dig'][3]['kind'] == FIFF.FIFFV_POINT_EEG - assert raw.info['dig'][-1]['ident'] == 129 - assert raw.info['custom_ref_applied'] == FIFF.FIFFV_MNE_CUSTOM_REF_ON - ref_loc = raw.info['dig'][-1]['r'] + include = ["DIN1", "DIN2", "DIN3", "DIN4", "DIN5", "DIN7"] + raw = _test_raw_reader( + read_raw_egi, + input_fname=egi_mff_fname, + include=include, + channel_naming="EEG %03d", + test_scaling=False, # XXX probably some bug + ) + assert raw.info["sfreq"] == 1000.0 + assert len(raw.info["dig"]) == n_card + n_eeg + n_ref + assert raw.info["dig"][0]["ident"] == FIFF.FIFFV_POINT_LPA + assert raw.info["dig"][0]["kind"] == FIFF.FIFFV_POINT_CARDINAL + assert raw.info["dig"][3]["kind"] == FIFF.FIFFV_POINT_EEG + assert raw.info["dig"][-1]["ident"] == 129 + assert raw.info["custom_ref_applied"] == FIFF.FIFFV_MNE_CUSTOM_REF_ON + ref_loc = raw.info["dig"][-1]["r"] eeg_picks = pick_types(raw.info, eeg=True) assert len(eeg_picks) == n_eeg + n_ref # 129 # ref channel doesn't store its own loc as ref location # so don't test it - ref_pick = pick_channels(raw.info['ch_names'], ['VREF']) + ref_pick = pick_channels(raw.info["ch_names"], ["VREF"]) eeg_picks = np.setdiff1d(eeg_picks, ref_pick) for i in eeg_picks: - loc = raw.info['chs'][i]['loc'] + loc = raw.info["chs"][i]["loc"] assert loc[:3].any(), loc[:3] - assert_array_equal(loc[3:6], ref_loc, err_msg=f'{i}') - assert raw.info['device_info']['type'] == 'HydroCel GSN 128 1.0' + assert_array_equal(loc[3:6], ref_loc, err_msg=f"{i}") + assert raw.info["device_info"]["type"] == "HydroCel GSN 128 1.0" - assert 'eeg' in raw + assert "eeg" in raw # test our custom channel naming logic functionality - eeg_chan = [c for c in raw.ch_names if 'EEG' in c] + eeg_chan = [c for c in raw.ch_names if "EEG" in c] assert len(eeg_chan) == n_eeg # 128: VREF will not match in comprehension - assert 'STI 014' in raw.ch_names + assert "STI 014" in raw.ch_names - events = find_events(raw, stim_channel='STI 014') + events = find_events(raw, stim_channel="STI 014") assert len(events) == 8 assert np.unique(events[:, 1])[0] == 0 assert np.unique(events[:, 0])[0] != 0 assert np.unique(events[:, 2])[0] != 0 - with pytest.raises(ValueError, match='Could not find event'): - read_raw_egi(egi_mff_fname, include=['Foo']) - with pytest.raises(ValueError, match='Could not find event'): - read_raw_egi(egi_mff_fname, exclude=['Bar']) + with pytest.raises(ValueError, match="Could not find event"): + read_raw_egi(egi_mff_fname, include=["Foo"]) + with pytest.raises(ValueError, match="Could not find event"): + read_raw_egi(egi_mff_fname, exclude=["Bar"]) for ii, k in enumerate(include, 1): assert k in raw.event_id assert raw.event_id[k] == ii @@ -188,33 +205,36 @@ def test_io_egi(): data = data[1:] data *= 1e-6 # µV - with pytest.warns(RuntimeWarning, match='Did not find any event code'): + with pytest.warns(RuntimeWarning, match="Did not find any event code"): raw = read_raw_egi(egi_fname, include=None) # The reader should accept a Path, too. - with pytest.warns(RuntimeWarning, match='Did not find any event code'): + with pytest.warns(RuntimeWarning, match="Did not find any event code"): raw = read_raw_egi(Path(egi_fname), include=None) - assert 'RawEGI' in repr(raw) + assert "RawEGI" in repr(raw) data_read, t_read = raw[:256] assert_allclose(t_read, t) assert_allclose(data_read, data, atol=1e-10) - include = ['TRSP', 'XXX1'] - raw = _test_raw_reader(read_raw_egi, input_fname=egi_fname, - include=include, test_rank='less', - test_scaling=False, # XXX probably some bug - ) + include = ["TRSP", "XXX1"] + raw = _test_raw_reader( + read_raw_egi, + input_fname=egi_fname, + include=include, + test_rank="less", + test_scaling=False, # XXX probably some bug + ) - assert 'eeg' in raw + assert "eeg" in raw assert raw.orig_format == "single" - eeg_chan = [c for c in raw.ch_names if c.startswith('E')] + eeg_chan = [c for c in raw.ch_names if c.startswith("E")] assert len(eeg_chan) == 256 picks = pick_types(raw.info, eeg=True) assert len(picks) == 256 - assert 'STI 014' in raw.ch_names + assert "STI 014" in raw.ch_names - events = find_events(raw, stim_channel='STI 014') + events = find_events(raw, stim_channel="STI 014") assert len(events) == 2 # ground truth assert np.unique(events[:, 1])[0] == 0 assert np.unique(events[:, 0])[0] != 0 @@ -227,52 +247,54 @@ def test_io_egi(): new_trigger = _combine_triggers(triggers, events_ids) assert_array_equal(np.unique(new_trigger), np.unique([0, 12, 24])) - pytest.raises(ValueError, read_raw_egi, egi_fname, include=['Foo'], - preload=False) - pytest.raises(ValueError, read_raw_egi, egi_fname, exclude=['Bar'], - preload=False) + pytest.raises(ValueError, read_raw_egi, egi_fname, include=["Foo"], preload=False) + pytest.raises(ValueError, read_raw_egi, egi_fname, exclude=["Bar"], preload=False) for ii, k in enumerate(include, 1): - assert (k in raw.event_id) - assert (raw.event_id[k] == ii) + assert k in raw.event_id + assert raw.event_id[k] == ii @requires_testing_data def test_io_egi_pns_mff(tmp_path): """Test importing EGI MFF with PNS data.""" - raw = read_raw_egi(egi_mff_pns_fname, include=None, preload=True, - verbose='error') - assert ('RawMff' in repr(raw)) + raw = read_raw_egi(egi_mff_pns_fname, include=None, preload=True, verbose="error") + assert "RawMff" in repr(raw) pns_chans = pick_types(raw.info, ecg=True, bio=True, emg=True) assert len(pns_chans) == 7 names = [raw.ch_names[x] for x in pns_chans] - pns_names = ['Resp. Temperature', - 'Resp. Pressure', - 'ECG', - 'Body Position', - 'Resp. Effort Chest', - 'Resp. Effort Abdomen', - 'EMG-Leg'] - _test_raw_reader(read_raw_egi, input_fname=egi_mff_pns_fname, - channel_naming='EEG %03d', verbose='error', - test_rank='less', - test_scaling=False, # XXX probably some bug - ) + pns_names = [ + "Resp. Temperature", + "Resp. Pressure", + "ECG", + "Body Position", + "Resp. Effort Chest", + "Resp. Effort Abdomen", + "EMG-Leg", + ] + _test_raw_reader( + read_raw_egi, + input_fname=egi_mff_pns_fname, + channel_naming="EEG %03d", + verbose="error", + test_rank="less", + test_scaling=False, # XXX probably some bug + ) assert names == pns_names mat_names = [ - 'Resp_Temperature', - 'Resp_Pressure', - 'ECG', - 'Body_Position', - 'Resp_Effort_Chest', - 'Resp_Effort_Abdomen', - 'EMGLeg' + "Resp_Temperature", + "Resp_Pressure", + "ECG", + "Body_Position", + "Resp_Effort_Chest", + "Resp_Effort_Abdomen", + "EMGLeg", ] egi_fname_mat = testing_path / "EGI" / "test_egi_pns.mat" mc = sio.loadmat(egi_fname_mat) for ch_name, ch_idx, mat_name in zip(pns_names, pns_chans, mat_names): - print('Testing {}'.format(ch_name)) + print("Testing {}".format(ch_name)) mc_key = [x for x in mc.keys() if mat_name in x][0] - cal = raw.info['chs'][ch_idx]['cal'] + cal = raw.info["chs"][ch_idx]["cal"] mat_data = mc[mc_key] * cal raw_data = raw[ch_idx][0] assert_array_equal(mat_data, raw_data) @@ -280,48 +302,50 @@ def test_io_egi_pns_mff(tmp_path): # EEG missing new_mff = tmp_path / "temp.mff" shutil.copytree(egi_mff_pns_fname, new_mff) - read_raw_egi(new_mff, verbose='error') + read_raw_egi(new_mff, verbose="error") os.remove(new_mff / "info1.xml") os.remove(new_mff / "signal1.bin") - with pytest.raises(FileNotFoundError, match='Could not find any EEG'): - read_raw_egi(new_mff, verbose='error') + with pytest.raises(FileNotFoundError, match="Could not find any EEG"): + read_raw_egi(new_mff, verbose="error") @requires_testing_data -@pytest.mark.parametrize('preload', (True, False)) +@pytest.mark.parametrize("preload", (True, False)) def test_io_egi_pns_mff_bug(preload): """Test importing EGI MFF with PNS data (BUG).""" egi_fname_mff = testing_path / "EGI" / "test_egi_pns_bug.mff" - with pytest.warns(RuntimeWarning, match='EGI PSG sample bug'): - raw = read_raw_egi(egi_fname_mff, include=None, preload=preload, - verbose='warning') + with pytest.warns(RuntimeWarning, match="EGI PSG sample bug"): + raw = read_raw_egi( + egi_fname_mff, include=None, preload=preload, verbose="warning" + ) assert len(raw.annotations) == 1 assert_allclose(raw.annotations.duration, [0.004]) assert_allclose(raw.annotations.onset, [13.948]) egi_fname_mat = testing_path / "EGI" / "test_egi_pns.mat" mc = sio.loadmat(egi_fname_mat) pns_chans = pick_types(raw.info, ecg=True, bio=True, emg=True) - pns_names = ['Resp. Temperature'[:15], - 'Resp. Pressure', - 'ECG', - 'Body Position', - 'Resp. Effort Chest'[:15], - 'Resp. Effort Abdomen'[:15], - 'EMG-Leg'] + pns_names = [ + "Resp. Temperature"[:15], + "Resp. Pressure", + "ECG", + "Body Position", + "Resp. Effort Chest"[:15], + "Resp. Effort Abdomen"[:15], + "EMG-Leg", + ] mat_names = [ - 'Resp_Temperature'[:15], - 'Resp_Pressure', - 'ECG', - 'Body_Position', - 'Resp_Effort_Chest'[:15], - 'Resp_Effort_Abdomen'[:15], - 'EMGLeg' - + "Resp_Temperature"[:15], + "Resp_Pressure", + "ECG", + "Body_Position", + "Resp_Effort_Chest"[:15], + "Resp_Effort_Abdomen"[:15], + "EMGLeg", ] for ch_name, ch_idx, mat_name in zip(pns_names, pns_chans, mat_names): - print('Testing {}'.format(ch_name)) + print("Testing {}".format(ch_name)) mc_key = [x for x in mc.keys() if mat_name in x][0] - cal = raw.info['chs'][ch_idx]['cal'] + cal = raw.info["chs"][ch_idx]["cal"] mat_data = mc[mc_key] * cal mat_data[:, -1] = 0 # The MFF has one less sample, the last one raw_data = raw[ch_idx][0] @@ -340,15 +364,22 @@ def test_io_egi_crop_no_preload(): assert_allclose(raw._data, raw_preload._data) -@pytest.mark.filterwarnings('ignore::FutureWarning') -@requires_version('mffpy', '0.5.7') +@pytest.mark.filterwarnings("ignore::FutureWarning") +@requires_version("mffpy", "0.5.7") @requires_testing_data -@pytest.mark.parametrize('idx, cond, tmax, signals, bads', [ - (0, 'Category 1', 0.016, egi_txt_evoked_cat1_fname, - ['E8', 'E11', 'E17', 'E28', 'ECG']), - (1, 'Category 2', 0.0, egi_txt_evoked_cat2_fname, - ['VREF', 'EMG']) -]) +@pytest.mark.parametrize( + "idx, cond, tmax, signals, bads", + [ + ( + 0, + "Category 1", + 0.016, + egi_txt_evoked_cat1_fname, + ["E8", "E11", "E17", "E28", "ECG"], + ), + (1, "Category 2", 0.0, egi_txt_evoked_cat2_fname, ["VREF", "EMG"]), + ], +) def test_io_egi_evokeds_mff(idx, cond, tmax, signals, bads): """Test reading evoked MFF file.""" # expected n channels @@ -365,15 +396,19 @@ def test_io_egi_evokeds_mff(idx, cond, tmax, signals, bads): assert len(evokeds) == 2 # Test invalid condition with pytest.raises(ValueError) as exc_info: - read_evokeds_mff(egi_mff_evoked_fname, condition='Invalid Condition') - message = "Invalid value for the 'condition' parameter provided as " \ - "category name. Allowed values are 'Category 1' and " \ - "'Category 2', but got 'Invalid Condition' instead." + read_evokeds_mff(egi_mff_evoked_fname, condition="Invalid Condition") + message = ( + "Invalid value for the 'condition' parameter provided as " + "category name. Allowed values are 'Category 1' and " + "'Category 2', but got 'Invalid Condition' instead." + ) assert str(exc_info.value) == message with pytest.raises(ValueError) as exc_info: read_evokeds_mff(egi_mff_evoked_fname, condition=2) - message = '"condition" parameter (2), provided as epoch index, ' \ - 'is out of range for available epochs (2).' + message = ( + '"condition" parameter (2), provided as epoch index, ' + "is out of range for available epochs (2)." + ) assert str(exc_info.value) == message with pytest.raises(TypeError) as exc_info: read_evokeds_mff(egi_mff_evoked_fname, condition=1.2) @@ -392,25 +427,25 @@ def test_io_egi_evokeds_mff(idx, cond, tmax, signals, bads): assert_allclose(evoked_cond.data, data, atol=1e-12) assert_allclose(evoked_idx.data, data, atol=1e-12) # Check info - assert object_diff(evoked_cond.info, evoked_idx.info) == '' - assert evoked_cond.info['description'] == cond - assert evoked_cond.info['bads'] == bads - assert len(evoked_cond.info['ch_names']) == n_eeg + n_ref + n_pns # 259 - assert 'ECG' in evoked_cond.info['ch_names'] - assert 'EMG' in evoked_cond.info['ch_names'] - assert 'ecg' in evoked_cond - assert 'emg' in evoked_cond + assert object_diff(evoked_cond.info, evoked_idx.info) == "" + assert evoked_cond.info["description"] == cond + assert evoked_cond.info["bads"] == bads + assert len(evoked_cond.info["ch_names"]) == n_eeg + n_ref + n_pns # 259 + assert "ECG" in evoked_cond.info["ch_names"] + assert "EMG" in evoked_cond.info["ch_names"] + assert "ecg" in evoked_cond + assert "emg" in evoked_cond pick_eeg = pick_types(evoked_cond.info, eeg=True, exclude=[]) assert len(pick_eeg) == n_eeg + n_ref # 257 - assert evoked_cond.info['nchan'] == n_eeg + n_ref + n_pns # 259 - assert evoked_cond.info['sfreq'] == 250.0 - assert not evoked_cond.info['custom_ref_applied'] - assert len(evoked_cond.info['dig']) == n_card + n_eeg + n_ref - assert evoked_cond.info['device_info']['type'] == 'HydroCel GSN 256 1.0' + assert evoked_cond.info["nchan"] == n_eeg + n_ref + n_pns # 259 + assert evoked_cond.info["sfreq"] == 250.0 + assert not evoked_cond.info["custom_ref_applied"] + assert len(evoked_cond.info["dig"]) == n_card + n_eeg + n_ref + assert evoked_cond.info["device_info"]["type"] == "HydroCel GSN 256 1.0" -@pytest.mark.filterwarnings('ignore::FutureWarning') -@requires_version('mffpy', '0.5.7') +@pytest.mark.filterwarnings("ignore::FutureWarning") +@requires_version("mffpy", "0.5.7") @requires_testing_data def test_read_evokeds_mff_bad_input(): """Test errors are thrown when reading invalid input file.""" @@ -422,8 +457,10 @@ def test_read_evokeds_mff_bad_input(): # Test continuous MFF with pytest.raises(ValueError) as exc_info: read_evokeds_mff(egi_mff_fname) - message = f'{egi_mff_fname} is a continuous MFF file. ' \ - 'fname must be the path to an averaged MFF file.' + message = ( + f"{egi_mff_fname} is a continuous MFF file. " + "fname must be the path to an averaged MFF file." + ) assert str(exc_info.value) == message @@ -437,56 +474,62 @@ def test_egi_coord_frame(): FIFF.FIFFV_POINT_RPA, ) for ii, want in enumerate(want_idents): - d = info['dig'][ii] - assert d['kind'] == FIFF.FIFFV_POINT_CARDINAL - assert d['ident'] == want - loc = d['r'] + d = info["dig"][ii] + assert d["kind"] == FIFF.FIFFV_POINT_CARDINAL + assert d["ident"] == want + loc = d["r"] if ii == 0: - assert 0.05 < -loc[0] < 0.1, 'LPA' - assert_allclose(loc[1:], 0, atol=1e-7, err_msg='LPA') + assert 0.05 < -loc[0] < 0.1, "LPA" + assert_allclose(loc[1:], 0, atol=1e-7, err_msg="LPA") elif ii == 1: - assert 0.05 < loc[1] < 0.11, 'Nasion' - assert_allclose(loc[::2], 0, atol=1e-7, err_msg='Nasion') + assert 0.05 < loc[1] < 0.11, "Nasion" + assert_allclose(loc[::2], 0, atol=1e-7, err_msg="Nasion") else: assert ii == 2 - assert 0.05 < loc[0] < 0.1, 'RPA' - assert_allclose(loc[1:], 0, atol=1e-7, err_msg='RPA') - for d in info['dig'][3:]: - assert d['kind'] == FIFF.FIFFV_POINT_EEG + assert 0.05 < loc[0] < 0.1, "RPA" + assert_allclose(loc[1:], 0, atol=1e-7, err_msg="RPA") + for d in info["dig"][3:]: + assert d["kind"] == FIFF.FIFFV_POINT_EEG @requires_testing_data -@pytest.mark.parametrize('fname, timestamp, utc_offset', [ - (egi_mff_fname, '2017-02-23T11:35:13.220824+01:00', '+0100'), - (egi_mff_pns_fname, '2017-09-20T09:55:44.072000+01:00', '+0100'), - (egi_eprime_pause_fname, '2018-07-30T10:46:09.621673-04:00', '-0400'), - (egi_pause_w1337_fname, '2019-10-14T10:54:27.395210-07:00', '-0700'), -]) +@pytest.mark.parametrize( + "fname, timestamp, utc_offset", + [ + (egi_mff_fname, "2017-02-23T11:35:13.220824+01:00", "+0100"), + (egi_mff_pns_fname, "2017-09-20T09:55:44.072000+01:00", "+0100"), + (egi_eprime_pause_fname, "2018-07-30T10:46:09.621673-04:00", "-0400"), + (egi_pause_w1337_fname, "2019-10-14T10:54:27.395210-07:00", "-0700"), + ], +) def test_meas_date(fname, timestamp, utc_offset): """Test meas date conversion.""" - raw = read_raw_egi(fname, verbose='warning') - dt = datetime.strptime(timestamp, '%Y-%m-%dT%H:%M:%S.%f%z') + raw = read_raw_egi(fname, verbose="warning") + dt = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%f%z") measdate = dt.astimezone(timezone.utc) - hour_local = int(dt.strftime('%H')) - hour_utc = int(raw.info['meas_date'].strftime('%H')) + hour_local = int(dt.strftime("%H")) + hour_utc = int(raw.info["meas_date"].strftime("%H")) local_utc_diff = hour_local - hour_utc - assert raw.info['meas_date'] == measdate - assert raw.info['utc_offset'] == utc_offset + assert raw.info["meas_date"] == measdate + assert raw.info["utc_offset"] == utc_offset assert local_utc_diff == int(utc_offset[:-2]) @requires_testing_data -@pytest.mark.parametrize('fname, standard_montage', [ - (egi_mff_fname, 'GSN-HydroCel-129'), # 129 chan EGI file - (egi_mff_pns_fname, 'GSN-HydroCel-257') # 257 chan EGI file -]) +@pytest.mark.parametrize( + "fname, standard_montage", + [ + (egi_mff_fname, "GSN-HydroCel-129"), # 129 chan EGI file + (egi_mff_pns_fname, "GSN-HydroCel-257"), # 257 chan EGI file + ], +) def test_set_standard_montage(fname, standard_montage): """Test setting a standard montage.""" - raw = read_raw_egi(fname, verbose='warning') - dig_before_mon = raw.info['dig'] + raw = read_raw_egi(fname, verbose="warning") + dig_before_mon = raw.info["dig"] - raw.set_montage(standard_montage, match_alias=True, on_missing='ignore') - dig_after_mon = raw.info['dig'] + raw.set_montage(standard_montage, match_alias=True, on_missing="ignore") + dig_after_mon = raw.info["dig"] # No dig entries should have been dropped while setting montage assert len(dig_before_mon) == len(dig_after_mon) diff --git a/mne/io/eximia/eximia.py b/mne/io/eximia/eximia.py index af6060f7709..dc3d9b445a5 100644 --- a/mne/io/eximia/eximia.py +++ b/mne/io/eximia/eximia.py @@ -53,41 +53,52 @@ class RawEximia(BaseRaw): @verbose def __init__(self, fname, preload=False, verbose=None): - fname = str(_check_fname(fname, 'read', True, 'fname')) + fname = str(_check_fname(fname, "read", True, "fname")) data_name = op.basename(fname) - logger.info('Loading %s' % data_name) + logger.info("Loading %s" % data_name) # Create vhdr and vmrk files so that we can use mne_brain_vision2fiff n_chan = 64 - sfreq = 1450. + sfreq = 1450.0 # data are multiplexed int16 - ch_names = ['GateIn', 'Trig1', 'Trig2', 'EOG'] - ch_types = ['stim', 'stim', 'stim', 'eog'] - cals = [0.0015259021896696422, 0.0015259021896696422, - 0.0015259021896696422, 0.3814755474174106] - ch_names += ('Fp1 Fpz Fp2 AF1 AFz AF2 ' - 'F7 F3 F1 Fz F2 F4 F8 ' - 'FT9 FT7 FC5 FC3 FC1 FCz FC2 FC4 FC6 FT8 FT10 ' - 'T7 C5 C3 C1 Cz C2 C4 C6 T8 ' - 'TP9 TP7 CP5 CP3 CP1 CPz CP2 CP4 CP6 TP8 TP10 ' - 'P9 P7 P3 P1 Pz P2 P4 P8 ' - 'P10 PO3 POz PO4 O1 Oz O2 Iz'.split()) + ch_names = ["GateIn", "Trig1", "Trig2", "EOG"] + ch_types = ["stim", "stim", "stim", "eog"] + cals = [ + 0.0015259021896696422, + 0.0015259021896696422, + 0.0015259021896696422, + 0.3814755474174106, + ] + ch_names += ( + "Fp1 Fpz Fp2 AF1 AFz AF2 " + "F7 F3 F1 Fz F2 F4 F8 " + "FT9 FT7 FC5 FC3 FC1 FCz FC2 FC4 FC6 FT8 FT10 " + "T7 C5 C3 C1 Cz C2 C4 C6 T8 " + "TP9 TP7 CP5 CP3 CP1 CPz CP2 CP4 CP6 TP8 TP10 " + "P9 P7 P3 P1 Pz P2 P4 P8 " + "P10 PO3 POz PO4 O1 Oz O2 Iz".split() + ) n_eeg = len(ch_names) - len(cals) cals += [0.07629510948348212] * n_eeg - ch_types += ['eeg'] * n_eeg + ch_types += ["eeg"] * n_eeg assert len(ch_names) == n_chan info = create_info(ch_names, sfreq, ch_types) n_bytes = _file_size(fname) n_samples, extra = divmod(n_bytes, (n_chan * 2)) if extra != 0: - warn('Incorrect number of samples in file (%s), the file is ' - 'likely truncated' % (n_samples,)) - for ch, cal in zip(info['chs'], cals): - ch['cal'] = cal + warn( + "Incorrect number of samples in file (%s), the file is " + "likely truncated" % (n_samples,) + ) + for ch, cal in zip(info["chs"], cals): + ch["cal"] = cal super(RawEximia, self).__init__( - info, preload=preload, last_samps=(n_samples - 1,), - filenames=[fname], orig_format='short') + info, + preload=preload, + last_samps=(n_samples - 1,), + filenames=[fname], + orig_format="short", + ) def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): """Read a chunk of raw data.""" - _read_segments_file( - self, data, idx, fi, start, stop, cals, mult, dtype=' 1: + if len(self.dataframes["recording_blocks"]) > 1: gap_annots = self._make_gap_annots() eye_annots = None if create_annotations: - eye_annots = self._make_eyelink_annots(self.dataframes, - create_annotations, - apply_offsets) + eye_annots = self._make_eyelink_annots( + self.dataframes, create_annotations, apply_offsets + ) if gap_annots and eye_annots: # set both self.set_annotations(gap_annots + eye_annots) elif gap_annots: @@ -449,7 +484,7 @@ def __init__(self, fname, preload=False, verbose=None, elif eye_annots: self.set_annotations(eye_annots) else: - logger.info('Not creating any annotations') + logger.info("Not creating any annotations") def _parse_recording_blocks(self): """Parse Eyelink ASCII file. @@ -464,15 +499,24 @@ def _parse_recording_blocks(self): with self.fname.open() as file: block_num = 1 self._sample_lines = [] - self._event_lines = {'START': [], 'END': [], 'SAMPLES': [], - 'EVENTS': [], 'ESACC': [], 'EBLINK': [], - 'EFIX': [], 'MSG': [], 'INPUT': [], - 'BUTTON': [], 'PUPIL': []} + self._event_lines = { + "START": [], + "END": [], + "SAMPLES": [], + "EVENTS": [], + "ESACC": [], + "EBLINK": [], + "EFIX": [], + "MSG": [], + "INPUT": [], + "BUTTON": [], + "PUPIL": [], + } self._system_lines = [] is_recording_block = False for line in file: - if line.startswith('START'): # start of recording block + if line.startswith("START"): # start of recording block is_recording_block = True if is_recording_block: if _is_sys_msg(line): @@ -485,73 +529,84 @@ def _parse_recording_blocks(self): elif tokens[0] in self._event_lines.keys(): event_key, event_info = tokens[0], tokens[1:] self._event_lines[event_key].append(event_info) - if tokens[0] == 'END': # end of recording block + if tokens[0] == "END": # end of recording block is_recording_block = False block_num += 1 - if not self._event_lines['START']: - raise ValueError('Could not determine the start of the' - ' recording. When converting to ASCII, START' - ' events should not be suppressed.') + if not self._event_lines["START"]: + raise ValueError( + "Could not determine the start of the" + " recording. When converting to ASCII, START" + " events should not be suppressed." + ) if not self._sample_lines: # no samples parsed raise ValueError(f"Couldn't find any samples in {self.fname}") self._validate_data() def _validate_data(self): """Check the incoming data for some known problems that can occur.""" - self._rec_info = self._event_lines['SAMPLES'][0] - pupil_info = self._event_lines['PUPIL'][0] - n_blocks = len(self._event_lines['START']) + self._rec_info = self._event_lines["SAMPLES"][0] + pupil_info = self._event_lines["PUPIL"][0] + n_blocks = len(self._event_lines["START"]) sfreq = int(_get_sfreq(self._rec_info)) - first_samp = self._event_lines['START'][0][0] - if ('LEFT' in self._rec_info) and ('RIGHT' in self._rec_info): - self._tracking_mode = 'binocular' + first_samp = self._event_lines["START"][0][0] + if ("LEFT" in self._rec_info) and ("RIGHT" in self._rec_info): + self._tracking_mode = "binocular" else: - self._tracking_mode = 'monocular' + self._tracking_mode = "monocular" # Detect the datatypes that are in file. - if 'GAZE' in self._rec_info: - logger.info('Pixel coordinate data detected.') - logger.warning('Pass `scalings=dict(eyegaze=1e3)` when using plot' - ' method to make traces more legible.') - elif 'HREF' in self._rec_info: - logger.info('Head-referenced eye angle data detected.') - elif 'PUPIL' in self._rec_info: - logger.warning('Raw eyegaze coordinates detected. Analyze with' - ' caution.') - if 'AREA' in pupil_info: - logger.info('Pupil-size area reported.') - elif 'DIAMETER' in pupil_info: - logger.info('Pupil-size diameter reported.') + if "GAZE" in self._rec_info: + logger.info("Pixel coordinate data detected.") + logger.warning( + "Pass `scalings=dict(eyegaze=1e3)` when using plot" + " method to make traces more legible." + ) + elif "HREF" in self._rec_info: + logger.info("Head-referenced eye angle data detected.") + elif "PUPIL" in self._rec_info: + logger.warning("Raw eyegaze coordinates detected. Analyze with" " caution.") + if "AREA" in pupil_info: + logger.info("Pupil-size area reported.") + elif "DIAMETER" in pupil_info: + logger.info("Pupil-size diameter reported.") # Check sampling frequency. if sfreq == 2000 and isinstance(first_samp, int): - raise ValueError(f'The sampling rate is {sfreq}Hz but the' - ' timestamps were not output as float values.' - ' Check the settings in the EDF2ASC application.') + raise ValueError( + f"The sampling rate is {sfreq}Hz but the" + " timestamps were not output as float values." + " Check the settings in the EDF2ASC application." + ) elif sfreq != 2000 and isinstance(first_samp, float): - raise ValueError('For recordings with a sampling rate less than' - ' 2000Hz, timestamps should not be output to the' - ' ASCII file as float values. Check the' - ' settings in the EDF2ASC application. Got a' - f' sampling rate of {sfreq}Hz.') + raise ValueError( + "For recordings with a sampling rate less than" + " 2000Hz, timestamps should not be output to the" + " ASCII file as float values. Check the" + " settings in the EDF2ASC application. Got a" + f" sampling rate of {sfreq}Hz." + ) # If more than 1 recording period, make sure sfreq didn't change. if n_blocks > 1: - err_msg = 'The sampling frequency changed during the recording.'\ - ' This file cannot be read into MNE.' - for block_info in self._event_lines['SAMPLES'][1:]: + err_msg = ( + "The sampling frequency changed during the recording." + " This file cannot be read into MNE." + ) + for block_info in self._event_lines["SAMPLES"][1:]: block_sfreq = int(_get_sfreq(block_info)) if block_sfreq != sfreq: - raise ValueError(err_msg + - f' Got both {sfreq} and {block_sfreq} Hz.' - ) - if self._tracking_mode == 'monocular': - assert self._rec_info[1] in ['LEFT', 'RIGHT'] + raise ValueError( + err_msg + f" Got both {sfreq} and {block_sfreq} Hz." + ) + if self._tracking_mode == "monocular": + assert self._rec_info[1] in ["LEFT", "RIGHT"] eye = self._rec_info[1] - blocks_list = self._event_lines['SAMPLES'] + blocks_list = self._event_lines["SAMPLES"] eye_per_block = [block_info[1] for block_info in blocks_list] if not all([this_eye == eye for this_eye in eye_per_block]): - logger.warning('The eye being tracked changed during the' - ' recording. The channel names will reflect' - ' the eye that was tracked at the start of' - ' the recording.') + logger.warning( + "The eye being tracked changed during the" + " recording. The channel names will reflect" + " the eye that was tracked at the start of" + " the recording." + ) def _get_recording_datetime(self): """Create a datetime object from the datetime in ASCII file.""" @@ -561,11 +616,11 @@ def _get_recording_datetime(self): with self.fname.open() as file: for line in file: # header lines are at top of file and start with ** - if line.startswith('**'): + if line.startswith("**"): in_header = True if in_header: - if line.startswith('** DATE:'): - dt_str = line.replace('** DATE:', '').strip() + if line.startswith("** DATE:"): + dt_str = line.replace("** DATE:", "").strip() fmt = "%a %b %d %H:%M:%S %Y" try: # Eyelink measdate timestamps are timezone naive. @@ -575,9 +630,11 @@ def _get_recording_datetime(self): dt_aware = dt_naive.replace(tzinfo=tz) self._meas_date = dt_aware except Exception: - msg = ('Extraction of measurement date failed.' - ' Please report this as a github issue.' - ' The date is being set to None') + msg = ( + "Extraction of measurement date failed." + " Please report this as a github issue." + " The date is being set to None" + ) logger.warning(msg) break @@ -613,66 +670,67 @@ def _infer_col_names(self): """ col_names = {} # initiate the column names for the sample lines - col_names['sample'] = list(EYELINK_COLS['timestamp']) + col_names["sample"] = list(EYELINK_COLS["timestamp"]) # and for the eye message lines - col_names['blink'] = list(EYELINK_COLS['eye_event']) - col_names['fixation'] = list(EYELINK_COLS['eye_event'] + - EYELINK_COLS['fixation']) - col_names['saccade'] = list(EYELINK_COLS['eye_event'] + - EYELINK_COLS['saccade']) + col_names["blink"] = list(EYELINK_COLS["eye_event"]) + col_names["fixation"] = list( + EYELINK_COLS["eye_event"] + EYELINK_COLS["fixation"] + ) + col_names["saccade"] = list(EYELINK_COLS["eye_event"] + EYELINK_COLS["saccade"]) # Recording was either binocular or monocular # If monocular, find out which eye was tracked and append to ch_name - if self._tracking_mode == 'monocular': - assert self._rec_info[1] in ['LEFT', 'RIGHT'] + if self._tracking_mode == "monocular": + assert self._rec_info[1] in ["LEFT", "RIGHT"] eye = self._rec_info[1].lower() - ch_names = list(EYELINK_COLS['pos'][eye]) - elif self._tracking_mode == 'binocular': - ch_names = list(EYELINK_COLS['pos']['left'] + - EYELINK_COLS['pos']['right']) - col_names['sample'].extend(ch_names) + ch_names = list(EYELINK_COLS["pos"][eye]) + elif self._tracking_mode == "binocular": + ch_names = list(EYELINK_COLS["pos"]["left"] + EYELINK_COLS["pos"]["right"]) + col_names["sample"].extend(ch_names) # The order of these if statements should not be changed. - if 'VEL' in self._rec_info: # If velocity data are reported - if self._tracking_mode == 'monocular': - ch_names.extend(EYELINK_COLS['velocity'][eye]) - col_names['sample'].extend(EYELINK_COLS['velocity'][eye]) - elif self._tracking_mode == 'binocular': - ch_names.extend(EYELINK_COLS['velocity']['left'] + - EYELINK_COLS['velocity']['right']) - col_names['sample'].extend(EYELINK_COLS['velocity']['left'] + - EYELINK_COLS['velocity']['right']) + if "VEL" in self._rec_info: # If velocity data are reported + if self._tracking_mode == "monocular": + ch_names.extend(EYELINK_COLS["velocity"][eye]) + col_names["sample"].extend(EYELINK_COLS["velocity"][eye]) + elif self._tracking_mode == "binocular": + ch_names.extend( + EYELINK_COLS["velocity"]["left"] + EYELINK_COLS["velocity"]["right"] + ) + col_names["sample"].extend( + EYELINK_COLS["velocity"]["left"] + EYELINK_COLS["velocity"]["right"] + ) # if resolution data are reported - if 'RES' in self._rec_info: - ch_names.extend(EYELINK_COLS['resolution']) - col_names['sample'].extend(EYELINK_COLS['resolution']) - col_names['fixation'].extend(EYELINK_COLS['resolution']) - col_names['saccade'].extend(EYELINK_COLS['resolution']) + if "RES" in self._rec_info: + ch_names.extend(EYELINK_COLS["resolution"]) + col_names["sample"].extend(EYELINK_COLS["resolution"]) + col_names["fixation"].extend(EYELINK_COLS["resolution"]) + col_names["saccade"].extend(EYELINK_COLS["resolution"]) # if digital input port values are reported - if 'INPUT' in self._rec_info: - ch_names.extend(EYELINK_COLS['input']) - col_names['sample'].extend(EYELINK_COLS['input']) + if "INPUT" in self._rec_info: + ch_names.extend(EYELINK_COLS["input"]) + col_names["sample"].extend(EYELINK_COLS["input"]) # add flags column - col_names['sample'].extend(EYELINK_COLS['flags']) + col_names["sample"].extend(EYELINK_COLS["flags"]) # if head target info was reported, add its cols after flags col. - if 'HTARGET' in self._rec_info: - ch_names.extend(EYELINK_COLS['remote']) - col_names['sample'].extend(EYELINK_COLS['remote'] - + EYELINK_COLS['remote_flags']) + if "HTARGET" in self._rec_info: + ch_names.extend(EYELINK_COLS["remote"]) + col_names["sample"].extend( + EYELINK_COLS["remote"] + EYELINK_COLS["remote_flags"] + ) # finally add a column for recording block number # FYI this column does not exist in the asc file.. # but it is added during _parse_recording_blocks for col in col_names.values(): - col.extend(EYELINK_COLS['block_num']) + col.extend(EYELINK_COLS["block_num"]) return col_names, ch_names - def _create_dataframes(self, col_names, sfreq, find_overlaps=False, - threshold=0.05): + def _create_dataframes(self, col_names, sfreq, find_overlaps=False, threshold=0.05): """Create pandas.DataFrame for Eyelink samples and events. Creates a pandas DataFrame for self._sample_lines and for each @@ -681,159 +739,166 @@ def _create_dataframes(self, col_names, sfreq, find_overlaps=False, pd = _check_pandas_installed() # First sample should be the first line of the first recording block - first_samp = self._event_lines['START'][0][0] + first_samp = self._event_lines["START"][0][0] # dataframe for samples - self.dataframes['samples'] = pd.DataFrame(self._sample_lines, - columns=col_names['sample']) - if 'HREF' in self._rec_info: - pos_names = (EYELINK_COLS['pos']['left'][:-1] - + EYELINK_COLS['pos']['right'][:-1]) - for col in self.dataframes['samples'].columns: + self.dataframes["samples"] = pd.DataFrame( + self._sample_lines, columns=col_names["sample"] + ) + if "HREF" in self._rec_info: + pos_names = ( + EYELINK_COLS["pos"]["left"][:-1] + EYELINK_COLS["pos"]["right"][:-1] + ) + for col in self.dataframes["samples"].columns: if col not in pos_names: # 'xpos_left' ... 'ypos_right' continue - series = self._href_to_radian(self.dataframes['samples'][col]) - self.dataframes['samples'][col] = series + series = self._href_to_radian(self.dataframes["samples"][col]) + self.dataframes["samples"][col] = series - n_block = len(self._event_lines['START']) + n_block = len(self._event_lines["START"]) if n_block > 1: - logger.info(f'There are {n_block} recording blocks in this' - ' file. Times between blocks will be annotated with' - f' {self._gap_desc}.') + logger.info( + f"There are {n_block} recording blocks in this" + " file. Times between blocks will be annotated with" + f" {self._gap_desc}." + ) # if there is more than 1 recording block we must account for # the missing timestamps and samples bt the blocks - self.dataframes['samples'] = _fill_times(self.dataframes - ['samples'], - sfreq=sfreq) - _convert_times(self.dataframes['samples'], first_samp) + self.dataframes["samples"] = _fill_times( + self.dataframes["samples"], sfreq=sfreq + ) + _convert_times(self.dataframes["samples"], first_samp) # dataframe for each type of occular event - for event, columns, label in zip(['EFIX', 'ESACC', 'EBLINK'], - [col_names['fixation'], - col_names['saccade'], - col_names['blink']], - ['fixations', - 'saccades', - 'blinks'] - ): + for event, columns, label in zip( + ["EFIX", "ESACC", "EBLINK"], + [col_names["fixation"], col_names["saccade"], col_names["blink"]], + ["fixations", "saccades", "blinks"], + ): if self._event_lines[event]: # an empty list returns False - self.dataframes[label] = pd.DataFrame(self._event_lines[event], - columns=columns) + self.dataframes[label] = pd.DataFrame( + self._event_lines[event], columns=columns + ) _convert_times(self.dataframes[label], first_samp) if find_overlaps is True: - if self._tracking_mode == 'monocular': - raise ValueError('find_overlaps is only valid with' - ' binocular recordings, this file is' - f' {self._tracking_mode}') - df = _find_overlaps(self.dataframes[label], - max_time=threshold) + if self._tracking_mode == "monocular": + raise ValueError( + "find_overlaps is only valid with" + " binocular recordings, this file is" + f" {self._tracking_mode}" + ) + df = _find_overlaps(self.dataframes[label], max_time=threshold) self.dataframes[label] = df else: - logger.info(f'No {label} were found in this file. ' - f'Not returning any info on {label}.') + logger.info( + f"No {label} were found in this file. " + f"Not returning any info on {label}." + ) # make dataframe for experiment messages - if self._event_lines['MSG']: + if self._event_lines["MSG"]: msgs = [] - for tokens in self._event_lines['MSG']: + for tokens in self._event_lines["MSG"]: timestamp = tokens[0] block = tokens[-1] # if offset token exists, it will be the 1st index # and is an int or float if isinstance(tokens[1], (int, float)): offset = tokens[1] - msg = ' '.join(str(x) for x in tokens[2:-1]) + msg = " ".join(str(x) for x in tokens[2:-1]) else: # there is no offset token offset = np.nan - msg = ' '.join(str(x) for x in tokens[1:-1]) + msg = " ".join(str(x) for x in tokens[1:-1]) msgs.append([timestamp, offset, msg, block]) - cols = ['time', 'offset', 'event_msg', 'block'] - self.dataframes['messages'] = (pd.DataFrame(msgs, - columns=cols)) - _convert_times(self.dataframes['messages'], first_samp) + cols = ["time", "offset", "event_msg", "block"] + self.dataframes["messages"] = pd.DataFrame(msgs, columns=cols) + _convert_times(self.dataframes["messages"], first_samp) # make dataframe for recording block start, end times - assert (len(self._event_lines['START']) - == len(self._event_lines['END']) - ) - blocks = [[bgn[0], end[0], bgn[-1]] # start, end, block_num - for bgn, end in zip(self._event_lines['START'], - self._event_lines['END']) - ] - cols = ['time', 'end_time', 'block'] - self.dataframes['recording_blocks'] = pd.DataFrame(blocks, - columns=cols) - _convert_times(self.dataframes['recording_blocks'], first_samp) + assert len(self._event_lines["START"]) == len(self._event_lines["END"]) + blocks = [ + [bgn[0], end[0], bgn[-1]] # start, end, block_num + for bgn, end in zip(self._event_lines["START"], self._event_lines["END"]) + ] + cols = ["time", "end_time", "block"] + self.dataframes["recording_blocks"] = pd.DataFrame(blocks, columns=cols) + _convert_times(self.dataframes["recording_blocks"], first_samp) # make dataframe for digital input port - if self._event_lines['INPUT']: - cols = ['time', 'DIN', 'block'] - self.dataframes['DINS'] = pd.DataFrame(self._event_lines['INPUT'], - columns=cols) - _convert_times(self.dataframes['DINS'], first_samp) + if self._event_lines["INPUT"]: + cols = ["time", "DIN", "block"] + self.dataframes["DINS"] = pd.DataFrame( + self._event_lines["INPUT"], columns=cols + ) + _convert_times(self.dataframes["DINS"], first_samp) # TODO: Make dataframes for other eyelink events (Buttons) def _create_info(self, ch_names, sfreq): """Create info object for RawEyelink.""" # assign channel type from ch_name - pos_names = (EYELINK_COLS['pos']['left'][:-1] - + EYELINK_COLS['pos']['right'][:-1]) - pupil_names = (EYELINK_COLS['pos']['left'][-1] - + EYELINK_COLS['pos']['right'][-1]) - ch_types = ['eyegaze' if ch in pos_names - else 'pupil' if ch in pupil_names - else 'stim' if ch == 'DIN' - else 'misc' - for ch in ch_names] - info = create_info(ch_names, - sfreq, - ch_types) + pos_names = EYELINK_COLS["pos"]["left"][:-1] + EYELINK_COLS["pos"]["right"][:-1] + pupil_names = EYELINK_COLS["pos"]["left"][-1] + EYELINK_COLS["pos"]["right"][-1] + ch_types = [ + "eyegaze" + if ch in pos_names + else "pupil" + if ch in pupil_names + else "stim" + if ch == "DIN" + else "misc" + for ch in ch_names + ] + info = create_info(ch_names, sfreq, ch_types) # set correct loc for eyepos and pupil channels - for ch_dict in info['chs']: + for ch_dict in info["chs"]: # loc index 3 can indicate left or right eye - if ch_dict['ch_name'].endswith('left'): # [x,y,pupil]_left - ch_dict['loc'][3] = -1 # left eye - elif ch_dict['ch_name'].endswith('right'): # [x,y,pupil]_right - ch_dict['loc'][3] = 1 # right eye + if ch_dict["ch_name"].endswith("left"): # [x,y,pupil]_left + ch_dict["loc"][3] = -1 # left eye + elif ch_dict["ch_name"].endswith("right"): # [x,y,pupil]_right + ch_dict["loc"][3] = 1 # right eye else: - logger.debug(f"leaving index 3 of loc array as" - f" {ch_dict['loc'][3]} for {ch_dict['ch_name']}") + logger.debug( + f"leaving index 3 of loc array as" + f" {ch_dict['loc'][3]} for {ch_dict['ch_name']}" + ) # loc index 4 can indicate x/y coord - if ch_dict['ch_name'].startswith('x'): - ch_dict['loc'][4] = -1 # x-coord - elif ch_dict['ch_name'].startswith('y'): - ch_dict['loc'][4] = 1 # y-coord + if ch_dict["ch_name"].startswith("x"): + ch_dict["loc"][4] = -1 # x-coord + elif ch_dict["ch_name"].startswith("y"): + ch_dict["loc"][4] = 1 # y-coord else: - logger.debug(f"leaving index 4 of loc array as" - f" {ch_dict['loc'][4]} for {ch_dict['ch_name']}") - if 'HREF' in self._rec_info: - if ch_dict['ch_name'].startswith(('xpos', 'ypos')): - ch_dict['unit'] = FIFF.FIFF_UNIT_RAD + logger.debug( + f"leaving index 4 of loc array as" + f" {ch_dict['loc'][4]} for {ch_dict['ch_name']}" + ) + if "HREF" in self._rec_info: + if ch_dict["ch_name"].startswith(("xpos", "ypos")): + ch_dict["unit"] = FIFF.FIFF_UNIT_RAD return info - def _make_gap_annots(self, key='recording_blocks'): + def _make_gap_annots(self, key="recording_blocks"): """Create Annotations for gap periods between recording blocks.""" df = self.dataframes[key] gap_desc = self._gap_desc - onsets = df['end_time'].iloc[:-1] - diffs = df['time'].shift(-1) - df['end_time'] + onsets = df["end_time"].iloc[:-1] + diffs = df["time"].shift(-1) - df["end_time"] durations = diffs.iloc[:-1] descriptions = [gap_desc] * len(onsets) - return Annotations(onset=onsets, - duration=durations, - description=descriptions) + return Annotations(onset=onsets, duration=durations, description=descriptions) def _make_eyelink_annots(self, df_dict, create_annots, apply_offsets): """Create Annotations for each df in self.dataframes.""" - valid_descs = ['blinks', 'saccades', 'fixations', 'messages'] - msg = ("create_annotations must be True or a list containing one or" - f" more of {valid_descs}.") - wrong_type = (msg + f' Got a {type(create_annots)} instead.') + valid_descs = ["blinks", "saccades", "fixations", "messages"] + msg = ( + "create_annotations must be True or a list containing one or" + f" more of {valid_descs}." + ) + wrong_type = msg + f" Got a {type(create_annots)} instead." if create_annots is True: descs = valid_descs else: @@ -844,31 +909,34 @@ def _make_eyelink_annots(self, df_dict, create_annots, apply_offsets): annots = None for key, df in df_dict.items(): - eye_annot_cond = ((key in ['blinks', 'fixations', 'saccades']) - and (key in descs)) + eye_annot_cond = (key in ["blinks", "fixations", "saccades"]) and ( + key in descs + ) if eye_annot_cond: - onsets = df['time'] - durations = df['duration'] + onsets = df["time"] + durations = df["duration"] # Create annotations for both eyes - descriptions = f'{key[:-1]}_' + df['eye'] # i.e "blink_r" - this_annot = Annotations(onset=onsets, - duration=durations, - description=descriptions) - elif (key in ['messages']) and (key in descs): + descriptions = f"{key[:-1]}_" + df["eye"] # i.e "blink_r" + this_annot = Annotations( + onset=onsets, duration=durations, description=descriptions + ) + elif (key in ["messages"]) and (key in descs): if apply_offsets: - if df['offset'].isnull().all(): - logger.warning('There are no offsets for the messages' - f' in {self.fname}. Not applying any' - ' offset') + if df["offset"].isnull().all(): + logger.warning( + "There are no offsets for the messages" + f" in {self.fname}. Not applying any" + " offset" + ) # If df['offset] is all NaNs, time is not changed - onsets = df['time'] + df['offset'].fillna(0) + onsets = df["time"] + df["offset"].fillna(0) else: - onsets = df['time'] + onsets = df["time"] durations = [0] * onsets - descriptions = df['event_msg'] - this_annot = Annotations(onset=onsets, - duration=durations, - description=descriptions) + descriptions = df["event_msg"] + this_annot = Annotations( + onset=onsets, duration=durations, description=descriptions + ) else: continue # TODO make df and annotations for Buttons if not annots: @@ -876,7 +944,8 @@ def _make_eyelink_annots(self, df_dict, create_annots, apply_offsets): elif annots: annots += this_annot if not annots: - logger.warning(f'Annotations for {descs} were requested but' - ' none could be made.') + logger.warning( + f"Annotations for {descs} were requested but" " none could be made." + ) return return annots diff --git a/mne/io/eyelink/tests/test_eyelink.py b/mne/io/eyelink/tests/test_eyelink.py index 0aa5e4d4e0b..51d64ea5ed5 100644 --- a/mne/io/eyelink/tests/test_eyelink.py +++ b/mne/io/eyelink/tests/test_eyelink.py @@ -9,87 +9,94 @@ from mne.utils import _check_pandas_installed, requires_pandas testing_path = data_path(download=False) -fname = testing_path / 'eyetrack' / 'test_eyelink.asc' -fname_href = testing_path / 'eyetrack' / 'test_eyelink_HREF.asc' +fname = testing_path / "eyetrack" / "test_eyelink.asc" +fname_href = testing_path / "eyetrack" / "test_eyelink_HREF.asc" def test_eyetrack_not_data_ch(): """Eyetrack channels are not data channels.""" - msg = 'eyetrack channels are not data channels. Refer to MNE definition'\ - ' of data channels in the glossary section of the documentation.' - assert 'eyegaze' not in _DATA_CH_TYPES_SPLIT, msg - assert 'pupil' not in _DATA_CH_TYPES_SPLIT, msg + msg = ( + "eyetrack channels are not data channels. Refer to MNE definition" + " of data channels in the glossary section of the documentation." + ) + assert "eyegaze" not in _DATA_CH_TYPES_SPLIT, msg + assert "pupil" not in _DATA_CH_TYPES_SPLIT, msg @requires_testing_data @requires_pandas -@pytest.mark.parametrize('fname, create_annotations, find_overlaps', - [(fname, False, False), - (fname, True, False), - (fname, True, True), - (fname, ['fixations', 'saccades', 'blinks'], True)]) +@pytest.mark.parametrize( + "fname, create_annotations, find_overlaps", + [ + (fname, False, False), + (fname, True, False), + (fname, True, True), + (fname, ["fixations", "saccades", "blinks"], True), + ], +) def test_eyelink(fname, create_annotations, find_overlaps): """Test reading eyelink asc files.""" - raw = read_raw_eyelink(fname, create_annotations=create_annotations, - find_overlaps=find_overlaps) + raw = read_raw_eyelink( + fname, create_annotations=create_annotations, find_overlaps=find_overlaps + ) # First, tests that shouldn't change based on function arguments - assert raw.info['sfreq'] == 500 # True for this file - assert raw.info['meas_date'].month == 3 - assert raw.info['meas_date'].day == 10 - assert raw.info['meas_date'].year == 2022 + assert raw.info["sfreq"] == 500 # True for this file + assert raw.info["meas_date"].month == 3 + assert raw.info["meas_date"].day == 10 + assert raw.info["meas_date"].year == 2022 - assert len(raw.info['ch_names']) == 6 - assert raw.info['chs'][0]['kind'] == FIFF.FIFFV_EYETRACK_CH - assert raw.info['chs'][0]['coil_type'] == FIFF.FIFFV_COIL_EYETRACK_POS - raw.info['chs'][2]['coil_type'] == FIFF.FIFFV_COIL_EYETRACK_PUPIL + assert len(raw.info["ch_names"]) == 6 + assert raw.info["chs"][0]["kind"] == FIFF.FIFFV_EYETRACK_CH + assert raw.info["chs"][0]["coil_type"] == FIFF.FIFFV_COIL_EYETRACK_POS + raw.info["chs"][2]["coil_type"] == FIFF.FIFFV_COIL_EYETRACK_PUPIL # x_left - assert all(raw.info['chs'][0]['loc'][3:5] == [-1, -1]) + assert all(raw.info["chs"][0]["loc"][3:5] == [-1, -1]) # pupil_left - assert raw.info['chs'][2]['loc'][3] == -1 - assert np.isnan(raw.info['chs'][2]['loc'][4]) + assert raw.info["chs"][2]["loc"][3] == -1 + assert np.isnan(raw.info["chs"][2]["loc"][4]) # y_right - assert all(raw.info['chs'][4]['loc'][3:5] == [1, 1]) - assert 'RawEyelink' in repr(raw) + assert all(raw.info["chs"][4]["loc"][3:5] == [1, 1]) + assert "RawEyelink" in repr(raw) # Test some annotation values for accuracy. if create_annotations is True and find_overlaps: - orig = raw.info['meas_date'] + orig = raw.info["meas_date"] df = raw.annotations.to_data_frame() # Convert annot onset datetimes to seconds, relative to orig_time - df['time_in_sec'] = df['onset'].apply(lambda x: x.timestamp() - - orig.timestamp()) + df["time_in_sec"] = df["onset"].apply( + lambda x: x.timestamp() - orig.timestamp() + ) # There is a blink in this data at 8.9 seconds - cond = (df['time_in_sec'] > 8.899) & (df['time_in_sec'] < 8.95) - assert df[cond]['description'].values[0].startswith('blink') + cond = (df["time_in_sec"] > 8.899) & (df["time_in_sec"] < 8.95) + assert df[cond]["description"].values[0].startswith("blink") if find_overlaps is True: df = raw.annotations.to_data_frame() # these should both be True so long as _find_overlaps is not # majorly refactored. - assert 'blink_L' in df['description'].unique() - assert 'blink_both' in df['description'].unique() + assert "blink_L" in df["description"].unique() + assert "blink_both" in df["description"].unique() if isinstance(create_annotations, list) and find_overlaps: # the last pytest parametrize condition should hit this df = raw.annotations.to_data_frame() # Rows 0, 1, 2 should be 'fixation_both', 'saccade_both', 'blink_both' - for i, label in zip([0, 1, 2], ['fixation', 'saccade', 'blink']): - assert df['description'].iloc[i] == f'{label}_both' + for i, label in zip([0, 1, 2], ["fixation", "saccade", "blink"]): + assert df["description"].iloc[i] == f"{label}_both" @requires_testing_data @requires_pandas -@pytest.mark.parametrize('fname_href', - [(fname_href)]) +@pytest.mark.parametrize("fname_href", [(fname_href)]) def test_radian(fname_href): """Test converting HREF position data to radians.""" - raw = read_raw_eyelink(fname_href, create_annotations=['blinks']) + raw = read_raw_eyelink(fname_href, create_annotations=["blinks"]) # Test channel types - assert raw.get_channel_types() == ['eyegaze', 'eyegaze', 'pupil'] + assert raw.get_channel_types() == ["eyegaze", "eyegaze", "pupil"] # Test that eyegaze channels have a radian unit - assert raw.info['chs'][0]['unit'] == FIFF.FIFF_UNIT_RAD - assert raw.info['chs'][1]['unit'] == FIFF.FIFF_UNIT_RAD + assert raw.info["chs"][0]["unit"] == FIFF.FIFF_UNIT_RAD + assert raw.info["chs"][1]["unit"] == FIFF.FIFF_UNIT_RAD # Data in radians should range between -1 and 1 # Test first channel (xpos_right) @@ -99,7 +106,7 @@ def test_radian(fname_href): @requires_testing_data @requires_pandas -@pytest.mark.parametrize('fname', [(fname)]) +@pytest.mark.parametrize("fname", [(fname)]) def test_fill_times(fname): """Test use of pd.merge_asof in _fill_times. @@ -112,17 +119,17 @@ def test_fill_times(fname): from ..eyelink import _fill_times raw = read_raw_eyelink(fname, create_annotations=False) - sfreq = raw.info['sfreq'] + sfreq = raw.info["sfreq"] # just take first 1000 points for testing - df = raw.dataframes['samples'].iloc[:1000].reset_index(drop=True) + df = raw.dataframes["samples"].iloc[:1000].reset_index(drop=True) # even during blinks, pupil val is 0, so there should be no nans # in this column - assert not df['pupil_left'].isna().sum() - nan_count = df['pupil_left'].isna().sum() # i.e 0 + assert not df["pupil_left"].isna().sum() + nan_count = df["pupil_left"].isna().sum() # i.e 0 df_merged = _fill_times(df, sfreq) # If times dont merge correctly, there will be additional rows in # in df_merged with all nan values - assert df_merged['pupil_left'].isna().sum() == nan_count # i.e. 0 + assert df_merged["pupil_left"].isna().sum() == nan_count # i.e. 0 @requires_pandas @@ -137,11 +144,16 @@ def test_find_overlaps(): overlap because they are both left eye events. """ from ..eyelink import _find_overlaps + pd = _check_pandas_installed() - blink_df = pd.DataFrame({'eye': ['L', 'R', 'L', 'R', 'L', 'L'], - 'time': [.01, .04, 4.14, 4.20, 6.50, 6.504], - 'end_time': [.05, .08, 4.18, 4.22, 6.60, 6.604]}) + blink_df = pd.DataFrame( + { + "eye": ["L", "R", "L", "R", "L", "L"], + "time": [0.01, 0.04, 4.14, 4.20, 6.50, 6.504], + "end_time": [0.05, 0.08, 4.18, 4.22, 6.60, 6.604], + } + ) overlap_df = _find_overlaps(blink_df) - assert len(overlap_df['eye'].unique()) == 3 # ['both', 'left', 'right'] + assert len(overlap_df["eye"].unique()) == 3 # ['both', 'left', 'right'] assert len(overlap_df) == 5 # ['both', 'L', 'R', 'L', 'L'] - assert overlap_df['eye'].iloc[0] == 'both' + assert overlap_df["eye"].iloc[0] == "both" diff --git a/mne/io/fieldtrip/__init__.py b/mne/io/fieldtrip/__init__.py index 2085c931925..2330e7222c9 100644 --- a/mne/io/fieldtrip/__init__.py +++ b/mne/io/fieldtrip/__init__.py @@ -4,5 +4,4 @@ # # License: BSD-3-Clause -from .fieldtrip import (read_evoked_fieldtrip, read_epochs_fieldtrip, - read_raw_fieldtrip) +from .fieldtrip import read_evoked_fieldtrip, read_epochs_fieldtrip, read_raw_fieldtrip diff --git a/mne/io/fieldtrip/fieldtrip.py b/mne/io/fieldtrip/fieldtrip.py index 3c7cfb3394c..872cf21005a 100644 --- a/mne/io/fieldtrip/fieldtrip.py +++ b/mne/io/fieldtrip/fieldtrip.py @@ -6,15 +6,20 @@ import numpy as np -from .utils import _create_info, _set_tmin, _create_events, \ - _create_event_metadata, _validate_ft_struct +from .utils import ( + _create_info, + _set_tmin, + _create_events, + _create_event_metadata, + _validate_ft_struct, +) from ...utils import _check_fname, _import_pymatreader_funcs from ..array.array import RawArray from ...epochs import EpochsArray from ...evoked import EvokedArray -def read_raw_fieldtrip(fname, info, data_name='data'): +def read_raw_fieldtrip(fname, info, data_name="data"): """Load continuous (raw) data from a FieldTrip preprocessing structure. This function expects to find single trial raw data (FT_DATATYPE_RAW) in @@ -49,12 +54,10 @@ def read_raw_fieldtrip(fname, info, data_name='data'): -------- mne.io.Raw : Documentation of attributes and methods of RawArray. """ - read_mat = _import_pymatreader_funcs('FieldTrip I/O') - fname = _check_fname(fname, overwrite='read', must_exist=True) + read_mat = _import_pymatreader_funcs("FieldTrip I/O") + fname = _check_fname(fname, overwrite="read", must_exist=True) - ft_struct = read_mat(fname, - ignore_fields=['previous'], - variable_names=[data_name]) + ft_struct = read_mat(fname, ignore_fields=["previous"], variable_names=[data_name]) # load data and set ft_struct to the heading dictionary ft_struct = ft_struct[data_name] @@ -62,7 +65,7 @@ def read_raw_fieldtrip(fname, info, data_name='data'): _validate_ft_struct(ft_struct) info = _create_info(ft_struct, info) # create info structure - data = np.array(ft_struct['trial']) # create the main data array + data = np.array(ft_struct["trial"]) # create the main data array if data.ndim > 2: data = np.squeeze(data) @@ -71,15 +74,15 @@ def read_raw_fieldtrip(fname, info, data_name='data'): data = data[np.newaxis, ...] if data.ndim != 2: - raise RuntimeError('The data you are trying to load does not seem to ' - 'be raw data') + raise RuntimeError( + "The data you are trying to load does not seem to " "be raw data" + ) raw = RawArray(data, info) # create an MNE RawArray return raw -def read_epochs_fieldtrip(fname, info, data_name='data', - trialinfo_column=0): +def read_epochs_fieldtrip(fname, info, data_name="data", trialinfo_column=0): """Load epoched data from a FieldTrip preprocessing structure. This function expects to find epoched data in the structure data_name is @@ -114,10 +117,8 @@ def read_epochs_fieldtrip(fname, info, data_name='data', epochs : instance of EpochsArray An EpochsArray containing the loaded data. """ - read_mat = _import_pymatreader_funcs('FieldTrip I/O') - ft_struct = read_mat(fname, - ignore_fields=['previous'], - variable_names=[data_name]) + read_mat = _import_pymatreader_funcs("FieldTrip I/O") + ft_struct = read_mat(fname, ignore_fields=["previous"], variable_names=[data_name]) # load data and set ft_struct to the heading dictionary ft_struct = ft_struct[data_name] @@ -125,7 +126,7 @@ def read_epochs_fieldtrip(fname, info, data_name='data', _validate_ft_struct(ft_struct) info = _create_info(ft_struct, info) # create info structure - data = np.array(ft_struct['trial']) # create the epochs data array + data = np.array(ft_struct["trial"]) # create the epochs data array events = _create_events(ft_struct, trialinfo_column) if events is not None: metadata = _create_event_metadata(ft_struct) @@ -133,13 +134,13 @@ def read_epochs_fieldtrip(fname, info, data_name='data', metadata = None tmin = _set_tmin(ft_struct) # create start time - epochs = EpochsArray(data=data, info=info, tmin=tmin, - events=events, metadata=metadata, proj=False) + epochs = EpochsArray( + data=data, info=info, tmin=tmin, events=events, metadata=metadata, proj=False + ) return epochs -def read_evoked_fieldtrip(fname, info, comment=None, - data_name='data'): +def read_evoked_fieldtrip(fname, info, comment=None, data_name="data"): """Load evoked data from a FieldTrip timelocked structure. This function expects to find timelocked data in the structure data_name is @@ -171,16 +172,14 @@ def read_evoked_fieldtrip(fname, info, comment=None, evoked : instance of EvokedArray An EvokedArray containing the loaded data. """ - read_mat = _import_pymatreader_funcs('FieldTrip I/O') - ft_struct = read_mat(fname, - ignore_fields=['previous'], - variable_names=[data_name]) + read_mat = _import_pymatreader_funcs("FieldTrip I/O") + ft_struct = read_mat(fname, ignore_fields=["previous"], variable_names=[data_name]) ft_struct = ft_struct[data_name] _validate_ft_struct(ft_struct) info = _create_info(ft_struct, info) # create info structure - data_evoked = ft_struct['avg'] # create evoked data + data_evoked = ft_struct["avg"] # create evoked data evoked = EvokedArray(data_evoked, info, comment=comment) return evoked diff --git a/mne/io/fieldtrip/tests/helpers.py b/mne/io/fieldtrip/tests/helpers.py index 076c7f9053a..d93e6ad126e 100644 --- a/mne/io/fieldtrip/tests/helpers.py +++ b/mne/io/fieldtrip/tests/helpers.py @@ -12,47 +12,78 @@ from mne.utils import object_diff -info_ignored_fields = ('file_id', 'hpi_results', 'hpi_meas', 'meas_id', - 'meas_date', 'highpass', 'lowpass', 'subject_info', - 'hpi_subsystem', 'experimenter', 'description', - 'proj_id', 'proj_name', 'line_freq', 'gantry_angle', - 'dev_head_t', 'bads', 'ctf_head_t', 'dev_ctf_t', - 'dig') - -ch_ignore_fields = ('logno', 'cal', 'range', 'scanno', 'coil_type', 'kind', - 'loc', 'coord_frame', 'unit') - -info_long_fields = ('hpi_meas', 'projs') - -system_to_reader_fn_dict = {'neuromag306': mne.io.read_raw_fif, - 'CNT': partial(mne.io.read_raw_cnt), - 'CTF': partial(mne.io.read_raw_ctf, - clean_names=True), - 'BTI': partial(mne.io.read_raw_bti, - head_shape_fname=None, - rename_channels=False, - sort_by_ch_name=False), - 'EGI': mne.io.read_raw_egi, - 'eximia': mne.io.read_raw_eximia} - -ignore_channels_dict = {'BTI': ['MUz', 'MLx', 'MLy', 'MUx', 'MUy', 'MLz']} - -drop_extra_chans_dict = {'EGI': ['STI 014', 'DIN1', 'DIN3', - 'DIN7', 'DIN4', 'DIN5', 'DIN2'], - 'eximia': ['GateIn', 'Trig1', 'Trig2']} - -system_decimal_accuracy_dict = {'CNT': 2} - -pandas_not_found_warning_msg = 'The Pandas library is not installed. Not ' \ - 'returning the original trialinfo matrix as ' \ - 'metadata.' +info_ignored_fields = ( + "file_id", + "hpi_results", + "hpi_meas", + "meas_id", + "meas_date", + "highpass", + "lowpass", + "subject_info", + "hpi_subsystem", + "experimenter", + "description", + "proj_id", + "proj_name", + "line_freq", + "gantry_angle", + "dev_head_t", + "bads", + "ctf_head_t", + "dev_ctf_t", + "dig", +) + +ch_ignore_fields = ( + "logno", + "cal", + "range", + "scanno", + "coil_type", + "kind", + "loc", + "coord_frame", + "unit", +) + +info_long_fields = ("hpi_meas", "projs") + +system_to_reader_fn_dict = { + "neuromag306": mne.io.read_raw_fif, + "CNT": partial(mne.io.read_raw_cnt), + "CTF": partial(mne.io.read_raw_ctf, clean_names=True), + "BTI": partial( + mne.io.read_raw_bti, + head_shape_fname=None, + rename_channels=False, + sort_by_ch_name=False, + ), + "EGI": mne.io.read_raw_egi, + "eximia": mne.io.read_raw_eximia, +} + +ignore_channels_dict = {"BTI": ["MUz", "MLx", "MLy", "MUx", "MUy", "MLz"]} + +drop_extra_chans_dict = { + "EGI": ["STI 014", "DIN1", "DIN3", "DIN7", "DIN4", "DIN5", "DIN2"], + "eximia": ["GateIn", "Trig1", "Trig2"], +} + +system_decimal_accuracy_dict = {"CNT": 2} + +pandas_not_found_warning_msg = ( + "The Pandas library is not installed. Not " + "returning the original trialinfo matrix as " + "metadata." +) testing_path = mne.datasets.testing.data_path(download=False) def _remove_ignored_ch_fields(info): - if 'chs' in info: - for cur_ch in info['chs']: + if "chs" in info: + for cur_ch in info["chs"]: for cur_field in ch_ignore_fields: if cur_field in cur_ch: del cur_ch[cur_field] @@ -80,8 +111,10 @@ def get_data_paths(system): def get_cfg_local(system): """Return cfg_local field for the system.""" from pymatreader import read_mat - cfg_local = read_mat(os.path.join(get_data_paths(system), 'raw_v7.mat'), - ['cfg_local'])['cfg_local'] + + cfg_local = read_mat( + os.path.join(get_data_paths(system), "raw_v7.mat"), ["cfg_local"] + )["cfg_local"] return cfg_local @@ -90,12 +123,12 @@ def get_raw_info(system): """Return the info dict of the raw data.""" cfg_local = get_cfg_local(system) - raw_data_file = os.path.join(testing_path, cfg_local['file_name']) + raw_data_file = os.path.join(testing_path, cfg_local["file_name"]) reader_function = system_to_reader_fn_dict[system] info = reader_function(raw_data_file, preload=False).info with info._unlock(): - info['comps'] = [] + info["comps"] = [] return info @@ -103,23 +136,23 @@ def get_raw_data(system, drop_extra_chs=False): """Find, load and process the raw data.""" cfg_local = get_cfg_local(system) - raw_data_file = os.path.join(testing_path, cfg_local['file_name']) + raw_data_file = os.path.join(testing_path, cfg_local["file_name"]) reader_function = system_to_reader_fn_dict[system] raw_data = reader_function(raw_data_file, preload=True) - crop = min(cfg_local['crop'], np.max(raw_data.times)) - if system == 'eximia': - crop -= 0.5 * (1.0 / raw_data.info['sfreq']) + crop = min(cfg_local["crop"], np.max(raw_data.times)) + if system == "eximia": + crop -= 0.5 * (1.0 / raw_data.info["sfreq"]) raw_data.crop(0, crop) - raw_data.del_proj('all') + raw_data.del_proj("all") with raw_data.info._unlock(): - raw_data.info['comps'] = [] - raw_data.drop_channels(cfg_local['removed_chan_names']) + raw_data.info["comps"] = [] + raw_data.drop_channels(cfg_local["removed_chan_names"]) - if system in ['EGI']: + if system in ["EGI"]: raw_data._data[0:-1, :] = raw_data._data[0:-1, :] * 1e6 - if system in ['CNT']: + if system in ["CNT"]: raw_data._data = raw_data._data * 1e6 if system in ignore_channels_dict: @@ -136,29 +169,32 @@ def get_epochs(system): cfg_local = get_cfg_local(system) raw_data = get_raw_data(system) - if cfg_local['eventtype'] in raw_data.ch_names: - stim_channel = cfg_local['eventtype'] + if cfg_local["eventtype"] in raw_data.ch_names: + stim_channel = cfg_local["eventtype"] else: - stim_channel = 'STI 014' + stim_channel = "STI 014" - if system == 'CNT': + if system == "CNT": events, event_id = mne.events_from_annotations(raw_data) events[:, 0] = events[:, 0] + 1 else: - events = mne.find_events(raw_data, stim_channel=stim_channel, - shortest_event=1) + events = mne.find_events(raw_data, stim_channel=stim_channel, shortest_event=1) - if isinstance(cfg_local['eventvalue'], np.ndarray): - event_id = list(cfg_local['eventvalue'].astype('int')) + if isinstance(cfg_local["eventvalue"], np.ndarray): + event_id = list(cfg_local["eventvalue"].astype("int")) else: - event_id = [int(cfg_local['eventvalue'])] + event_id = [int(cfg_local["eventvalue"])] event_id = [id for id in event_id if id in events[:, 2]] - epochs = mne.Epochs(raw_data, events=events, - event_id=event_id, - tmin=-cfg_local['prestim'], - tmax=cfg_local['poststim'], baseline=None) + epochs = mne.Epochs( + raw_data, + events=events, + event_id=event_id, + tmin=-cfg_local["prestim"], + tmax=cfg_local["poststim"], + baseline=None, + ) return epochs @@ -188,12 +224,12 @@ def check_info_fields(expected, actual, has_raw_info, ignore_long=True): # we annoyingly have two ways of representing this, so just always use # an empty list here for obj in (expected, actual): - if obj.get('dig', None) is None: + if obj.get("dig", None) is None: with obj._unlock(): - obj['dig'] = [] + obj["dig"] = [] d = object_diff(actual, expected, allclose=True) - assert d == '', d + assert d == "", d def check_data(expected, actual, system): diff --git a/mne/io/fieldtrip/tests/test_fieldtrip.py b/mne/io/fieldtrip/tests/test_fieldtrip.py index 080ee0a7eda..64690c857fe 100644 --- a/mne/io/fieldtrip/tests/test_fieldtrip.py +++ b/mne/io/fieldtrip/tests/test_fieldtrip.py @@ -17,9 +17,16 @@ from mne.io import read_raw_fieldtrip from mne.io.fieldtrip.utils import NOINFO_WARNING, _create_events from mne.io.fieldtrip.tests.helpers import ( - check_info_fields, get_data_paths, get_raw_data, get_epochs, get_evoked, - pandas_not_found_warning_msg, get_raw_info, check_data, - assert_warning_in_record) + check_info_fields, + get_data_paths, + get_raw_data, + get_epochs, + get_evoked, + pandas_not_found_warning_msg, + get_raw_info, + check_data, + assert_warning_in_record, +) from mne.io.tests.test_raw import _test_raw_reader from mne.utils import _check_pandas_installed, _record_warnings @@ -27,28 +34,26 @@ # names. # EGI: no calibration done in FT. so data is VERY different -all_systems_raw = ['neuromag306', 'CTF', 'CNT', 'BTI', 'eximia'] -all_systems_epochs = ['neuromag306', 'CTF', 'CNT'] -all_versions = ['v7', 'v73'] +all_systems_raw = ["neuromag306", "CTF", "CNT", "BTI", "eximia"] +all_systems_epochs = ["neuromag306", "CTF", "CNT"] +all_versions = ["v7", "v73"] use_info = [True, False] -all_test_params_raw = list(itertools.product(all_systems_raw, all_versions, - use_info)) -all_test_params_epochs = list(itertools.product(all_systems_epochs, - all_versions, - use_info)) +all_test_params_raw = list(itertools.product(all_systems_raw, all_versions, use_info)) +all_test_params_epochs = list( + itertools.product(all_systems_epochs, all_versions, use_info) +) # just for speed we skip some slowest ones -- the coverage should still # be sufficient for obj in (all_test_params_epochs, all_test_params_raw): - for key in [('CTF', 'v73', True), ('neuromag306', 'v73', False)]: + for key in [("CTF", "v73", True), ("neuromag306", "v73", False)]: obj.pop(obj.index(key)) for ki, key in enumerate(obj): - if key[1] == 'v73': + if key[1] == "v73": obj[ki] = pytest.param(*obj[ki], marks=pytest.mark.slowtest) -no_info_warning = {'expected_warning': RuntimeWarning, - 'match': NOINFO_WARNING} +no_info_warning = {"expected_warning": RuntimeWarning, "match": NOINFO_WARNING} -pymatreader = pytest.importorskip('pymatreader') # module-level +pymatreader = pytest.importorskip("pymatreader") # module-level testing_path = mne.datasets.testing.data_path(download=False) @@ -56,10 +61,9 @@ @testing.requires_testing_data # Reading the sample CNT data results in a RuntimeWarning because it cannot # parse the measurement date. We need to ignore that warning. -@pytest.mark.filterwarnings('ignore:.*parse meas date.*:RuntimeWarning') -@pytest.mark.filterwarnings('ignore:.*number of bytes.*:RuntimeWarning') -@pytest.mark.parametrize('cur_system, version, use_info', - all_test_params_epochs) +@pytest.mark.filterwarnings("ignore:.*parse meas date.*:RuntimeWarning") +@pytest.mark.filterwarnings("ignore:.*number of bytes.*:RuntimeWarning") +@pytest.mark.parametrize("cur_system, version, use_info", all_test_params_epochs) def test_read_evoked(cur_system, version, use_info): """Test comparing reading an Evoked object and the FieldTrip version.""" test_data_folder_ft = get_data_paths(cur_system) @@ -85,15 +89,16 @@ def test_read_evoked(cur_system, version, use_info): @testing.requires_testing_data # Reading the sample CNT data results in a RuntimeWarning because it cannot # parse the measurement date. We need to ignore that warning. -@pytest.mark.filterwarnings('ignore:.*parse meas date.*:RuntimeWarning') -@pytest.mark.filterwarnings('ignore:.*number of bytes.*:RuntimeWarning') -@pytest.mark.parametrize('cur_system, version, use_info', - all_test_params_epochs) +@pytest.mark.filterwarnings("ignore:.*parse meas date.*:RuntimeWarning") +@pytest.mark.filterwarnings("ignore:.*number of bytes.*:RuntimeWarning") +@pytest.mark.parametrize("cur_system, version, use_info", all_test_params_epochs) # Strange, non-deterministic Pandas errors: # "ValueError: cannot expose native-only dtype 'g' in non-native # byte order '<' via buffer interface" -@pytest.mark.skipif(os.getenv('AZURE_CI_WINDOWS', 'false').lower() == 'true', - reason='Pandas problem on Azure CI') +@pytest.mark.skipif( + os.getenv("AZURE_CI_WINDOWS", "false").lower() == "true", + reason="Pandas problem on Azure CI", +) def test_read_epochs(cur_system, version, use_info, monkeypatch): """Test comparing reading an Epochs object and the FieldTrip version.""" pandas = _check_pandas_installed(strict=False) @@ -130,21 +135,21 @@ def test_read_epochs(cur_system, version, use_info, monkeypatch): # weird sfreq def modify_mat(fname, variable_names=None, ignore_fields=None): out = read_mat(fname, variable_names, ignore_fields) - if 'fsample' in out['data']: - out['data']['fsample'] = np.repeat(out['data']['fsample'], 2) + if "fsample" in out["data"]: + out["data"]["fsample"] = np.repeat(out["data"]["fsample"], 2) return out - monkeypatch.setattr(pymatreader, 'read_mat', modify_mat) - with pytest.warns(RuntimeWarning, match='multiple'): + monkeypatch.setattr(pymatreader, "read_mat", modify_mat) + with pytest.warns(RuntimeWarning, match="multiple"): mne.io.read_epochs_fieldtrip(cur_fname, info) @testing.requires_testing_data # Reading the sample CNT data results in a RuntimeWarning because it cannot # parse the measurement date. We need to ignore that warning. -@pytest.mark.filterwarnings('ignore:.*parse meas date.*:RuntimeWarning') -@pytest.mark.filterwarnings('ignore:.*number of bytes.*:RuntimeWarning') -@pytest.mark.parametrize('cur_system, version, use_info', all_test_params_raw) +@pytest.mark.filterwarnings("ignore:.*parse meas date.*:RuntimeWarning") +@pytest.mark.filterwarnings("ignore:.*number of bytes.*:RuntimeWarning") +@pytest.mark.parametrize("cur_system, version, use_info", all_test_params_raw) def test_read_raw_fieldtrip(cur_system, version, use_info): """Test comparing reading a raw fiff file and the FieldTrip version.""" # Load the raw fiff file with mne @@ -152,8 +157,8 @@ def test_read_raw_fieldtrip(cur_system, version, use_info): raw_fiff_mne = get_raw_data(cur_system, drop_extra_chs=True) if use_info: info = get_raw_info(cur_system) - if cur_system in ('BTI', 'eximia'): - ctx = pytest.warns(RuntimeWarning, match='cannot be found in') + if cur_system in ("BTI", "eximia"): + ctx = pytest.warns(RuntimeWarning, match="cannot be found in") else: ctx = nullcontext() else: @@ -165,24 +170,24 @@ def test_read_raw_fieldtrip(cur_system, version, use_info): with ctx: raw_fiff_ft = mne.io.read_raw_fieldtrip(cur_fname, info) - if cur_system == 'BTI' and not use_info: - raw_fiff_ft.drop_channels(['MzA', 'MxA', 'MyaA', - 'MyA', 'MxaA', 'MzaA']) + if cur_system == "BTI" and not use_info: + raw_fiff_ft.drop_channels(["MzA", "MxA", "MyaA", "MyA", "MxaA", "MzaA"]) - if cur_system == 'eximia' and not use_info: - raw_fiff_ft.drop_channels(['TRIG2', 'TRIG1', 'GATE']) + if cur_system == "eximia" and not use_info: + raw_fiff_ft.drop_channels(["TRIG2", "TRIG1", "GATE"]) # Check that the data was loaded correctly - check_data(raw_fiff_mne.get_data(), - raw_fiff_ft.get_data(), - cur_system) + check_data(raw_fiff_mne.get_data(), raw_fiff_ft.get_data(), cur_system) # standard tests with _record_warnings(): _test_raw_reader( - read_raw_fieldtrip, fname=cur_fname, info=info, + read_raw_fieldtrip, + fname=cur_fname, + info=info, test_preloading=False, - test_kwargs=False) # TODO: This should probably work + test_kwargs=False, + ) # TODO: This should probably work # Check info field check_info_fields(raw_fiff_mne, raw_fiff_ft, use_info) @@ -191,8 +196,8 @@ def test_read_raw_fieldtrip(cur_system, version, use_info): @testing.requires_testing_data def test_load_epoched_as_raw(): """Test whether exception is thrown when loading epochs as raw.""" - test_data_folder_ft = get_data_paths('neuromag306') - info = get_raw_info('neuromag306') + test_data_folder_ft = get_data_paths("neuromag306") + info = get_raw_info("neuromag306") cur_fname = test_data_folder_ft / "epoched_v7.mat" with pytest.raises(RuntimeError): @@ -202,8 +207,8 @@ def test_load_epoched_as_raw(): @testing.requires_testing_data def test_invalid_trialinfocolumn(): """Test for exceptions when using wrong values for trialinfo parameter.""" - test_data_folder_ft = get_data_paths('neuromag306') - info = get_raw_info('neuromag306') + test_data_folder_ft = get_data_paths("neuromag306") + info = get_raw_info("neuromag306") cur_fname = test_data_folder_ft / "epoched_v7.mat" with pytest.raises(ValueError): @@ -216,14 +221,17 @@ def test_invalid_trialinfocolumn(): @testing.requires_testing_data def test_create_events(): """Test 2dim trialinfo fields.""" - test_data_folder_ft = get_data_paths('neuromag306') + test_data_folder_ft = get_data_paths("neuromag306") cur_fname = test_data_folder_ft / "epoched_v7.mat" - original_data = pymatreader.read_mat(cur_fname, ['data', ]) + original_data = pymatreader.read_mat( + cur_fname, + [ + "data", + ], + ) new_data = copy.deepcopy(original_data) - new_data['trialinfo'] = np.array([[1, 2, 3, 4], - [1, 2, 3, 4], - [1, 2, 3, 4]]) + new_data["trialinfo"] = np.array([[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]]) with pytest.raises(ValueError): _create_events(new_data, -1) @@ -237,14 +245,10 @@ def test_create_events(): @testing.requires_testing_data -@pytest.mark.parametrize('version', all_versions) +@pytest.mark.parametrize("version", all_versions) def test_one_channel_elec_bug(version): """Test if loading data having only one elec in the elec field works.""" - fname = ( - testing_path - / "fieldtrip" - / f"one_channel_elec_bug_data_{version}.mat" - ) + fname = testing_path / "fieldtrip" / f"one_channel_elec_bug_data_{version}.mat" with pytest.warns(**no_info_warning): mne.io.read_raw_fieldtrip(fname, info=None) @@ -253,65 +257,70 @@ def test_one_channel_elec_bug(version): @testing.requires_testing_data # Reading the sample CNT data results in a RuntimeWarning because it cannot # parse the measurement date. We need to ignore that warning. -@pytest.mark.filterwarnings('ignore:.*parse meas date.*:RuntimeWarning') -@pytest.mark.filterwarnings('ignore:.*number of bytes.*:RuntimeWarning') -@pytest.mark.parametrize('version', all_versions) -@pytest.mark.parametrize('type', ['averaged', 'epoched', 'raw']) +@pytest.mark.filterwarnings("ignore:.*parse meas date.*:RuntimeWarning") +@pytest.mark.filterwarnings("ignore:.*number of bytes.*:RuntimeWarning") +@pytest.mark.parametrize("version", all_versions) +@pytest.mark.parametrize("type", ["averaged", "epoched", "raw"]) def test_throw_exception_on_cellarray(version, type): """Test for a meaningful exception when the data is a cell array.""" fname = get_data_paths("cellarray") / f"{type}_{version}.mat" - info = get_raw_info('CNT') - with pytest.raises(RuntimeError, match='Loading of data in cell arrays ' - 'is not supported'): - if type == 'averaged': + info = get_raw_info("CNT") + with pytest.raises( + RuntimeError, match="Loading of data in cell arrays " "is not supported" + ): + if type == "averaged": mne.read_evoked_fieldtrip(fname, info) - elif type == 'epoched': + elif type == "epoched": mne.read_epochs_fieldtrip(fname, info) - elif type == 'raw': + elif type == "raw": mne.io.read_raw_fieldtrip(fname, info) @testing.requires_testing_data def test_with_missing_channels(): """Test _create_info when channels are missing from info.""" - cur_system = 'neuromag306' + cur_system = "neuromag306" test_data_folder_ft = get_data_paths(cur_system) info = get_raw_info(cur_system) - del info['chs'][1:20] + del info["chs"][1:20] info._update_redundant() with pytest.warns(RuntimeWarning): mne.io.read_raw_fieldtrip(test_data_folder_ft / "raw_v7.mat", info) - mne.read_evoked_fieldtrip( - test_data_folder_ft / "averaged_v7.mat", info) + mne.read_evoked_fieldtrip(test_data_folder_ft / "averaged_v7.mat", info) mne.read_epochs_fieldtrip(test_data_folder_ft / "epoched_v7.mat", info) @testing.requires_testing_data -@pytest.mark.filterwarnings('ignore: Importing FieldTrip data without an info') -@pytest.mark.filterwarnings('ignore: Cannot guess the correct type') +@pytest.mark.filterwarnings("ignore: Importing FieldTrip data without an info") +@pytest.mark.filterwarnings("ignore: Cannot guess the correct type") def test_throw_error_on_non_uniform_time_field(): """Test if an error is thrown when time fields are not uniform.""" fname = testing_path / "fieldtrip" / "not_uniform_time.mat" - with pytest.raises(RuntimeError, match='Loading data with non-uniform ' - 'times per epoch is not supported'): + with pytest.raises( + RuntimeError, + match="Loading data with non-uniform " "times per epoch is not supported", + ): mne.io.read_epochs_fieldtrip(fname, info=None) @testing.requires_testing_data -@pytest.mark.filterwarnings('ignore: Importing FieldTrip data without an info') +@pytest.mark.filterwarnings("ignore: Importing FieldTrip data without an info") def test_throw_error_when_importing_old_ft_version_data(): """Test if an error is thrown if the data was saved with an old version.""" fname = testing_path / "fieldtrip" / "old_version.mat" - with pytest.raises(RuntimeError, match='This file was created with ' - 'an old version of FieldTrip. You ' - 'can convert the data to the new ' - 'version by loading it into ' - 'FieldTrip and applying ' - 'ft_selectdata with an ' - 'empty cfg structure on it. ' - 'Otherwise you can supply ' - 'the Info field.'): + with pytest.raises( + RuntimeError, + match="This file was created with " + "an old version of FieldTrip. You " + "can convert the data to the new " + "version by loading it into " + "FieldTrip and applying " + "ft_selectdata with an " + "empty cfg structure on it. " + "Otherwise you can supply " + "the Info field.", + ): mne.io.read_epochs_fieldtrip(fname, info=None) diff --git a/mne/io/fieldtrip/utils.py b/mne/io/fieldtrip/utils.py index 7127f63ab54..aed95ed0347 100644 --- a/mne/io/fieldtrip/utils.py +++ b/mne/io/fieldtrip/utils.py @@ -12,28 +12,32 @@ from ...transforms import rotation3d_align_z_axis from ...utils import warn, _check_pandas_installed -_supported_megs = ['neuromag306'] - -_unit_dict = {'m': 1, - 'cm': 1e-2, - 'mm': 1e-3, - 'V': 1, - 'mV': 1e-3, - 'uV': 1e-6, - 'T': 1, - 'T/m': 1, - 'T/cm': 1e2} - -NOINFO_WARNING = 'Importing FieldTrip data without an info dict from the ' \ - 'original file. Channel locations, orientations and types ' \ - 'will be incorrect. The imported data cannot be used for ' \ - 'source analysis, channel interpolation etc.' +_supported_megs = ["neuromag306"] + +_unit_dict = { + "m": 1, + "cm": 1e-2, + "mm": 1e-3, + "V": 1, + "mV": 1e-3, + "uV": 1e-6, + "T": 1, + "T/m": 1, + "T/cm": 1e2, +} + +NOINFO_WARNING = ( + "Importing FieldTrip data without an info dict from the " + "original file. Channel locations, orientations and types " + "will be incorrect. The imported data cannot be used for " + "source analysis, channel interpolation etc." +) def _validate_ft_struct(ft_struct): """Run validation checks on the ft_structure.""" if isinstance(ft_struct, list): - raise RuntimeError('Loading of data in cell arrays is not supported') + raise RuntimeError("Loading of data in cell arrays is not supported") def _create_info(ft_struct, raw_info): @@ -42,36 +46,37 @@ def _create_info(ft_struct, raw_info): warn(NOINFO_WARNING) sfreq = _set_sfreq(ft_struct) - ch_names = ft_struct['label'] + ch_names = ft_struct["label"] if raw_info: info = raw_info.copy() - missing_channels = set(ch_names) - set(info['ch_names']) + missing_channels = set(ch_names) - set(info["ch_names"]) if missing_channels: - warn('The following channels are present in the FieldTrip data ' - 'but cannot be found in the provided info: %s.\n' - 'These channels will be removed from the resulting data!' - % (str(missing_channels), )) + warn( + "The following channels are present in the FieldTrip data " + "but cannot be found in the provided info: %s.\n" + "These channels will be removed from the resulting data!" + % (str(missing_channels),) + ) missing_chan_idx = [ch_names.index(ch) for ch in missing_channels] new_chs = [ch for ch in ch_names if ch not in missing_channels] ch_names = new_chs - ft_struct['label'] = ch_names + ft_struct["label"] = ch_names - if 'trial' in ft_struct: - ft_struct['trial'] = _remove_missing_channels_from_trial( - ft_struct['trial'], - missing_chan_idx + if "trial" in ft_struct: + ft_struct["trial"] = _remove_missing_channels_from_trial( + ft_struct["trial"], missing_chan_idx ) - if 'avg' in ft_struct: - if ft_struct['avg'].ndim == 2: - ft_struct['avg'] = np.delete(ft_struct['avg'], - missing_chan_idx, - axis=0) + if "avg" in ft_struct: + if ft_struct["avg"].ndim == 2: + ft_struct["avg"] = np.delete( + ft_struct["avg"], missing_chan_idx, axis=0 + ) with info._unlock(): - info['sfreq'] = sfreq - ch_idx = [info['ch_names'].index(ch) for ch in ch_names] + info["sfreq"] = sfreq + ch_idx = [info["ch_names"].index(ch) for ch in ch_names] pick_info(info, ch_idx, copy=False) else: info = create_info(ch_names, sfreq) @@ -90,80 +95,89 @@ def _remove_missing_channels_from_trial(trial, missing_chan_idx): ) elif isinstance(trial, np.ndarray): if trial.ndim == 2: - trial = np.delete(trial, - missing_chan_idx, - axis=0) + trial = np.delete(trial, missing_chan_idx, axis=0) else: - raise ValueError('"trial" field of the FieldTrip structure ' - 'has an unknown format.') + raise ValueError( + '"trial" field of the FieldTrip structure ' "has an unknown format." + ) return trial def _create_info_chs_dig(ft_struct): """Create the chs info field from the FieldTrip structure.""" - all_channels = ft_struct['label'] - ch_defaults = dict(coord_frame=FIFF.FIFFV_COORD_UNKNOWN, - cal=1.0, - range=1.0, - unit_mul=FIFF.FIFF_UNITM_NONE, - loc=np.array([0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1]), - unit=FIFF.FIFF_UNIT_V) + all_channels = ft_struct["label"] + ch_defaults = dict( + coord_frame=FIFF.FIFFV_COORD_UNKNOWN, + cal=1.0, + range=1.0, + unit_mul=FIFF.FIFF_UNITM_NONE, + loc=np.array([0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1]), + unit=FIFF.FIFF_UNIT_V, + ) try: - elec = ft_struct['elec'] + elec = ft_struct["elec"] except KeyError: elec = None try: - grad = ft_struct['grad'] + grad = ft_struct["grad"] except KeyError: grad = None if elec is None and grad is None: - warn('The supplied FieldTrip structure does not have an elec or grad ' - 'field. No channel locations will extracted and the kind of ' - 'channel might be inaccurate.') - if 'chanpos' not in (elec or grad or {'chanpos': None}): + warn( + "The supplied FieldTrip structure does not have an elec or grad " + "field. No channel locations will extracted and the kind of " + "channel might be inaccurate." + ) + if "chanpos" not in (elec or grad or {"chanpos": None}): raise RuntimeError( - 'This file was created with an old version of FieldTrip. You can ' - 'convert the data to the new version by loading it into FieldTrip ' - 'and applying ft_selectdata with an empty cfg structure on it. ' - 'Otherwise you can supply the Info field.') + "This file was created with an old version of FieldTrip. You can " + "convert the data to the new version by loading it into FieldTrip " + "and applying ft_selectdata with an empty cfg structure on it. " + "Otherwise you can supply the Info field." + ) chs = list() dig = list() counter = 0 for idx_chan, cur_channel_label in enumerate(all_channels): cur_ch = ch_defaults.copy() - cur_ch['ch_name'] = cur_channel_label - cur_ch['logno'] = idx_chan + 1 - cur_ch['scanno'] = idx_chan + 1 - if elec and cur_channel_label in elec['label']: + cur_ch["ch_name"] = cur_channel_label + cur_ch["logno"] = idx_chan + 1 + cur_ch["scanno"] = idx_chan + 1 + if elec and cur_channel_label in elec["label"]: cur_ch = _process_channel_eeg(cur_ch, elec) - assert cur_ch['coord_frame'] == FIFF.FIFFV_COORD_HEAD + assert cur_ch["coord_frame"] == FIFF.FIFFV_COORD_HEAD # Ref gets ident=0 and we don't have it, so start at 1 counter += 1 d = DigPoint( - r=cur_ch['loc'][:3], coord_frame=FIFF.FIFFV_COORD_HEAD, - kind=FIFF.FIFFV_POINT_EEG, ident=counter) + r=cur_ch["loc"][:3], + coord_frame=FIFF.FIFFV_COORD_HEAD, + kind=FIFF.FIFFV_POINT_EEG, + ident=counter, + ) dig.append(d) - elif grad and cur_channel_label in grad['label']: + elif grad and cur_channel_label in grad["label"]: cur_ch = _process_channel_meg(cur_ch, grad) else: - if cur_channel_label.startswith('EOG'): - cur_ch['kind'] = FIFF.FIFFV_EOG_CH - cur_ch['coil_type'] = FIFF.FIFFV_COIL_EEG - elif cur_channel_label.startswith('ECG'): - cur_ch['kind'] = FIFF.FIFFV_ECG_CH - cur_ch['coil_type'] = FIFF.FIFFV_COIL_EEG_BIPOLAR - elif cur_channel_label.startswith('STI'): - cur_ch['kind'] = FIFF.FIFFV_STIM_CH - cur_ch['coil_type'] = FIFF.FIFFV_COIL_NONE + if cur_channel_label.startswith("EOG"): + cur_ch["kind"] = FIFF.FIFFV_EOG_CH + cur_ch["coil_type"] = FIFF.FIFFV_COIL_EEG + elif cur_channel_label.startswith("ECG"): + cur_ch["kind"] = FIFF.FIFFV_ECG_CH + cur_ch["coil_type"] = FIFF.FIFFV_COIL_EEG_BIPOLAR + elif cur_channel_label.startswith("STI"): + cur_ch["kind"] = FIFF.FIFFV_STIM_CH + cur_ch["coil_type"] = FIFF.FIFFV_COIL_NONE else: - warn('Cannot guess the correct type of channel %s. Making ' - 'it a MISC channel.' % (cur_channel_label,)) - cur_ch['kind'] = FIFF.FIFFV_MISC_CH - cur_ch['coil_type'] = FIFF.FIFFV_COIL_NONE + warn( + "Cannot guess the correct type of channel %s. Making " + "it a MISC channel." % (cur_channel_label,) + ) + cur_ch["kind"] = FIFF.FIFFV_MISC_CH + cur_ch["coil_type"] = FIFF.FIFFV_COIL_NONE chs.append(cur_ch) _ensure_fiducials_head(dig) @@ -174,63 +188,67 @@ def _create_info_chs_dig(ft_struct): def _set_sfreq(ft_struct): """Set the sample frequency.""" try: - sfreq = ft_struct['fsample'] + sfreq = ft_struct["fsample"] except KeyError: try: - time = ft_struct['time'] + time = ft_struct["time"] except KeyError: - raise ValueError('No Source for sfreq found') + raise ValueError("No Source for sfreq found") else: t1, t2 = float(time[0]), float(time[1]) sfreq = 1 / (t2 - t1) try: sfreq = float(sfreq) except TypeError: - warn('FieldTrip structure contained multiple sample rates, trying the ' - f'first of:\n{sfreq} Hz') + warn( + "FieldTrip structure contained multiple sample rates, trying the " + f"first of:\n{sfreq} Hz" + ) sfreq = float(sfreq.ravel()[0]) return sfreq def _set_tmin(ft_struct): """Set the start time before the event in evoked data if possible.""" - times = ft_struct['time'] - time_check = all(times[i][0] == times[i - 1][0] - for i, x in enumerate(times)) + times = ft_struct["time"] + time_check = all(times[i][0] == times[i - 1][0] for i, x in enumerate(times)) if time_check: tmin = times[0][0] else: - raise RuntimeError('Loading data with non-uniform ' - 'times per epoch is not supported') + raise RuntimeError( + "Loading data with non-uniform " "times per epoch is not supported" + ) return tmin def _create_events(ft_struct, trialinfo_column): """Create an event matrix from the FieldTrip structure.""" - if 'trialinfo' not in ft_struct: + if "trialinfo" not in ft_struct: return None - event_type = ft_struct['trialinfo'] + event_type = ft_struct["trialinfo"] event_number = range(len(event_type)) if trialinfo_column < 0: - raise ValueError('trialinfo_column must be positive') + raise ValueError("trialinfo_column must be positive") available_ti_cols = 1 if event_type.ndim == 2: available_ti_cols = event_type.shape[1] if trialinfo_column > (available_ti_cols - 1): - raise ValueError('trialinfo_column is higher than the amount of' - 'columns in trialinfo.') + raise ValueError( + "trialinfo_column is higher than the amount of" "columns in trialinfo." + ) event_trans_val = np.zeros(len(event_type)) if event_type.ndim == 2: event_type = event_type[:, trialinfo_column] - events = np.vstack([np.array(event_number), event_trans_val, - event_type]).astype('int').T + events = ( + np.vstack([np.array(event_number), event_trans_val, event_type]).astype("int").T + ) return events @@ -239,11 +257,13 @@ def _create_event_metadata(ft_struct): """Create event metadata from trialinfo.""" pandas = _check_pandas_installed(strict=False) if not pandas: - warn('The Pandas library is not installed. Not returning the original ' - 'trialinfo matrix as metadata.') + warn( + "The Pandas library is not installed. Not returning the original " + "trialinfo matrix as metadata." + ) return None - metadata = pandas.DataFrame(ft_struct['trialinfo']) + metadata = pandas.DataFrame(ft_struct["trialinfo"]) return metadata @@ -264,18 +284,18 @@ def _process_channel_eeg(cur_ch, elec): cur_ch: dict The original dict (cur_ch) with the added information """ - all_labels = np.asanyarray(elec['label']) - chan_idx_in_elec = np.where(all_labels == cur_ch['ch_name'])[0][0] - position = np.squeeze(elec['chanpos'][chan_idx_in_elec, :]) + all_labels = np.asanyarray(elec["label"]) + chan_idx_in_elec = np.where(all_labels == cur_ch["ch_name"])[0][0] + position = np.squeeze(elec["chanpos"][chan_idx_in_elec, :]) # chanunit = elec['chanunit'][chan_idx_in_elec] # not used/needed yet - position_unit = elec['unit'] + position_unit = elec["unit"] position = position * _unit_dict[position_unit] - cur_ch['loc'] = np.hstack((position, np.zeros((9,)))) - cur_ch['unit'] = FIFF.FIFF_UNIT_V - cur_ch['kind'] = FIFF.FIFFV_EEG_CH - cur_ch['coil_type'] = FIFF.FIFFV_COIL_EEG - cur_ch['coord_frame'] = FIFF.FIFFV_COORD_HEAD + cur_ch["loc"] = np.hstack((position, np.zeros((9,)))) + cur_ch["unit"] = FIFF.FIFF_UNIT_V + cur_ch["kind"] = FIFF.FIFFV_EEG_CH + cur_ch["coil_type"] = FIFF.FIFFV_COIL_EEG + cur_ch["coord_frame"] = FIFF.FIFFV_COORD_HEAD return cur_ch @@ -295,27 +315,27 @@ def _process_channel_meg(cur_ch, grad): ------- dict: The original dict (cur_ch) with the added information """ - all_labels = np.asanyarray(grad['label']) - chan_idx_in_grad = np.where(all_labels == cur_ch['ch_name'])[0][0] - gradtype = grad['type'] - chantype = grad['chantype'][chan_idx_in_grad] - position_unit = grad['unit'] - position = np.squeeze(grad['chanpos'][chan_idx_in_grad, :]) + all_labels = np.asanyarray(grad["label"]) + chan_idx_in_grad = np.where(all_labels == cur_ch["ch_name"])[0][0] + gradtype = grad["type"] + chantype = grad["chantype"][chan_idx_in_grad] + position_unit = grad["unit"] + position = np.squeeze(grad["chanpos"][chan_idx_in_grad, :]) position = position * _unit_dict[position_unit] - if gradtype == 'neuromag306' and 'tra' in grad and 'coilpos' in grad: + if gradtype == "neuromag306" and "tra" in grad and "coilpos" in grad: # Try to regenerate original channel pos. - idx_in_coilpos = np.where(grad['tra'][chan_idx_in_grad, :] != 0)[0] - cur_coilpos = grad['coilpos'][idx_in_coilpos, :] + idx_in_coilpos = np.where(grad["tra"][chan_idx_in_grad, :] != 0)[0] + cur_coilpos = grad["coilpos"][idx_in_coilpos, :] cur_coilpos = cur_coilpos * _unit_dict[position_unit] - cur_coilori = grad['coilori'][idx_in_coilpos, :] - if chantype == 'megmag': + cur_coilori = grad["coilori"][idx_in_coilpos, :] + if chantype == "megmag": position = cur_coilpos[0] - 0.0003 * cur_coilori[0] - if chantype == 'megplanar': + if chantype == "megplanar": tmp_pos = cur_coilpos - 0.0003 * cur_coilori position = np.average(tmp_pos, axis=0) - original_orientation = np.squeeze(grad['chanori'][chan_idx_in_grad, :]) + original_orientation = np.squeeze(grad["chanori"][chan_idx_in_grad, :]) try: orientation = rotation3d_align_z_axis(original_orientation).T except AssertionError: @@ -324,27 +344,26 @@ def _process_channel_meg(cur_ch, grad): orientation = orientation.flatten() # chanunit = grad['chanunit'][chan_idx_in_grad] # not used/needed yet - cur_ch['loc'] = np.hstack((position, orientation)) - cur_ch['kind'] = FIFF.FIFFV_MEG_CH - if chantype == 'megmag': - cur_ch['coil_type'] = FIFF.FIFFV_COIL_POINT_MAGNETOMETER - cur_ch['unit'] = FIFF.FIFF_UNIT_T - elif chantype == 'megplanar': - cur_ch['coil_type'] = FIFF.FIFFV_COIL_VV_PLANAR_T1 - cur_ch['unit'] = FIFF.FIFF_UNIT_T_M - elif chantype == 'refmag': - cur_ch['coil_type'] = FIFF.FIFFV_COIL_MAGNES_REF_MAG - cur_ch['unit'] = FIFF.FIFF_UNIT_T - elif chantype == 'refgrad': - cur_ch['coil_type'] = FIFF.FIFFV_COIL_MAGNES_REF_GRAD - cur_ch['unit'] = FIFF.FIFF_UNIT_T - elif chantype == 'meggrad': - cur_ch['coil_type'] = FIFF.FIFFV_COIL_AXIAL_GRAD_5CM - cur_ch['unit'] = FIFF.FIFF_UNIT_T + cur_ch["loc"] = np.hstack((position, orientation)) + cur_ch["kind"] = FIFF.FIFFV_MEG_CH + if chantype == "megmag": + cur_ch["coil_type"] = FIFF.FIFFV_COIL_POINT_MAGNETOMETER + cur_ch["unit"] = FIFF.FIFF_UNIT_T + elif chantype == "megplanar": + cur_ch["coil_type"] = FIFF.FIFFV_COIL_VV_PLANAR_T1 + cur_ch["unit"] = FIFF.FIFF_UNIT_T_M + elif chantype == "refmag": + cur_ch["coil_type"] = FIFF.FIFFV_COIL_MAGNES_REF_MAG + cur_ch["unit"] = FIFF.FIFF_UNIT_T + elif chantype == "refgrad": + cur_ch["coil_type"] = FIFF.FIFFV_COIL_MAGNES_REF_GRAD + cur_ch["unit"] = FIFF.FIFF_UNIT_T + elif chantype == "meggrad": + cur_ch["coil_type"] = FIFF.FIFFV_COIL_AXIAL_GRAD_5CM + cur_ch["unit"] = FIFF.FIFF_UNIT_T else: - raise RuntimeError('Unexpected coil type: %s.' % ( - chantype,)) + raise RuntimeError("Unexpected coil type: %s." % (chantype,)) - cur_ch['coord_frame'] = FIFF.FIFFV_COORD_HEAD + cur_ch["coord_frame"] = FIFF.FIFFV_COORD_HEAD return cur_ch diff --git a/mne/io/fiff/raw.py b/mne/io/fiff/raw.py index b71154a0208..30b9fb826a7 100644 --- a/mne/io/fiff/raw.py +++ b/mne/io/fiff/raw.py @@ -17,15 +17,22 @@ from ..meas_info import read_meas_info from ..tree import dir_tree_find from ..tag import read_tag, read_tag_info -from ..base import (BaseRaw, _RawShell, _check_raw_compatibility, - _check_maxshield) +from ..base import BaseRaw, _RawShell, _check_raw_compatibility, _check_maxshield from ..utils import _mult_cal_one from ...annotations import Annotations, _read_annotations_fif from ...event import AcqParserFIF -from ...utils import (check_fname, logger, verbose, warn, fill_doc, _file_like, - _on_missing, _check_fname) +from ...utils import ( + check_fname, + logger, + verbose, + warn, + fill_doc, + _file_like, + _on_missing, + _check_fname, +) @fill_doc @@ -71,27 +78,34 @@ class Raw(BaseRaw): """ @verbose - def __init__(self, fname, allow_maxshield=False, preload=False, - on_split_missing='raise', verbose=None): # noqa: D102 + def __init__( + self, + fname, + allow_maxshield=False, + preload=False, + on_split_missing="raise", + verbose=None, + ): # noqa: D102 raws = [] do_check_ext = not _file_like(fname) next_fname = fname while next_fname is not None: - raw, next_fname, buffer_size_sec = \ - self._read_raw_file(next_fname, allow_maxshield, - preload, do_check_ext) + raw, next_fname, buffer_size_sec = self._read_raw_file( + next_fname, allow_maxshield, preload, do_check_ext + ) do_check_ext = False raws.append(raw) if next_fname is not None: if not op.exists(next_fname): msg = ( - f'Split raw file detected but next file {next_fname} ' - 'does not exist. Ensure all files were transferred ' - 'properly and that split and original files were not ' - 'manually renamed on disk (split files should be ' - 'renamed by loading and re-saving with MNE-Python to ' - 'preserve proper filename linkage).') - _on_missing(on_split_missing, msg, name='on_split_missing') + f"Split raw file detected but next file {next_fname} " + "does not exist. Ensure all files were transferred " + "properly and that split and original files were not " + "manually renamed on disk (split files should be " + "renamed by loading and re-saving with MNE-Python to " + "preserve proper filename linkage)." + ) + _on_missing(on_split_missing, msg, name="on_split_missing") break if _file_like(fname): # avoid serialization error when copying file-like @@ -99,25 +113,33 @@ def __init__(self, fname, allow_maxshield=False, preload=False, _check_raw_compatibility(raws) super(Raw, self).__init__( - copy.deepcopy(raws[0].info), False, - [r.first_samp for r in raws], [r.last_samp for r in raws], - [r.filename for r in raws], [r._raw_extras for r in raws], - raws[0].orig_format, None, buffer_size_sec=buffer_size_sec, - verbose=verbose) + copy.deepcopy(raws[0].info), + False, + [r.first_samp for r in raws], + [r.last_samp for r in raws], + [r.filename for r in raws], + [r._raw_extras for r in raws], + raws[0].orig_format, + None, + buffer_size_sec=buffer_size_sec, + verbose=verbose, + ) # combine annotations self.set_annotations(raws[0].annotations, emit_warning=False) # Add annotations for in-data skips for extra in self._raw_extras: - mask = [ent is None for ent in extra['ent']] - start = extra['bounds'][:-1][mask] - stop = extra['bounds'][1:][mask] - 1 - duration = (stop - start + 1.) / self.info['sfreq'] - annot = Annotations(onset=(start / self.info['sfreq']), - duration=duration, - description='BAD_ACQ_SKIP', - orig_time=self.info['meas_date']) + mask = [ent is None for ent in extra["ent"]] + start = extra["bounds"][:-1][mask] + stop = extra["bounds"][1:][mask] - 1 + duration = (stop - start + 1.0) / self.info["sfreq"] + annot = Annotations( + onset=(start / self.info["sfreq"]), + duration=duration, + description="BAD_ACQ_SKIP", + orig_time=self.info["meas_date"], + ) self._annotations += annot @@ -130,27 +152,34 @@ def __init__(self, fname, allow_maxshield=False, preload=False, self._filenames = [_get_fname_rep(fname) for fname in self._filenames] @verbose - def _read_raw_file(self, fname, allow_maxshield, preload, - do_check_ext=True, verbose=None): + def _read_raw_file( + self, fname, allow_maxshield, preload, do_check_ext=True, verbose=None + ): """Read in header information from a raw file.""" - logger.info('Opening raw data file %s...' % fname) + logger.info("Opening raw data file %s..." % fname) # Read in the whole file if preload is on and .fif.gz (saves time) if not _file_like(fname): if do_check_ext: - endings = ('raw.fif', 'raw_sss.fif', 'raw_tsss.fif', - '_meg.fif', '_eeg.fif', '_ieeg.fif') - endings += tuple([f'{e}.gz' for e in endings]) - check_fname(fname, 'raw', endings) + endings = ( + "raw.fif", + "raw_sss.fif", + "raw_tsss.fif", + "_meg.fif", + "_eeg.fif", + "_ieeg.fif", + ) + endings += tuple([f"{e}.gz" for e in endings]) + check_fname(fname, "raw", endings) # filename fname = str(_check_fname(fname, "read", True, "fname")) ext = os.path.splitext(fname)[1].lower() - whole_file = preload if '.gz' in ext else False + whole_file = preload if ".gz" in ext else False del ext else: # file-like if not preload: - raise ValueError('preload must be used with file-like objects') + raise ValueError("preload must be used with file-like objects") whole_file = True fname_rep = _get_fname_rep(fname) ff, tree, _ = fiff_open(fname, preload=whole_file) @@ -164,22 +193,22 @@ def _read_raw_file(self, fname, allow_maxshield, preload, raw_node = dir_tree_find(meas, FIFF.FIFFB_RAW_DATA) if len(raw_node) == 0: raw_node = dir_tree_find(meas, FIFF.FIFFB_CONTINUOUS_DATA) - if (len(raw_node) == 0): + if len(raw_node) == 0: raw_node = dir_tree_find(meas, FIFF.FIFFB_IAS_RAW_DATA) - if (len(raw_node) == 0): - raise ValueError('No raw data in %s' % fname_rep) + if len(raw_node) == 0: + raise ValueError("No raw data in %s" % fname_rep) _check_maxshield(allow_maxshield) with info._unlock(): - info['maxshield'] = True + info["maxshield"] = True del meas if len(raw_node) == 1: raw_node = raw_node[0] # Process the directory - directory = raw_node['directory'] - nent = raw_node['nent'] - nchan = int(info['nchan']) + directory = raw_node["directory"] + nent = raw_node["nent"] + nchan = int(info["nchan"]) first = 0 first_samp = 0 first_skip = 0 @@ -202,11 +231,11 @@ def _read_raw_file(self, fname, allow_maxshield, preload, raw = _RawShell() raw.filename = fname raw.first_samp = first_samp - if info['meas_date'] is None and annotations is not None: + if info["meas_date"] is None and annotations is not None: # we need to adjust annotations.onset as when there is no meas # date set_annotations considers that the origin of time is the # first available sample (ignores first_samp) - annotations.onset -= first_samp / info['sfreq'] + annotations.onset -= first_samp / info["sfreq"] raw.set_annotations(annotations) # Go through the remaining tags in the directory @@ -238,23 +267,24 @@ def _read_raw_file(self, fname, allow_maxshield, preload, elif ent.type == FIFF.FIFFT_COMPLEX_DOUBLE: nsamp = ent.size // (16 * nchan) else: - raise ValueError('Cannot handle data buffers of type ' - '%d' % ent.type) + raise ValueError( + "Cannot handle data buffers of type " "%d" % ent.type + ) if orig_format is None: if ent.type == FIFF.FIFFT_DAU_PACK16: - orig_format = 'short' + orig_format = "short" elif ent.type == FIFF.FIFFT_SHORT: - orig_format = 'short' + orig_format = "short" elif ent.type == FIFF.FIFFT_FLOAT: - orig_format = 'single' + orig_format = "single" elif ent.type == FIFF.FIFFT_DOUBLE: - orig_format = 'double' + orig_format = "double" elif ent.type == FIFF.FIFFT_INT: - orig_format = 'int' + orig_format = "int" elif ent.type == FIFF.FIFFT_COMPLEX_FLOAT: - orig_format = 'single' + orig_format = "single" elif ent.type == FIFF.FIFFT_COMPLEX_DOUBLE: - orig_format = 'double' + orig_format = "double" # Do we have an initial skip pending? if first_skip > 0: @@ -264,58 +294,72 @@ def _read_raw_file(self, fname, allow_maxshield, preload, # Do we have a skip pending? if nskip > 0: - raw_extras.append(dict( - ent=None, first=first_samp, nsamp=nskip * nsamp, - last=first_samp + nskip * nsamp - 1)) + raw_extras.append( + dict( + ent=None, + first=first_samp, + nsamp=nskip * nsamp, + last=first_samp + nskip * nsamp - 1, + ) + ) first_samp += nskip * nsamp nskip = 0 # Add a data buffer - raw_extras.append(dict(ent=ent, first=first_samp, - last=first_samp + nsamp - 1, - nsamp=nsamp)) + raw_extras.append( + dict( + ent=ent, + first=first_samp, + last=first_samp + nsamp - 1, + nsamp=nsamp, + ) + ) first_samp += nsamp next_fname = _get_next_fname(fid, fname_rep, tree) # reformat raw_extras to be a dict of list/ndarray rather than # list of dict (faster access) - raw_extras = {key: [r[key] for r in raw_extras] - for key in raw_extras[0]} + raw_extras = {key: [r[key] for r in raw_extras] for key in raw_extras[0]} for key in raw_extras: - if key != 'ent': # dict or None + if key != "ent": # dict or None raw_extras[key] = np.array(raw_extras[key], int) - if not np.array_equal(raw_extras['last'][:-1], - raw_extras['first'][1:] - 1): - raise RuntimeError('FIF file appears to be broken') - bounds = np.cumsum(np.concatenate( - [raw_extras['first'][:1], raw_extras['nsamp']])) - raw_extras['bounds'] = bounds - assert len(raw_extras['bounds']) == len(raw_extras['ent']) + 1 + if not np.array_equal(raw_extras["last"][:-1], raw_extras["first"][1:] - 1): + raise RuntimeError("FIF file appears to be broken") + bounds = np.cumsum( + np.concatenate([raw_extras["first"][:1], raw_extras["nsamp"]]) + ) + raw_extras["bounds"] = bounds + assert len(raw_extras["bounds"]) == len(raw_extras["ent"]) + 1 # store the original buffer size - buffer_size_sec = np.median(raw_extras['nsamp']) / info['sfreq'] - del raw_extras['first'] - del raw_extras['last'] - del raw_extras['nsamp'] + buffer_size_sec = np.median(raw_extras["nsamp"]) / info["sfreq"] + del raw_extras["first"] + del raw_extras["last"] + del raw_extras["nsamp"] raw.last_samp = first_samp - 1 raw.orig_format = orig_format # Add the calibration factors - cals = np.zeros(info['nchan']) - for k in range(info['nchan']): - cals[k] = info['chs'][k]['range'] * info['chs'][k]['cal'] + cals = np.zeros(info["nchan"]) + for k in range(info["nchan"]): + cals[k] = info["chs"][k]["range"] * info["chs"][k]["cal"] raw._cals = cals raw._raw_extras = raw_extras - logger.info(' Range : %d ... %d = %9.3f ... %9.3f secs' % ( - raw.first_samp, raw.last_samp, - float(raw.first_samp) / info['sfreq'], - float(raw.last_samp) / info['sfreq'])) + logger.info( + " Range : %d ... %d = %9.3f ... %9.3f secs" + % ( + raw.first_samp, + raw.last_samp, + float(raw.first_samp) / info["sfreq"], + float(raw.last_samp) / info["sfreq"], + ) + ) raw.info = info - logger.info('Ready.') + logger.info("Ready.") return raw, next_fname, buffer_size_sec @@ -326,14 +370,16 @@ def _dtype(self): return self._dtype_ dtype = None for raw_extra, filename in zip(self._raw_extras, self._filenames): - for ent in raw_extra['ent']: + for ent in raw_extra["ent"]: if ent is not None: with _fiff_get_fid(filename) as fid: fid.seek(ent.pos, 0) tag = read_tag_info(fid) if tag is not None: - if tag.type in (FIFF.FIFFT_COMPLEX_FLOAT, - FIFF.FIFFT_COMPLEX_DOUBLE): + if tag.type in ( + FIFF.FIFFT_COMPLEX_FLOAT, + FIFF.FIFFT_COMPLEX_DOUBLE, + ): dtype = np.complex128 else: dtype = np.float64 @@ -342,7 +388,7 @@ def _dtype(self): if dtype is not None: break if dtype is None: - raise RuntimeError('bug in reading') + raise RuntimeError("bug in reading") self._dtype_ = dtype return dtype @@ -350,9 +396,9 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): """Read a segment of data from a file.""" n_bad = 0 with _fiff_get_fid(self._filenames[fi]) as fid: - bounds = self._raw_extras[fi]['bounds'] - ents = self._raw_extras[fi]['ent'] - nchan = self._raw_extras[fi]['orig_nchan'] + bounds = self._raw_extras[fi]["bounds"] + ents = self._raw_extras[fi]["ent"] + nchan = self._raw_extras[fi]["orig_nchan"] use = (stop > bounds[:-1]) & (start < bounds[1:]) offset = 0 for ei in np.where(use)[0]: @@ -365,20 +411,30 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): picksamp = last_pick - first_pick # only read data if it exists if ent is not None: - one = read_tag(fid, ent.pos, - shape=(nsamp, nchan), - rlims=(first_pick, last_pick)).data + one = read_tag( + fid, + ent.pos, + shape=(nsamp, nchan), + rlims=(first_pick, last_pick), + ).data try: one.shape = (picksamp, nchan) except AttributeError: # one is None n_bad += picksamp else: - _mult_cal_one(data[:, offset:(offset + picksamp)], - one.T, idx, cals, mult) + _mult_cal_one( + data[:, offset : (offset + picksamp)], + one.T, + idx, + cals, + mult, + ) offset += picksamp if n_bad: - warn(f'FIF raw buffer could not be read, acquisition error ' - f'likely: {n_bad} samples set to zero') + warn( + f"FIF raw buffer could not be read, acquisition error " + f"likely: {n_bad} samples set to zero" + ) assert offset == stop - start def fix_mag_coil_types(self): @@ -410,6 +466,7 @@ def fix_mag_coil_types(self): Therefore the use of mne_fix_mag_coil_types is not mandatory. """ from ...channels import fix_mag_coil_types + fix_mag_coil_types(self.info) return self @@ -421,7 +478,7 @@ def acqparser(self): -------- mne.AcqParserFIF """ - if getattr(self, '_acqparser', None) is None: + if getattr(self, "_acqparser", None) is None: self._acqparser = AcqParserFIF(self.info) return self._acqparser @@ -430,18 +487,19 @@ def _get_fname_rep(fname): if not _file_like(fname): return fname else: - return 'File-like' + return "File-like" def _check_entry(first, nent): """Sanity check entries.""" if first >= nent: - raise OSError('Could not read data, perhaps this is a corrupt file') + raise OSError("Could not read data, perhaps this is a corrupt file") @fill_doc -def read_raw_fif(fname, allow_maxshield=False, preload=False, - on_split_missing='raise', verbose=None): +def read_raw_fif( + fname, allow_maxshield=False, preload=False, on_split_missing="raise", verbose=None +): """Reader function for Raw FIF data. Parameters @@ -479,6 +537,10 @@ def read_raw_fif(fname, allow_maxshield=False, preload=False, ``raw.n_times`` parameters but ``raw.first_samp`` and ``raw.first_time`` are updated accordingly. """ - return Raw(fname=fname, allow_maxshield=allow_maxshield, - preload=preload, verbose=verbose, - on_split_missing=on_split_missing) + return Raw( + fname=fname, + allow_maxshield=allow_maxshield, + preload=preload, + verbose=verbose, + on_split_missing=on_split_missing, + ) diff --git a/mne/io/fiff/tests/test_raw_fiff.py b/mne/io/fiff/tests/test_raw_fiff.py index febd699d2cd..806a7ce4dc4 100644 --- a/mne/io/fiff/tests/test_raw_fiff.py +++ b/mne/io/fiff/tests/test_raw_fiff.py @@ -14,24 +14,36 @@ import sys import numpy as np -from numpy.testing import (assert_array_almost_equal, assert_array_equal, - assert_allclose) +from numpy.testing import assert_array_almost_equal, assert_array_equal, assert_allclose import pytest from mne.datasets import testing from mne.filter import filter_data from mne.io.constants import FIFF -from mne.io import (RawArray, concatenate_raws, read_raw_fif, - match_channel_orders, base) +from mne.io import RawArray, concatenate_raws, read_raw_fif, match_channel_orders, base from mne.io.open import read_tag, read_tag_info from mne.io.tag import _read_tag_header from mne.io.tests.test_raw import _test_concat, _test_raw_reader -from mne import (concatenate_events, find_events, equalize_channels, - compute_proj_raw, pick_types, pick_channels, create_info, - pick_info, make_fixed_length_epochs) -from mne.utils import (requires_pandas, assert_object_equal, _dt_to_stamp, - requires_mne, run_subprocess, _record_warnings, - assert_and_remove_boundary_annot) +from mne import ( + concatenate_events, + find_events, + equalize_channels, + compute_proj_raw, + pick_types, + pick_channels, + create_info, + pick_info, + make_fixed_length_epochs, +) +from mne.utils import ( + requires_pandas, + assert_object_equal, + _dt_to_stamp, + requires_mne, + run_subprocess, + _record_warnings, + assert_and_remove_boundary_annot, +) from mne.annotations import Annotations testing_path = testing.data_path(download=False) @@ -61,28 +73,29 @@ def test_acq_skip(tmp_path): annotations = raw.annotations assert len(annotations) == 3 # there are 3 skips assert_allclose(annotations.onset, [14, 19, 23]) - assert_allclose(annotations.duration, [2., 2., 3.]) # inclusive! - data, times = raw.get_data( - picks, reject_by_annotation='omit', return_times=True) - expected_data, expected_times = zip(raw[picks, :2000], - raw[picks, 4000:7000], - raw[picks, 9000:11000], - raw[picks, 14000:17000]) + assert_allclose(annotations.duration, [2.0, 2.0, 3.0]) # inclusive! + data, times = raw.get_data(picks, reject_by_annotation="omit", return_times=True) + expected_data, expected_times = zip( + raw[picks, :2000], + raw[picks, 4000:7000], + raw[picks, 9000:11000], + raw[picks, 14000:17000], + ) expected_times = np.concatenate(list(expected_times), axis=-1) assert_allclose(times, expected_times) expected_data = list(expected_data) assert_allclose(data, np.concatenate(expected_data, axis=-1), atol=1e-22) # Check that acquisition skips are handled properly in filtering - kwargs = dict(l_freq=None, h_freq=50., fir_design='firwin') + kwargs = dict(l_freq=None, h_freq=50.0, fir_design="firwin") raw_filt = raw.copy().filter(picks=picks, **kwargs) for data in expected_data: - filter_data(data, raw.info['sfreq'], copy=False, **kwargs) - data = raw_filt.get_data(picks, reject_by_annotation='omit') + filter_data(data, raw.info["sfreq"], copy=False, **kwargs) + data = raw_filt.get_data(picks, reject_by_annotation="omit") assert_allclose(data, np.concatenate(expected_data, axis=-1), atol=1e-22) # Check that acquisition skips are handled properly during I/O - fname = tmp_path / 'test_raw.fif' + fname = tmp_path / "test_raw.fif" raw.save(fname, fmt=raw.orig_format) # first: file size should not increase much (orig data is missing # 7 of 17 buffers, so if we write them out it should increase the file @@ -100,8 +113,8 @@ def test_acq_skip(tmp_path): with _record_warnings() as w: raw.save(fname, buffer_size_sec=0.5, overwrite=True) assert len(w) == 0 - with pytest.warns(RuntimeWarning, match='did not fit evenly'): - raw.save(fname, buffer_size_sec=2., overwrite=True) + with pytest.warns(RuntimeWarning, match="did not fit evenly"): + raw.save(fname, buffer_size_sec=2.0, overwrite=True) def test_fix_types(): @@ -109,35 +122,34 @@ def test_fix_types(): for fname, change, bads in ( (hp_fif_fname, True, ["MEG0111"]), (test_fif_fname, False, []), - (ctf_fname, False, []) + (ctf_fname, False, []), ): raw = read_raw_fif(fname) raw.info["bads"] = bads - mag_picks = pick_types(raw.info, meg='mag', exclude=[]) + mag_picks = pick_types(raw.info, meg="mag", exclude=[]) other_picks = np.setdiff1d(np.arange(len(raw.ch_names)), mag_picks) # we don't actually have any files suffering from this problem, so # fake it if change: for ii in mag_picks: - raw.info['chs'][ii]['coil_type'] = FIFF.FIFFV_COIL_VV_MAG_T2 - orig_types = np.array([ch['coil_type'] for ch in raw.info['chs']]) + raw.info["chs"][ii]["coil_type"] = FIFF.FIFFV_COIL_VV_MAG_T2 + orig_types = np.array([ch["coil_type"] for ch in raw.info["chs"]]) raw.fix_mag_coil_types() - new_types = np.array([ch['coil_type'] for ch in raw.info['chs']]) + new_types = np.array([ch["coil_type"] for ch in raw.info["chs"]]) if not change: assert_array_equal(orig_types, new_types) else: assert_array_equal(orig_types[other_picks], new_types[other_picks]) - assert ((orig_types[mag_picks] != new_types[mag_picks]).all()) - assert ((new_types[mag_picks] == - FIFF.FIFFV_COIL_VV_MAG_T3).all()) + assert (orig_types[mag_picks] != new_types[mag_picks]).all() + assert (new_types[mag_picks] == FIFF.FIFFV_COIL_VV_MAG_T3).all() def test_concat(tmp_path): """Test RawFIF concatenation.""" # we trim the file to save lots of memory and some time raw = read_raw_fif(test_fif_fname) - raw.crop(0, 2.) - test_name = tmp_path / 'test_raw.fif' + raw.crop(0, 2.0) + test_name = tmp_path / "test_raw.fif" raw.save(test_name) # now run the standard test _test_concat(partial(read_raw_fif), test_name) @@ -152,7 +164,7 @@ def test_hash_raw(): raw_size = raw._size raw.load_data() raw_load_size = raw._size - assert (raw_size < raw_load_size) + assert raw_size < raw_load_size raw_2 = read_raw_fif(fif_fname).crop(0, 0.5) raw_2.load_data() assert hash(raw) == hash(raw_2) @@ -166,35 +178,35 @@ def test_hash_raw(): @testing.requires_testing_data def test_maxshield(): """Test maxshield warning.""" - with pytest.warns(RuntimeWarning, match='Internal Active Shielding') as w: + with pytest.warns(RuntimeWarning, match="Internal Active Shielding") as w: read_raw_fif(ms_fname, allow_maxshield=True) - assert ('test_raw_fiff.py' in w[0].filename) + assert "test_raw_fiff.py" in w[0].filename @testing.requires_testing_data def test_subject_info(tmp_path): """Test reading subject information.""" raw = read_raw_fif(fif_fname).crop(0, 1) - assert (raw.info['subject_info'] is None) + assert raw.info["subject_info"] is None # fake some subject data - keys = ['id', 'his_id', 'last_name', 'first_name', 'birthday', 'sex', - 'hand'] - vals = [1, 'foobar', 'bar', 'foo', (1901, 2, 3), 0, 1] + keys = ["id", "his_id", "last_name", "first_name", "birthday", "sex", "hand"] + vals = [1, "foobar", "bar", "foo", (1901, 2, 3), 0, 1] subject_info = dict() for key, val in zip(keys, vals): subject_info[key] = val - raw.info['subject_info'] = subject_info - out_fname = tmp_path / 'test_subj_info_raw.fif' + raw.info["subject_info"] = subject_info + out_fname = tmp_path / "test_subj_info_raw.fif" raw.save(out_fname, overwrite=True) raw_read = read_raw_fif(out_fname) for key in keys: - assert subject_info[key] == raw_read.info['subject_info'][key] - assert raw.info['meas_date'] == raw_read.info['meas_date'] + assert subject_info[key] == raw_read.info["subject_info"][key] + assert raw.info["meas_date"] == raw_read.info["meas_date"] - for key in ['secs', 'usecs', 'version']: - assert raw.info['meas_id'][key] == raw_read.info['meas_id'][key] - assert_array_equal(raw.info['meas_id']['machid'], - raw_read.info['meas_id']['machid']) + for key in ["secs", "usecs", "version"]: + assert raw.info["meas_id"][key] == raw_read.info["meas_id"][key] + assert_array_equal( + raw.info["meas_id"]["machid"], raw_read.info["meas_id"]["machid"] + ) @testing.requires_testing_data @@ -210,13 +222,13 @@ def test_copy_append(): @testing.requires_testing_data def test_output_formats(tmp_path): """Test saving and loading raw data using multiple formats.""" - formats = ['short', 'int', 'single', 'double'] + formats = ["short", "int", "single", "double"] tols = [1e-4, 1e-7, 1e-7, 1e-15] # let's fake a raw file with different formats raw = read_raw_fif(test_fif_fname).crop(0, 1) - temp_file = tmp_path / 'raw.fif' + temp_file = tmp_path / "raw.fif" for ii, (fmt, tol) in enumerate(zip(formats, tols)): # Let's test the overwriting error throwing while we're at it if ii > 0: @@ -244,10 +256,10 @@ def test_multiple_files(tmp_path): raw = read_raw_fif(fif_fname).crop(0, 10) raw.load_data() raw.load_data() # test no operation - split_size = 3. # in seconds - sfreq = raw.info['sfreq'] - nsamp = (raw.last_samp - raw.first_samp) - tmins = np.round(np.arange(0., nsamp, split_size * sfreq)) + split_size = 3.0 # in seconds + sfreq = raw.info["sfreq"] + nsamp = raw.last_samp - raw.first_samp + tmins = np.round(np.arange(0.0, nsamp, split_size * sfreq)) tmaxs = np.concatenate((tmins[1:] - 1, [nsamp])) tmaxs /= sfreq tmins /= sfreq @@ -256,20 +268,20 @@ def test_multiple_files(tmp_path): # going in reverse order so the last fname is the first file (need later) raws = [None] * len(tmins) for ri in range(len(tmins) - 1, -1, -1): - fname = tmp_path / ('test_raw_split-%d_raw.fif' % ri) + fname = tmp_path / ("test_raw_split-%d_raw.fif" % ri) raw.save(fname, tmin=tmins[ri], tmax=tmaxs[ri]) raws[ri] = read_raw_fif(fname) - assert (len(raws[ri].times) == - int(round((tmaxs[ri] - tmins[ri]) * - raw.info['sfreq'])) + 1) # + 1 b/c inclusive - events = [find_events(r, stim_channel='STI 014') for r in raws] + assert ( + len(raws[ri].times) + == int(round((tmaxs[ri] - tmins[ri]) * raw.info["sfreq"])) + 1 + ) # + 1 b/c inclusive + events = [find_events(r, stim_channel="STI 014") for r in raws] last_samps = [r.last_samp for r in raws] first_samps = [r.first_samp for r in raws] # test concatenation of split file pytest.raises(ValueError, concatenate_raws, raws, True, events[1:]) - all_raw_1, events1 = concatenate_raws(raws, preload=False, - events_list=events) + all_raw_1, events1 = concatenate_raws(raws, preload=False, events_list=events) assert_allclose(all_raw_1.times, raw.times) assert raw.first_samp == all_raw_1.first_samp assert raw.last_samp == all_raw_1.last_samp @@ -280,7 +292,7 @@ def test_multiple_files(tmp_path): # test proper event treatment for split files events2 = concatenate_events(events, first_samps, last_samps) - events3 = find_events(all_raw_2, stim_channel='STI 014') + events3 = find_events(all_raw_2, stim_channel="STI 014") assert_array_equal(events1, events2) assert_array_equal(events1, events3) @@ -292,17 +304,17 @@ def test_multiple_files(tmp_path): # add potentially problematic points times.extend([n_times - 1, n_times, 2 * n_times - 1]) - raw_combo0 = concatenate_raws([read_raw_fif(f) - for f in [fif_fname, fif_fname]], - preload=True) + raw_combo0 = concatenate_raws( + [read_raw_fif(f) for f in [fif_fname, fif_fname]], preload=True + ) _compare_combo(raw, raw_combo0, times, n_times) - raw_combo = concatenate_raws([read_raw_fif(f) - for f in [fif_fname, fif_fname]], - preload=False) + raw_combo = concatenate_raws( + [read_raw_fif(f) for f in [fif_fname, fif_fname]], preload=False + ) _compare_combo(raw, raw_combo, times, n_times) - raw_combo = concatenate_raws([read_raw_fif(f) - for f in [fif_fname, fif_fname]], - preload='memmap8.dat') + raw_combo = concatenate_raws( + [read_raw_fif(f) for f in [fif_fname, fif_fname]], preload="memmap8.dat" + ) _compare_combo(raw, raw_combo, times, n_times) assert raw[:, :][0].shape[1] * 2 == raw_combo0[:, :][0].shape[1] assert raw_combo0[:, :][0].shape[1] == raw_combo0.n_times @@ -310,52 +322,63 @@ def test_multiple_files(tmp_path): # with all data preloaded, result should be preloaded raw_combo = read_raw_fif(fif_fname, preload=True) raw_combo.append(read_raw_fif(fif_fname, preload=True)) - assert (raw_combo.preload is True) + assert raw_combo.preload is True assert raw_combo.n_times == raw_combo._data.shape[1] _compare_combo(raw, raw_combo, times, n_times) # with any data not preloaded, don't set result as preloaded - raw_combo = concatenate_raws([read_raw_fif(fif_fname, preload=True), - read_raw_fif(fif_fname, preload=False)]) - assert (raw_combo.preload is False) - assert_array_equal(find_events(raw_combo, stim_channel='STI 014'), - find_events(raw_combo0, stim_channel='STI 014')) + raw_combo = concatenate_raws( + [read_raw_fif(fif_fname, preload=True), read_raw_fif(fif_fname, preload=False)] + ) + assert raw_combo.preload is False + assert_array_equal( + find_events(raw_combo, stim_channel="STI 014"), + find_events(raw_combo0, stim_channel="STI 014"), + ) _compare_combo(raw, raw_combo, times, n_times) # user should be able to force data to be preloaded upon concat - raw_combo = concatenate_raws([read_raw_fif(fif_fname, preload=False), - read_raw_fif(fif_fname, preload=True)], - preload=True) - assert (raw_combo.preload is True) + raw_combo = concatenate_raws( + [read_raw_fif(fif_fname, preload=False), read_raw_fif(fif_fname, preload=True)], + preload=True, + ) + assert raw_combo.preload is True _compare_combo(raw, raw_combo, times, n_times) - raw_combo = concatenate_raws([read_raw_fif(fif_fname, preload=False), - read_raw_fif(fif_fname, preload=True)], - preload='memmap3.dat') + raw_combo = concatenate_raws( + [read_raw_fif(fif_fname, preload=False), read_raw_fif(fif_fname, preload=True)], + preload="memmap3.dat", + ) _compare_combo(raw, raw_combo, times, n_times) - raw_combo = concatenate_raws([ - read_raw_fif(fif_fname, preload=True), - read_raw_fif(fif_fname, preload=True)], preload='memmap4.dat') + raw_combo = concatenate_raws( + [read_raw_fif(fif_fname, preload=True), read_raw_fif(fif_fname, preload=True)], + preload="memmap4.dat", + ) _compare_combo(raw, raw_combo, times, n_times) - raw_combo = concatenate_raws([ - read_raw_fif(fif_fname, preload=False), - read_raw_fif(fif_fname, preload=False)], preload='memmap5.dat') + raw_combo = concatenate_raws( + [ + read_raw_fif(fif_fname, preload=False), + read_raw_fif(fif_fname, preload=False), + ], + preload="memmap5.dat", + ) _compare_combo(raw, raw_combo, times, n_times) # verify that combining raws with different projectors throws an exception raw.add_proj([], remove_existing=True) - pytest.raises(ValueError, raw.append, - read_raw_fif(fif_fname, preload=True)) + pytest.raises(ValueError, raw.append, read_raw_fif(fif_fname, preload=True)) # now test event treatment for concatenated raw files - events = [find_events(raw, stim_channel='STI 014'), - find_events(raw, stim_channel='STI 014')] + events = [ + find_events(raw, stim_channel="STI 014"), + find_events(raw, stim_channel="STI 014"), + ] last_samps = [raw.last_samp, raw.last_samp] first_samps = [raw.first_samp, raw.first_samp] events = concatenate_events(events, first_samps, last_samps) - events2 = find_events(raw_combo0, stim_channel='STI 014') + events2 = find_events(raw_combo0, stim_channel="STI 014") assert_array_equal(events, events2) # check out the len method @@ -364,21 +387,21 @@ def test_multiple_files(tmp_path): @testing.requires_testing_data -@pytest.mark.parametrize('on_mismatch', ('ignore', 'warn', 'raise')) +@pytest.mark.parametrize("on_mismatch", ("ignore", "warn", "raise")) def test_concatenate_raws(on_mismatch): """Test error handling during raw concatenation.""" raw = read_raw_fif(fif_fname).crop(0, 10) raws = [raw, raw.copy()] - raws[1].info['dev_head_t']['trans'] += 0.1 + raws[1].info["dev_head_t"]["trans"] += 0.1 kws = dict(raws=raws, on_mismatch=on_mismatch) - if on_mismatch == 'ignore': + if on_mismatch == "ignore": concatenate_raws(**kws) - elif on_mismatch == 'warn': - with pytest.warns(RuntimeWarning, match='different head positions'): + elif on_mismatch == "warn": + with pytest.warns(RuntimeWarning, match="different head positions"): concatenate_raws(**kws) - elif on_mismatch == 'raise': - with pytest.raises(ValueError, match='different head positions'): + elif on_mismatch == "raise": + with pytest.raises(ValueError, match="different head positions"): concatenate_raws(**kws) @@ -464,67 +487,82 @@ def test_concatenate_raws_order(): @testing.requires_testing_data -@pytest.mark.parametrize('mod', ( - 'meg', - pytest.param('raw', marks=[ - pytest.mark.filterwarnings( - 'ignore:.*naming conventions.*:RuntimeWarning'), - pytest.mark.slowtest]), -)) +@pytest.mark.parametrize( + "mod", + ( + "meg", + pytest.param( + "raw", + marks=[ + pytest.mark.filterwarnings( + "ignore:.*naming conventions.*:RuntimeWarning" + ), + pytest.mark.slowtest, + ], + ), + ), +) def test_split_files(tmp_path, mod, monkeypatch): """Test writing and reading of split raw files.""" raw_1 = read_raw_fif(fif_fname, preload=True) # Test a very close corner case - assert_allclose(raw_1.buffer_size_sec, 10., atol=1e-2) # samp rate - split_fname = tmp_path / f'split_raw_{mod}.fif' + assert_allclose(raw_1.buffer_size_sec, 10.0, atol=1e-2) # samp rate + split_fname = tmp_path / f"split_raw_{mod}.fif" # intended filenames - split_fname_elekta_part2 = tmp_path / f'split_raw_{mod}-1.fif' - split_fname_bids_part1 = tmp_path / f'split_raw_split-01_{mod}.fif' - split_fname_bids_part2 = tmp_path / f'split_raw_split-02_{mod}.fif' - raw_1.set_annotations(Annotations([2.], [5.5], 'test')) + split_fname_elekta_part2 = tmp_path / f"split_raw_{mod}-1.fif" + split_fname_bids_part1 = tmp_path / f"split_raw_split-01_{mod}.fif" + split_fname_bids_part2 = tmp_path / f"split_raw_split-02_{mod}.fif" + raw_1.set_annotations(Annotations([2.0], [5.5], "test")) # Check that if BIDS is used and no split is needed it defaults to # simple writing without _split- entity. - raw_1.save(split_fname, split_naming='bids', verbose=True) + raw_1.save(split_fname, split_naming="bids", verbose=True) assert split_fname.is_file() assert not split_fname_bids_part1.is_file() - for split_naming in ('neuromag', 'bids'): - with pytest.raises(FileExistsError, match='Destination file'): + for split_naming in ("neuromag", "bids"): + with pytest.raises(FileExistsError, match="Destination file"): raw_1.save(split_fname, split_naming=split_naming, verbose=True) os.remove(split_fname) - with open(split_fname_bids_part1, 'w'): + with open(split_fname_bids_part1, "w"): pass - with pytest.raises(FileExistsError, match='Destination file'): - raw_1.save(split_fname, split_naming='bids', verbose=True) + with pytest.raises(FileExistsError, match="Destination file"): + raw_1.save(split_fname, split_naming="bids", verbose=True) assert not split_fname.is_file() - raw_1.save(split_fname, split_naming='neuromag', verbose=True) # okay + raw_1.save(split_fname, split_naming="neuromag", verbose=True) # okay os.remove(split_fname) os.remove(split_fname_bids_part1) - raw_1.save(split_fname, buffer_size_sec=1.0, split_size='10MB', - verbose=True) + raw_1.save(split_fname, buffer_size_sec=1.0, split_size="10MB", verbose=True) # check that the filenames match the intended pattern assert split_fname.is_file() assert split_fname_elekta_part2.is_file() # check that filenames are being formatted correctly for BIDS - raw_1.save(split_fname, buffer_size_sec=1.0, split_size='10MB', - split_naming='bids', overwrite=True, verbose=True) + raw_1.save( + split_fname, + buffer_size_sec=1.0, + split_size="10MB", + split_naming="bids", + overwrite=True, + verbose=True, + ) assert split_fname_bids_part1.is_file() assert split_fname_bids_part2.is_file() - annot = Annotations(np.arange(20), np.ones((20,)), 'test') + annot = Annotations(np.arange(20), np.ones((20,)), "test") raw_1.set_annotations(annot) split_fname = tmp_path / "split_raw.fif" - raw_1.save(split_fname, buffer_size_sec=1.0, split_size='10MB') + raw_1.save(split_fname, buffer_size_sec=1.0, split_size="10MB") raw_2 = read_raw_fif(split_fname) - assert_allclose(raw_2.buffer_size_sec, 1., atol=1e-2) # samp rate + assert_allclose(raw_2.buffer_size_sec, 1.0, atol=1e-2) # samp rate assert_allclose(raw_1.annotations.onset, raw_2.annotations.onset) - assert_allclose(raw_1.annotations.duration, raw_2.annotations.duration, - rtol=0.001 / raw_2.info['sfreq']) - assert_array_equal(raw_1.annotations.description, - raw_2.annotations.description) + assert_allclose( + raw_1.annotations.duration, + raw_2.annotations.duration, + rtol=0.001 / raw_2.info["sfreq"], + ) + assert_array_equal(raw_1.annotations.description, raw_2.annotations.description) data_1, times_1 = raw_1[:, :] data_2, times_2 = raw_2[:, :] @@ -538,11 +576,11 @@ def test_split_files(tmp_path, mod, monkeypatch): del raw_bids # split missing behaviors os.remove(split_fname_bids_part2) - with pytest.raises(ValueError, match='manually renamed'): - read_raw_fif(split_fname_bids_part1, on_split_missing='raise') - with pytest.warns(RuntimeWarning, match='Split raw file detected'): - read_raw_fif(split_fname_bids_part1, on_split_missing='warn') - read_raw_fif(split_fname_bids_part1, on_split_missing='ignore') + with pytest.raises(ValueError, match="manually renamed"): + read_raw_fif(split_fname_bids_part1, on_split_missing="raise") + with pytest.warns(RuntimeWarning, match="Split raw file detected"): + read_raw_fif(split_fname_bids_part1, on_split_missing="warn") + read_raw_fif(split_fname_bids_part1, on_split_missing="ignore") # test the case where we only end up with one buffer to write # (GH#3210). These tests rely on writing meas info and annotations @@ -550,21 +588,24 @@ def test_split_files(tmp_path, mod, monkeypatch): # somehow, the numbers below for e.g. split_size might need to be # adjusted. raw_crop = raw_1.copy().crop(0, 5) - raw_crop.set_annotations(Annotations([2.], [5.5], 'test'), - emit_warning=False) - with pytest.raises(ValueError, - match='after writing measurement information'): - raw_crop.save(split_fname, split_size='1MB', # too small a size - buffer_size_sec=1., overwrite=True) - with pytest.raises(ValueError, - match='too large for the given split size'): - raw_crop.save(split_fname, - split_size=3003000, # still too small, now after Info - buffer_size_sec=1., overwrite=True) + raw_crop.set_annotations(Annotations([2.0], [5.5], "test"), emit_warning=False) + with pytest.raises(ValueError, match="after writing measurement information"): + raw_crop.save( + split_fname, + split_size="1MB", # too small a size + buffer_size_sec=1.0, + overwrite=True, + ) + with pytest.raises(ValueError, match="too large for the given split size"): + raw_crop.save( + split_fname, + split_size=3003000, # still too small, now after Info + buffer_size_sec=1.0, + overwrite=True, + ) # just barely big enough here; the right size to write exactly one buffer # at a time so we hit GH#3210 if we aren't careful - raw_crop.save(split_fname, split_size='4.5MB', - buffer_size_sec=1., overwrite=True) + raw_crop.save(split_fname, split_size="4.5MB", buffer_size_sec=1.0, overwrite=True) raw_read = read_raw_fif(split_fname) assert_allclose(raw_crop[:][0], raw_read[:][0], atol=1e-20) @@ -572,46 +613,45 @@ def test_split_files(tmp_path, mod, monkeypatch): # 1 buffer required raw_crop = raw_1.copy().crop(0, 1) - raw_crop.save(split_fname, buffer_size_sec=1., overwrite=True) + raw_crop.save(split_fname, buffer_size_sec=1.0, overwrite=True) raw_read = read_raw_fif(split_fname) - assert_array_equal(np.diff(raw_read._raw_extras[0]['bounds']), (301,)) + assert_array_equal(np.diff(raw_read._raw_extras[0]["bounds"]), (301,)) assert_allclose(raw_crop[:][0], raw_read[:][0]) # 2 buffers required raw_crop.save(split_fname, buffer_size_sec=0.5, overwrite=True) raw_read = read_raw_fif(split_fname) - assert_array_equal(np.diff(raw_read._raw_extras[0]['bounds']), (151, 150)) + assert_array_equal(np.diff(raw_read._raw_extras[0]["bounds"]), (151, 150)) assert_allclose(raw_crop[:][0], raw_read[:][0]) # 2 buffers required - raw_crop.save(split_fname, - buffer_size_sec=1. - 1.01 / raw_crop.info['sfreq'], - overwrite=True) + raw_crop.save( + split_fname, buffer_size_sec=1.0 - 1.01 / raw_crop.info["sfreq"], overwrite=True + ) raw_read = read_raw_fif(split_fname) - assert_array_equal(np.diff(raw_read._raw_extras[0]['bounds']), (300, 1)) + assert_array_equal(np.diff(raw_read._raw_extras[0]["bounds"]), (300, 1)) assert_allclose(raw_crop[:][0], raw_read[:][0]) - raw_crop.save(split_fname, - buffer_size_sec=1. - 2.01 / raw_crop.info['sfreq'], - overwrite=True) + raw_crop.save( + split_fname, buffer_size_sec=1.0 - 2.01 / raw_crop.info["sfreq"], overwrite=True + ) raw_read = read_raw_fif(split_fname) - assert_array_equal(np.diff(raw_read._raw_extras[0]['bounds']), (299, 2)) + assert_array_equal(np.diff(raw_read._raw_extras[0]["bounds"]), (299, 2)) assert_allclose(raw_crop[:][0], raw_read[:][0]) # proper ending assert tmp_path.is_dir() - with pytest.raises(ValueError, match='must end with an underscore'): - raw_crop.save( - tmp_path / 'test.fif', split_naming='bids', verbose='error') + with pytest.raises(ValueError, match="must end with an underscore"): + raw_crop.save(tmp_path / "test.fif", split_naming="bids", verbose="error") # reserved file is deleted - fname = tmp_path / 'test_raw.fif' - monkeypatch.setattr(base, '_write_raw_fid', _err) - with pytest.raises(RuntimeError, match='Killed mid-write'): - raw_1.save(fname, split_size='10MB', split_naming='bids') + fname = tmp_path / "test_raw.fif" + monkeypatch.setattr(base, "_write_raw_fid", _err) + with pytest.raises(RuntimeError, match="Killed mid-write"): + raw_1.save(fname, split_size="10MB", split_naming="bids") assert fname.is_file() assert not (tmp_path / "test_split-01_raw.fif").is_file() def _err(*args, **kwargs): - raise RuntimeError('Killed mid-write') + raise RuntimeError("Killed mid-write") def _no_write_file_name(fid, kind, data): @@ -621,12 +661,11 @@ def _no_write_file_name(fid, kind, data): def test_split_numbers(tmp_path, monkeypatch): """Test handling of split files using numbers instead of names.""" - monkeypatch.setattr(base, 'write_string', _no_write_file_name) - raw = read_raw_fif(test_fif_fname).pick('eeg') + monkeypatch.setattr(base, "write_string", _no_write_file_name) + raw = read_raw_fif(test_fif_fname).pick("eeg") # gh-8339 - dashes_fname = tmp_path / 'sub-1_ses-2_task-3_raw.fif' - raw.save(dashes_fname, split_size='5MB', - buffer_size_sec=1.) + dashes_fname = tmp_path / "sub-1_ses-2_task-3_raw.fif" + raw.save(dashes_fname, split_size="5MB", buffer_size_sec=1.0) assert dashes_fname.is_file() next_fname = Path(str(dashes_fname)[:-4] + "-1.fif") assert next_fname.is_file() @@ -639,37 +678,37 @@ def test_load_bad_channels(tmp_path): """Test reading/writing of bad channels.""" # Load correctly marked file (manually done in mne_process_raw) raw_marked = read_raw_fif(fif_bad_marked_fname) - correct_bads = raw_marked.info['bads'] + correct_bads = raw_marked.info["bads"] raw = read_raw_fif(test_fif_fname) # Make sure it starts clean - assert_array_equal(raw.info['bads'], []) + assert_array_equal(raw.info["bads"], []) # Test normal case raw.load_bad_channels(bad_file_works) # Write it out, read it in, and check - raw.save(tmp_path / 'foo_raw.fif') - raw_new = read_raw_fif(tmp_path / 'foo_raw.fif') - assert correct_bads == raw_new.info['bads'] + raw.save(tmp_path / "foo_raw.fif") + raw_new = read_raw_fif(tmp_path / "foo_raw.fif") + assert correct_bads == raw_new.info["bads"] # Reset it - raw.info['bads'] = [] + raw.info["bads"] = [] # Test bad case pytest.raises(ValueError, raw.load_bad_channels, bad_file_wrong) # Test forcing the bad case - with pytest.warns(RuntimeWarning, match='1 bad channel'): + with pytest.warns(RuntimeWarning, match="1 bad channel"): raw.load_bad_channels(bad_file_wrong, force=True) # write it out, read it in, and check - raw.save(tmp_path / 'foo_raw.fif', overwrite=True) - raw_new = read_raw_fif(tmp_path / 'foo_raw.fif') - assert correct_bads == raw_new.info['bads'] + raw.save(tmp_path / "foo_raw.fif", overwrite=True) + raw_new = read_raw_fif(tmp_path / "foo_raw.fif") + assert correct_bads == raw_new.info["bads"] # Check that bad channels are cleared raw.load_bad_channels(None) - raw.save(tmp_path / 'foo_raw.fif', overwrite=True) - raw_new = read_raw_fif(tmp_path / 'foo_raw.fif') - assert raw_new.info['bads'] == [] + raw.save(tmp_path / "foo_raw.fif", overwrite=True) + raw_new = read_raw_fif(tmp_path / "foo_raw.fif") + assert raw_new.info["bads"] == [] @pytest.mark.slowtest @@ -678,15 +717,15 @@ def test_io_raw(tmp_path): """Test IO for raw data (Neuromag).""" rng = np.random.RandomState(0) # test unicode io - for chars in [u'äöé', 'a']: + for chars in ["äöé", "a"]: with read_raw_fif(fif_fname) as r: - assert ('Raw' in repr(r)) - assert (fif_fname.name in repr(r)) - r.info['description'] = chars - temp_file = tmp_path / 'raw.fif' + assert "Raw" in repr(r) + assert fif_fname.name in repr(r) + r.info["description"] = chars + temp_file = tmp_path / "raw.fif" r.save(temp_file, overwrite=True) with read_raw_fif(temp_file) as r2: - desc2 = r2.info['description'] + desc2 = r2.info["description"] assert desc2 == chars # Let's construct a simple test for IO first @@ -696,7 +735,7 @@ def test_io_raw(tmp_path): data = rng.randn(raw._data.shape[0], raw._data.shape[1]) raw._data[:, :] = data # save it somewhere - fname = tmp_path / 'test_copy_raw.fif' + fname = tmp_path / "test_copy_raw.fif" raw.save(fname, buffer_size_sec=1.0) # read it in, make sure the whole thing matches raw = read_raw_fif(fname) @@ -707,48 +746,63 @@ def test_io_raw(tmp_path): assert_allclose(data[:, sl], raw[:, sl][0], rtol=1e-6, atol=1e-20) -@pytest.mark.parametrize('fname_in, fname_out', [ - (test_fif_fname, 'raw.fif'), - pytest.param(test_fif_gz_fname, 'raw.fif.gz', marks=pytest.mark.slowtest), - (ctf_fname, 'raw.fif')]) +@pytest.mark.parametrize( + "fname_in, fname_out", + [ + (test_fif_fname, "raw.fif"), + pytest.param(test_fif_gz_fname, "raw.fif.gz", marks=pytest.mark.slowtest), + (ctf_fname, "raw.fif"), + ], +) def test_io_raw_additional(fname_in, fname_out, tmp_path): """Test IO for raw data (Neuromag + CTF + gz).""" fname_out = tmp_path / fname_out raw = read_raw_fif(fname_in).crop(0, 2) - nchan = raw.info['nchan'] - ch_names = raw.info['ch_names'] - meg_channels_idx = [k for k in range(nchan) - if ch_names[k][0] == 'M'] + nchan = raw.info["nchan"] + ch_names = raw.info["ch_names"] + meg_channels_idx = [k for k in range(nchan) if ch_names[k][0] == "M"] n_channels = 100 meg_channels_idx = meg_channels_idx[:n_channels] start, stop = raw.time_as_index([0, 5], use_rounding=True) - data, times = raw[meg_channels_idx, start:(stop + 1)] + data, times = raw[meg_channels_idx, start : (stop + 1)] meg_ch_names = [ch_names[k] for k in meg_channels_idx] # Set up pick list: MEG + STI 014 - bad channels - include = ['STI 014'] + include = ["STI 014"] include += meg_ch_names - picks = pick_types(raw.info, meg=True, eeg=False, stim=True, - misc=True, ref_meg=True, include=include, - exclude='bads') + picks = pick_types( + raw.info, + meg=True, + eeg=False, + stim=True, + misc=True, + ref_meg=True, + include=include, + exclude="bads", + ) # Writing with drop_small_buffer True - raw.save(fname_out, picks, tmin=0, tmax=4, buffer_size_sec=3, - drop_small_buffer=True, overwrite=True) + raw.save( + fname_out, + picks, + tmin=0, + tmax=4, + buffer_size_sec=3, + drop_small_buffer=True, + overwrite=True, + ) raw2 = read_raw_fif(fname_out) sel = pick_channels(raw2.ch_names, meg_ch_names) data2, times2 = raw2[sel, :] - assert (times2.max() <= 3) + assert times2.max() <= 3 # Writing raw.save(fname_out, picks, tmin=0, tmax=5, overwrite=True) - if fname_in in ( - fif_fname, fif_fname.with_suffix(fif_fname.suffix + ".gz") - ): - assert len(raw.info['dig']) == 146 + if fname_in in (fif_fname, fif_fname.with_suffix(fif_fname.suffix + ".gz")): + assert len(raw.info["dig"]) == 146 raw2 = read_raw_fif(fname_out) @@ -757,44 +811,41 @@ def test_io_raw_additional(fname_in, fname_out, tmp_path): assert_allclose(data, data2, rtol=1e-6, atol=1e-20) assert_allclose(times, times2) - assert_allclose(raw.info['sfreq'], raw2.info['sfreq'], rtol=1e-5) + assert_allclose(raw.info["sfreq"], raw2.info["sfreq"], rtol=1e-5) # check transformations - for trans in ['dev_head_t', 'dev_ctf_t', 'ctf_head_t']: + for trans in ["dev_head_t", "dev_ctf_t", "ctf_head_t"]: if raw.info[trans] is None: - assert (raw2.info[trans] is None) + assert raw2.info[trans] is None else: - assert_array_equal(raw.info[trans]['trans'], - raw2.info[trans]['trans']) + assert_array_equal(raw.info[trans]["trans"], raw2.info[trans]["trans"]) # check transformation 'from' and 'to' - if trans.startswith('dev'): + if trans.startswith("dev"): from_id = FIFF.FIFFV_COORD_DEVICE else: from_id = FIFF.FIFFV_MNE_COORD_CTF_HEAD - if trans[4:8] == 'head': + if trans[4:8] == "head": to_id = FIFF.FIFFV_COORD_HEAD else: to_id = FIFF.FIFFV_MNE_COORD_CTF_HEAD for raw_ in [raw, raw2]: - assert raw_.info[trans]['from'] == from_id - assert raw_.info[trans]['to'] == to_id + assert raw_.info[trans]["from"] == from_id + assert raw_.info[trans]["to"] == to_id - if fname_in in ( - fif_fname, fif_fname.with_suffix(fif_fname.suffix + ".gz") - ): - assert_allclose(raw.info['dig'][0]['r'], raw2.info['dig'][0]['r']) + if fname_in in (fif_fname, fif_fname.with_suffix(fif_fname.suffix + ".gz")): + assert_allclose(raw.info["dig"][0]["r"], raw2.info["dig"][0]["r"]) # test warnings on bad filenames - raw_badname = tmp_path / 'test-bad-name.fif.gz' - with pytest.warns(RuntimeWarning, match='raw.fif'): + raw_badname = tmp_path / "test-bad-name.fif.gz" + with pytest.warns(RuntimeWarning, match="raw.fif"): raw.save(raw_badname) - with pytest.warns(RuntimeWarning, match='raw.fif'): + with pytest.warns(RuntimeWarning, match="raw.fif"): read_raw_fif(raw_badname) @testing.requires_testing_data -@pytest.mark.parametrize('dtype', ('complex128', 'complex64')) +@pytest.mark.parametrize("dtype", ("complex128", "complex64")) def test_io_complex(tmp_path, dtype): """Test IO with complex data types.""" rng = np.random.RandomState(0) @@ -805,14 +856,14 @@ def test_io_complex(tmp_path, dtype): raw_cp = raw.copy() raw_cp._data = np.array(raw_cp._data, dtype) raw_cp._data += imag_rand - with pytest.warns(RuntimeWarning, match='Saving .* complex data.'): - raw_cp.save(tmp_path / 'raw.fif', overwrite=True) + with pytest.warns(RuntimeWarning, match="Saving .* complex data."): + raw_cp.save(tmp_path / "raw.fif", overwrite=True) - raw2 = read_raw_fif(tmp_path / 'raw.fif') + raw2 = read_raw_fif(tmp_path / "raw.fif") raw2_data, _ = raw2[:] assert_allclose(raw2_data, raw_cp._data) # with preloading - raw2 = read_raw_fif(tmp_path / 'raw.fif', preload=True) + raw2 = read_raw_fif(tmp_path / "raw.fif", preload=True) raw2_data, _ = raw2[:] assert_allclose(raw2_data, raw_cp._data) assert_allclose(data_orig, raw_cp._data.real) @@ -821,7 +872,7 @@ def test_io_complex(tmp_path, dtype): @testing.requires_testing_data def test_getitem(): """Test getitem/indexing of Raw.""" - for preload in [False, True, 'memmap.dat']: + for preload in [False, True, "memmap.dat"]: raw = read_raw_fif(fif_fname, preload=preload) data, times = raw[0, :] data1, times1 = raw[0] @@ -836,11 +887,11 @@ def test_getitem(): assert_array_equal(times, times1) assert_array_equal(raw[raw.ch_names[0]][0][0], raw[0][0][0]) assert_array_equal( - raw[-10:-1, :][0], - raw[len(raw.ch_names) - 10:len(raw.ch_names) - 1, :][0]) - with pytest.raises(ValueError, match='No appropriate channels'): + raw[-10:-1, :][0], raw[len(raw.ch_names) - 10 : len(raw.ch_names) - 1, :][0] + ) + with pytest.raises(ValueError, match="No appropriate channels"): raw[slice(-len(raw.ch_names) - 1), slice(None)] - with pytest.raises(ValueError, match='must be'): + with pytest.raises(ValueError, match="must be"): raw[-1000] @@ -851,7 +902,7 @@ def test_proj(tmp_path): raw = read_raw_fif(fif_fname, preload=False) if proj: raw.apply_proj() - assert (all(p['active'] == proj for p in raw.info['projs'])) + assert all(p["active"] == proj for p in raw.info["projs"]) data, times = raw[0:2, :] data1, times1 = raw[0:2] @@ -860,19 +911,18 @@ def test_proj(tmp_path): # test adding / deleting proj if proj: - pytest.raises(ValueError, raw.add_proj, [], - {'remove_existing': True}) + pytest.raises(ValueError, raw.add_proj, [], {"remove_existing": True}) pytest.raises(ValueError, raw.del_proj, 0) else: - projs = deepcopy(raw.info['projs']) - n_proj = len(raw.info['projs']) + projs = deepcopy(raw.info["projs"]) + n_proj = len(raw.info["projs"]) raw.del_proj(0) - assert len(raw.info['projs']) == n_proj - 1 + assert len(raw.info["projs"]) == n_proj - 1 raw.add_proj(projs, remove_existing=False) # Test that already existing projections are not added. - assert len(raw.info['projs']) == n_proj + assert len(raw.info["projs"]) == n_proj raw.add_proj(projs[:-1], remove_existing=True) - assert len(raw.info['projs']) == n_proj - 1 + assert len(raw.info["projs"]) == n_proj - 1 # test apply_proj() with and without preload for preload in [True, False]: @@ -885,18 +935,18 @@ def test_proj(tmp_path): raw = read_raw_fif(fif_fname, preload=preload) # write the file with proj. activated, make sure proj has been applied - raw.save(tmp_path / 'raw.fif', proj=True, overwrite=True) - raw2 = read_raw_fif(tmp_path / 'raw.fif') + raw.save(tmp_path / "raw.fif", proj=True, overwrite=True) + raw2 = read_raw_fif(tmp_path / "raw.fif") data_proj_2, _ = raw2[:, 0:2] assert_allclose(data_proj_1, data_proj_2) - assert (all(p['active'] for p in raw2.info['projs'])) + assert all(p["active"] for p in raw2.info["projs"]) # read orig file with proj. active raw2 = read_raw_fif(fif_fname, preload=preload) raw2.apply_proj() data_proj_2, _ = raw2[:, 0:2] assert_allclose(data_proj_1, data_proj_2) - assert (all(p['active'] for p in raw2.info['projs'])) + assert all(p["active"] for p in raw2.info["projs"]) # test that apply_proj works raw.apply_proj() @@ -906,25 +956,25 @@ def test_proj(tmp_path): # Test that picking removes projectors ... raw = read_raw_fif(fif_fname) - n_projs = len(raw.info['projs']) + n_projs = len(raw.info["projs"]) raw.pick_types(meg=False, eeg=True) - assert len(raw.info['projs']) == n_projs - 3 + assert len(raw.info["projs"]) == n_projs - 3 # ... but only if it doesn't apply to any channels in the dataset anymore. raw = read_raw_fif(fif_fname) - n_projs = len(raw.info['projs']) - raw.pick_types(meg='mag', eeg=True) - assert len(raw.info['projs']) == n_projs + n_projs = len(raw.info["projs"]) + raw.pick_types(meg="mag", eeg=True) + assert len(raw.info["projs"]) == n_projs # I/O roundtrip of an MEG projector with a Raw that only contains EEG # data. - out_fname = tmp_path / 'test_raw.fif' + out_fname = tmp_path / "test_raw.fif" raw = read_raw_fif(test_fif_fname, preload=True).crop(0, 0.002) - proj = raw.info['projs'][-1] + proj = raw.info["projs"][-1] raw.pick_types(meg=False, eeg=True) raw.add_proj(proj) # Restore, because picking removed it! raw._data.fill(0) - raw._data[-1] = 1. + raw._data[-1] = 1.0 raw.save(out_fname) raw = read_raw_fif(out_fname, preload=False) raw.apply_proj() @@ -932,30 +982,30 @@ def test_proj(tmp_path): @testing.requires_testing_data -@pytest.mark.parametrize('preload', [False, True, 'memmap.dat']) +@pytest.mark.parametrize("preload", [False, True, "memmap.dat"]) def test_preload_modify(preload, tmp_path): """Test preloading and modifying data.""" rng = np.random.RandomState(0) raw = read_raw_fif(fif_fname, preload=preload) nsamp = raw.last_samp - raw.first_samp + 1 - picks = pick_types(raw.info, meg='grad', exclude='bads') + picks = pick_types(raw.info, meg="grad", exclude="bads") data = rng.randn(len(picks), nsamp // 2) try: - raw[picks, :nsamp // 2] = data + raw[picks, : nsamp // 2] = data except RuntimeError: if not preload: return else: raise - tmp_fname = tmp_path / 'raw.fif' + tmp_fname = tmp_path / "raw.fif" raw.save(tmp_fname, overwrite=True) raw_new = read_raw_fif(tmp_fname) - data_new, _ = raw_new[picks, :nsamp // 2] + data_new, _ = raw_new[picks, : nsamp // 2] assert_allclose(data, data_new) @@ -968,13 +1018,17 @@ def test_filter(): raw.load_data() sig_dec_notch = 12 sig_dec_notch_fit = 12 - picks_meg = pick_types(raw.info, meg=True, exclude='bads') + picks_meg = pick_types(raw.info, meg=True, exclude="bads") picks = picks_meg[:4] trans = 2.0 - filter_params = dict(picks=picks, filter_length='auto', - h_trans_bandwidth=trans, l_trans_bandwidth=trans, - fir_design='firwin') + filter_params = dict( + picks=picks, + filter_length="auto", + h_trans_bandwidth=trans, + l_trans_bandwidth=trans, + fir_design="firwin", + ) raw_lp = raw.copy().filter(None, 8.0, **filter_params) raw_hp = raw.copy().filter(16.0, None, **filter_params) raw_bp = raw.copy().filter(8.0 + trans, 16.0 - trans, **filter_params) @@ -992,8 +1046,9 @@ def test_filter(): assert_allclose(data, lp_data + bp_data + hp_data, **tols) assert_allclose(data, bp_data + bs_data, **tols) - filter_params_iir = dict(picks=picks, n_jobs=2, method='iir', - iir_params=dict(output='ba')) + filter_params_iir = dict( + picks=picks, n_jobs=2, method="iir", iir_params=dict(output="ba") + ) raw_lp_iir = raw.copy().filter(None, 4.0, **filter_params_iir) raw_hp_iir = raw.copy().filter(8.0, None, **filter_params_iir) raw_bp_iir = raw.copy().filter(4.0, 8.0, **filter_params_iir) @@ -1017,83 +1072,98 @@ def test_filter(): assert not np.may_share_memory(raw_copy._data, raw._data) # this could be assert_array_equal but we do this to mirror the call below assert (raw._data[0] == raw_copy._data[0]).all() - raw_copy.filter(None, 20., n_jobs=2, **filter_params) + raw_copy.filter(None, 20.0, n_jobs=2, **filter_params) assert not (raw._data[0] == raw_copy._data[0]).all() - assert_array_equal(raw.copy().filter(None, 20., **filter_params)._data, - raw_copy._data) + assert_array_equal( + raw.copy().filter(None, 20.0, **filter_params)._data, raw_copy._data + ) # do a very simple check on line filtering raw_bs = raw.copy().filter(60.0 + trans, 60.0 - trans, **filter_params) data_bs, _ = raw_bs[picks, :] raw_notch = raw.copy().notch_filter( - 60.0, picks=picks, n_jobs=2, method='fir', - trans_bandwidth=2 * trans) + 60.0, picks=picks, n_jobs=2, method="fir", trans_bandwidth=2 * trans + ) data_notch, _ = raw_notch[picks, :] assert_array_almost_equal(data_bs, data_notch, sig_dec_notch) # now use the sinusoidal fitting assert raw.times[-1] < 10 # catch error with filter_length > n_times raw_notch = raw.copy().notch_filter( - None, picks=picks, n_jobs=2, method='spectrum_fit', - filter_length='10s') + None, picks=picks, n_jobs=2, method="spectrum_fit", filter_length="10s" + ) data_notch, _ = raw_notch[picks, :] data, _ = raw[picks, :] assert_array_almost_equal(data, data_notch, sig_dec_notch_fit) # filter should set the "lowpass" and "highpass" parameters - raw = RawArray(np.random.randn(3, 1000), - create_info(3, 1000., ['eeg'] * 2 + ['stim'])) + raw = RawArray( + np.random.randn(3, 1000), create_info(3, 1000.0, ["eeg"] * 2 + ["stim"]) + ) with raw.info._unlock(): - raw.info['lowpass'] = raw.info['highpass'] = None - for kind in ('none', 'lowpass', 'highpass', 'bandpass', 'bandstop'): + raw.info["lowpass"] = raw.info["highpass"] = None + for kind in ("none", "lowpass", "highpass", "bandpass", "bandstop"): print(kind) h_freq = l_freq = None - if kind in ('lowpass', 'bandpass'): + if kind in ("lowpass", "bandpass"): h_freq = 70 - if kind in ('highpass', 'bandpass'): + if kind in ("highpass", "bandpass"): l_freq = 30 - if kind == 'bandstop': + if kind == "bandstop": l_freq, h_freq = 70, 30 - assert (raw.info['lowpass'] is None) - assert (raw.info['highpass'] is None) - kwargs = dict(l_trans_bandwidth=20, h_trans_bandwidth=20, - filter_length='auto', phase='zero', fir_design='firwin') - raw_filt = raw.copy().filter(l_freq, h_freq, picks=np.arange(1), - **kwargs) - assert (raw.info['lowpass'] is None) - assert (raw.info['highpass'] is None) + assert raw.info["lowpass"] is None + assert raw.info["highpass"] is None + kwargs = dict( + l_trans_bandwidth=20, + h_trans_bandwidth=20, + filter_length="auto", + phase="zero", + fir_design="firwin", + ) + raw_filt = raw.copy().filter(l_freq, h_freq, picks=np.arange(1), **kwargs) + assert raw.info["lowpass"] is None + assert raw.info["highpass"] is None raw_filt = raw.copy().filter(l_freq, h_freq, **kwargs) - wanted_h = h_freq if kind != 'bandstop' else None - wanted_l = l_freq if kind != 'bandstop' else None - assert raw_filt.info['lowpass'] == wanted_h - assert raw_filt.info['highpass'] == wanted_l + wanted_h = h_freq if kind != "bandstop" else None + wanted_l = l_freq if kind != "bandstop" else None + assert raw_filt.info["lowpass"] == wanted_h + assert raw_filt.info["highpass"] == wanted_l # Using all data channels should still set the params (GH#3259) - raw_filt = raw.copy().filter(l_freq, h_freq, picks=np.arange(2), - **kwargs) - assert raw_filt.info['lowpass'] == wanted_h - assert raw_filt.info['highpass'] == wanted_l + raw_filt = raw.copy().filter(l_freq, h_freq, picks=np.arange(2), **kwargs) + assert raw_filt.info["lowpass"] == wanted_h + assert raw_filt.info["highpass"] == wanted_l def test_filter_picks(): """Test filtering default channel picks.""" - ch_types = ['mag', 'grad', 'eeg', 'seeg', 'dbs', 'misc', 'stim', 'ecog', - 'hbo', 'hbr'] + ch_types = [ + "mag", + "grad", + "eeg", + "seeg", + "dbs", + "misc", + "stim", + "ecog", + "hbo", + "hbr", + ] info = create_info(ch_names=ch_types, ch_types=ch_types, sfreq=256) raw = RawArray(data=np.zeros((len(ch_types), 1000)), info=info) # -- Deal with meg mag grad and fnirs exceptions - ch_types = ('misc', 'stim', 'meg', 'eeg', 'seeg', 'dbs', 'ecog') + ch_types = ("misc", "stim", "meg", "eeg", "seeg", "dbs", "ecog") # -- Filter data channels - for ch_type in ('mag', 'grad', 'eeg', 'seeg', 'dbs', 'ecog', 'hbo', 'hbr'): + for ch_type in ("mag", "grad", "eeg", "seeg", "dbs", "ecog", "hbo", "hbr"): picks = {ch: ch == ch_type for ch in ch_types} - picks['meg'] = ch_type if ch_type in ('mag', 'grad') else False - picks['fnirs'] = ch_type if ch_type in ('hbo', 'hbr') else False + picks["meg"] = ch_type if ch_type in ("mag", "grad") else False + picks["fnirs"] = ch_type if ch_type in ("hbo", "hbr") else False raw_ = raw.copy().pick_types(**picks) - raw_.filter(10, 30, fir_design='firwin') + raw_.filter(10, 30, fir_design="firwin") # -- Error if no data channel - for ch_type in ('misc', 'stim'): + for ch_type in ("misc", "stim"): picks = {ch: ch == ch_type for ch in ch_types} raw_ = raw.copy().pick_types(**picks) pytest.raises(ValueError, raw_.filter, 10, 30) @@ -1103,14 +1173,13 @@ def test_filter_picks(): def test_crop(): """Test cropping raw files.""" # split a concatenated file to test a difficult case - raw = concatenate_raws([read_raw_fif(f) - for f in [fif_fname, fif_fname]]) - split_size = 10. # in seconds - sfreq = raw.info['sfreq'] - nsamp = (raw.last_samp - raw.first_samp + 1) + raw = concatenate_raws([read_raw_fif(f) for f in [fif_fname, fif_fname]]) + split_size = 10.0 # in seconds + sfreq = raw.info["sfreq"] + nsamp = raw.last_samp - raw.first_samp + 1 # do an annoying case (off-by-one splitting) - tmins = np.r_[1., np.round(np.arange(0., nsamp - 1, split_size * sfreq))] + tmins = np.r_[1.0, np.round(np.arange(0.0, nsamp - 1, split_size * sfreq))] tmins = np.sort(tmins) tmaxs = np.concatenate((tmins[1:] - 1, [nsamp - 1])) tmaxs /= sfreq @@ -1121,14 +1190,15 @@ def test_crop(): if ri < len(tmins) - 1: assert_allclose( raws[ri].times, - raw.copy().crop(tmin, tmins[ri + 1], include_tmax=False).times) + raw.copy().crop(tmin, tmins[ri + 1], include_tmax=False).times, + ) assert raws[ri] all_raw_2 = concatenate_raws(raws, preload=False) assert raw.first_samp == all_raw_2.first_samp assert raw.last_samp == all_raw_2.last_samp assert_array_equal(raw[:, :][0], all_raw_2[:, :][0]) - tmins = np.round(np.arange(0., nsamp - 1, split_size * sfreq)) + tmins = np.round(np.arange(0.0, nsamp - 1, split_size * sfreq)) tmaxs = np.concatenate((tmins[1:] - 1, [nsamp - 1])) tmaxs /= sfreq tmins /= sfreq @@ -1156,20 +1226,20 @@ def test_crop(): assert raw1[:][0].shape == (1, 2001) # degenerate - with pytest.raises(ValueError, match='No samples.*when include_tmax=Fals'): + with pytest.raises(ValueError, match="No samples.*when include_tmax=Fals"): raw.crop(0, 0, include_tmax=False) # edge cases cropping to exact duration +/- 1 sample data = np.zeros((1, 100)) info = create_info(1, 100) raw = RawArray(data, info) - with pytest.raises(ValueError, match='tmax \\(1\\) must be less than or '): + with pytest.raises(ValueError, match="tmax \\(1\\) must be less than or "): raw.copy().crop(tmax=1, include_tmax=True) - raw1 = raw.copy().crop(tmax=1 - 1 / raw.info['sfreq'], include_tmax=True) + raw1 = raw.copy().crop(tmax=1 - 1 / raw.info["sfreq"], include_tmax=True) assert raw.n_times == raw1.n_times raw2 = raw.copy().crop(tmax=1, include_tmax=False) assert raw.n_times == raw2.n_times - raw3 = raw.copy().crop(tmax=1 - 1 / raw.info['sfreq'], include_tmax=False) + raw3 = raw.copy().crop(tmax=1 - 1 / raw.info["sfreq"], include_tmax=False) assert raw.n_times - 1 == raw3.n_times @@ -1179,16 +1249,19 @@ def test_resample_equiv(): raw = read_raw_fif(fif_fname).crop(0, 1) raw_preload = raw.copy().load_data() for r in (raw, raw_preload): - r.resample(r.info['sfreq'] / 4.) + r.resample(r.info["sfreq"] / 4.0) assert_allclose(raw._data, raw_preload._data) @pytest.mark.slowtest @testing.requires_testing_data -@pytest.mark.parametrize('preload, n, npad', [ - (True, 512, 'auto'), - (False, 512, 0), -]) +@pytest.mark.parametrize( + "preload, n, npad", + [ + (True, 512, "auto"), + (False, 512, 0), + ], +) def test_resample(tmp_path, preload, n, npad): """Test resample (with I/O and multiple files).""" raw = read_raw_fif(fif_fname) @@ -1197,19 +1270,19 @@ def test_resample(tmp_path, preload, n, npad): if preload: raw.load_data() raw_resamp = raw.copy() - sfreq = raw.info['sfreq'] + sfreq = raw.info["sfreq"] # test parallel on upsample raw_resamp.resample(sfreq * 2, n_jobs=2, npad=npad) assert raw_resamp.n_times == len(raw_resamp.times) - raw_resamp.save(tmp_path / 'raw_resamp-raw.fif') - raw_resamp = read_raw_fif(tmp_path / 'raw_resamp-raw.fif', preload=True) - assert sfreq == raw_resamp.info['sfreq'] / 2 + raw_resamp.save(tmp_path / "raw_resamp-raw.fif") + raw_resamp = read_raw_fif(tmp_path / "raw_resamp-raw.fif", preload=True) + assert sfreq == raw_resamp.info["sfreq"] / 2 assert raw.n_times == raw_resamp.n_times // 2 assert raw_resamp.get_data().shape[1] == raw_resamp.n_times assert raw.get_data().shape[0] == raw_resamp._data.shape[0] # test non-parallel on downsample raw_resamp.resample(sfreq, n_jobs=None, npad=npad) - assert raw_resamp.info['sfreq'] == sfreq + assert raw_resamp.info["sfreq"] == sfreq assert raw.get_data().shape == raw_resamp._data.shape assert raw.first_samp == raw_resamp.first_samp assert raw.last_samp == raw.last_samp @@ -1217,12 +1290,18 @@ def test_resample(tmp_path, preload, n, npad): # works (hooray). Note that the stim channels had to be sub-sampled # without filtering to be accurately preserved # note we have to treat MEG and EEG+STIM channels differently (tols) - assert_allclose(raw.get_data()[:306, 200:-200], - raw_resamp._data[:306, 200:-200], - rtol=1e-2, atol=1e-12) - assert_allclose(raw.get_data()[306:, 200:-200], - raw_resamp._data[306:, 200:-200], - rtol=1e-2, atol=1e-7) + assert_allclose( + raw.get_data()[:306, 200:-200], + raw_resamp._data[:306, 200:-200], + rtol=1e-2, + atol=1e-12, + ) + assert_allclose( + raw.get_data()[306:, 200:-200], + raw_resamp._data[306:, 200:-200], + rtol=1e-2, + atol=1e-7, + ) # now check multiple file support w/resampling, as order of operations # (concat, resample) should not affect our data @@ -1231,9 +1310,9 @@ def test_resample(tmp_path, preload, n, npad): raw3 = raw.copy() raw4 = raw.copy() raw1 = concatenate_raws([raw1, raw2]) - raw1.resample(10., npad=npad) - raw3.resample(10., npad=npad) - raw4.resample(10., npad=npad) + raw1.resample(10.0, npad=npad) + raw3.resample(10.0, npad=npad) + raw4.resample(10.0, npad=npad) raw3 = concatenate_raws([raw3, raw4]) assert_array_equal(raw1._data, raw3._data) assert_array_equal(raw1._first_samps, raw3._first_samps) @@ -1241,7 +1320,7 @@ def test_resample(tmp_path, preload, n, npad): assert_array_equal(raw1._raw_lengths, raw3._raw_lengths) assert raw1.first_samp == raw3.first_samp assert raw1.last_samp == raw3.last_samp - assert raw1.info['sfreq'] == raw3.info['sfreq'] + assert raw1.info["sfreq"] == raw3.info["sfreq"] # smoke test crop after resample raw4.crop(tmin=raw4.times[1], tmax=raw4.times[-1]) @@ -1250,34 +1329,33 @@ def test_resample(tmp_path, preload, n, npad): # basic decimation stim = [1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0] - raw = RawArray([stim], create_info(1, len(stim), ['stim'])) - assert_allclose(raw.resample(8., npad=npad)._data, - [[1, 1, 0, 0, 1, 1, 0, 0]]) + raw = RawArray([stim], create_info(1, len(stim), ["stim"])) + assert_allclose(raw.resample(8.0, npad=npad)._data, [[1, 1, 0, 0, 1, 1, 0, 0]]) # decimation of multiple stim channels - raw = RawArray(2 * [stim], create_info(2, len(stim), 2 * ['stim'])) - assert_allclose(raw.resample(8., npad=npad, verbose='error')._data, - [[1, 1, 0, 0, 1, 1, 0, 0], - [1, 1, 0, 0, 1, 1, 0, 0]]) + raw = RawArray(2 * [stim], create_info(2, len(stim), 2 * ["stim"])) + assert_allclose( + raw.resample(8.0, npad=npad, verbose="error")._data, + [[1, 1, 0, 0, 1, 1, 0, 0], [1, 1, 0, 0, 1, 1, 0, 0]], + ) # decimation that could potentially drop events if the decimation is # done naively stim = [0, 0, 0, 1, 1, 0, 0, 0] - raw = RawArray([stim], create_info(1, len(stim), ['stim'])) - assert_allclose(raw.resample(4., npad=npad)._data, - [[0, 1, 1, 0]]) + raw = RawArray([stim], create_info(1, len(stim), ["stim"])) + assert_allclose(raw.resample(4.0, npad=npad)._data, [[0, 1, 1, 0]]) # two events are merged in this case (warning) stim = [0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0] - raw = RawArray([stim], create_info(1, len(stim), ['stim'])) - with pytest.warns(RuntimeWarning, match='become unreliable'): - raw.resample(8., npad=npad) + raw = RawArray([stim], create_info(1, len(stim), ["stim"])) + with pytest.warns(RuntimeWarning, match="become unreliable"): + raw.resample(8.0, npad=npad) # events are dropped in this case (warning) stim = [0, 1, 1, 0, 0, 1, 1, 0] - raw = RawArray([stim], create_info(1, len(stim), ['stim'])) - with pytest.warns(RuntimeWarning, match='become unreliable'): - raw.resample(4., npad=npad) + raw = RawArray([stim], create_info(1, len(stim), ["stim"])) + with pytest.warns(RuntimeWarning, match="become unreliable"): + raw.resample(4.0, npad=npad) # test resampling events: this should no longer give a warning # we often have first_samp != 0, include it here too @@ -1286,8 +1364,7 @@ def test_resample(tmp_path, preload, n, npad): o_sfreq, sfreq_ratio = len(stim), 0.5 n_sfreq = o_sfreq * sfreq_ratio first_samp = len(stim) // 2 - raw = RawArray([stim], create_info(1, o_sfreq, ['stim']), - first_samp=first_samp) + raw = RawArray([stim], create_info(1, o_sfreq, ["stim"]), first_samp=first_samp) events = find_events(raw) raw, events = raw.resample(n_sfreq, events=events, npad=npad) # Try index into raw.times with resampled events: @@ -1297,62 +1374,80 @@ def test_resample(tmp_path, preload, n, npad): # https://docs.scipy.org/doc/numpy/reference/generated/numpy.around.html assert_array_equal( events, - np.array([[np.round(1 * sfreq_ratio) + n_fsamp, 0, 1], - [np.round(10 * sfreq_ratio) + n_fsamp, 0, 1], - [np.minimum(np.round(15 * sfreq_ratio), - raw._data.shape[1] - 1) + n_fsamp, 0, 1]])) + np.array( + [ + [np.round(1 * sfreq_ratio) + n_fsamp, 0, 1], + [np.round(10 * sfreq_ratio) + n_fsamp, 0, 1], + [ + np.minimum(np.round(15 * sfreq_ratio), raw._data.shape[1] - 1) + + n_fsamp, + 0, + 1, + ], + ] + ), + ) # test copy flag stim = [1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0] - raw = RawArray([stim], create_info(1, len(stim), ['stim'])) - raw_resampled = raw.copy().resample(4., npad=npad) - assert (raw_resampled is not raw) - raw_resampled = raw.resample(4., npad=npad) - assert (raw_resampled is raw) + raw = RawArray([stim], create_info(1, len(stim), ["stim"])) + raw_resampled = raw.copy().resample(4.0, npad=npad) + assert raw_resampled is not raw + raw_resampled = raw.resample(4.0, npad=npad) + assert raw_resampled is raw # resample should still work even when no stim channel is present - raw = RawArray(np.random.randn(1, 100), create_info(1, 100, ['eeg'])) + raw = RawArray(np.random.randn(1, 100), create_info(1, 100, ["eeg"])) with raw.info._unlock(): - raw.info['lowpass'] = 50. + raw.info["lowpass"] = 50.0 raw.resample(10, npad=npad) - assert raw.info['lowpass'] == 5. + assert raw.info["lowpass"] == 5.0 assert len(raw) == 10 def test_resample_stim(): """Test stim_picks argument.""" data = np.ones((2, 1000)) - info = create_info(2, 1000., ('eeg', 'misc')) + info = create_info(2, 1000.0, ("eeg", "misc")) raw = RawArray(data, info) - raw.resample(500., stim_picks='misc') + raw.resample(500.0, stim_picks="misc") @testing.requires_testing_data def test_hilbert(): """Test computation of analytic signal using hilbert.""" raw = read_raw_fif(fif_fname, preload=True) - picks_meg = pick_types(raw.info, meg=True, exclude='bads') + picks_meg = pick_types(raw.info, meg=True, exclude="bads") picks = picks_meg[:4] raw_filt = raw.copy() - raw_filt.filter(10, 20, picks=picks, l_trans_bandwidth='auto', - h_trans_bandwidth='auto', filter_length='auto', - phase='zero', fir_window='blackman', fir_design='firwin') + raw_filt.filter( + 10, + 20, + picks=picks, + l_trans_bandwidth="auto", + h_trans_bandwidth="auto", + filter_length="auto", + phase="zero", + fir_window="blackman", + fir_design="firwin", + ) raw_filt_2 = raw_filt.copy() raw2 = raw.copy() raw3 = raw.copy() - raw.apply_hilbert(picks, n_fft='auto') - raw2.apply_hilbert(picks, n_fft='auto', envelope=True) + raw.apply_hilbert(picks, n_fft="auto") + raw2.apply_hilbert(picks, n_fft="auto", envelope=True) # Test custom n_fft - raw_filt.apply_hilbert(picks, n_fft='auto') + raw_filt.apply_hilbert(picks, n_fft="auto") n_fft = 2 ** int(np.ceil(np.log2(raw_filt_2.n_times + 1000))) raw_filt_2.apply_hilbert(picks, n_fft=n_fft) assert raw_filt._data.shape == raw_filt_2._data.shape - assert_allclose(raw_filt._data[:, 50:-50], raw_filt_2._data[:, 50:-50], - atol=1e-13, rtol=1e-2) - with pytest.raises(ValueError, match='n_fft.*must be at least the number'): + assert_allclose( + raw_filt._data[:, 50:-50], raw_filt_2._data[:, 50:-50], atol=1e-13, rtol=1e-2 + ) + with pytest.raises(ValueError, match="n_fft.*must be at least the number"): raw3.apply_hilbert(picks, n_fft=raw3.n_times - 100) env = np.abs(raw._data[picks, :]) @@ -1381,53 +1476,60 @@ def test_raw_copy(): def test_to_data_frame(): """Test raw Pandas exporter.""" from pandas import Timedelta + raw = read_raw_fif(test_fif_fname).crop(0, 1).load_data() - df = raw.to_data_frame(index='time') - assert ((df.columns == raw.ch_names).all()) + df = raw.to_data_frame(index="time") + assert (df.columns == raw.ch_names).all() df = raw.to_data_frame(index=None) - assert ('time' in df.columns) + assert "time" in df.columns assert_array_equal(df.values[:, 1], raw._data[0] * 1e13) assert_array_equal(df.values[:, 3], raw._data[2] * 1e15) # test long format df_long = raw.to_data_frame(long_format=True) assert len(df_long) == raw.get_data().size - expected = ('time', 'channel', 'ch_type', 'value') + expected = ("time", "channel", "ch_type", "value") assert set(expected) == set(df_long.columns) # test bad time format - with pytest.raises(ValueError, match='not a valid time format. Valid'): - raw.to_data_frame(time_format='foo') + with pytest.raises(ValueError, match="not a valid time format. Valid"): + raw.to_data_frame(time_format="foo") # test time format error handling raw.set_meas_date(None) - with pytest.warns(RuntimeWarning, match='Cannot convert to Datetime when'): - df = raw.to_data_frame(time_format='datetime') - assert isinstance(df['time'].iloc[0], Timedelta) + with pytest.warns(RuntimeWarning, match="Cannot convert to Datetime when"): + df = raw.to_data_frame(time_format="datetime") + assert isinstance(df["time"].iloc[0], Timedelta) @requires_pandas -@pytest.mark.parametrize('time_format', (None, 'ms', 'timedelta', 'datetime')) +@pytest.mark.parametrize("time_format", (None, "ms", "timedelta", "datetime")) def test_to_data_frame_time_format(time_format): """Test time conversion in epochs Pandas exporter.""" from pandas import Timedelta, Timestamp, to_timedelta + raw = read_raw_fif(test_fif_fname, preload=True) # test time_format df = raw.to_data_frame(time_format=time_format) - dtypes = {None: np.float64, 'ms': np.int64, 'timedelta': Timedelta, - 'datetime': Timestamp} - assert isinstance(df['time'].iloc[0], dtypes[time_format]) + dtypes = { + None: np.float64, + "ms": np.int64, + "timedelta": Timedelta, + "datetime": Timestamp, + } + assert isinstance(df["time"].iloc[0], dtypes[time_format]) # test values _, times = raw[0, :10] - offset = 0. - if time_format == 'datetime': + offset = 0.0 + if time_format == "datetime": times += raw.first_time - offset = raw.info['meas_date'] - elif time_format == 'timedelta': - offset = Timedelta(0.) - funcs = {None: lambda x: x, - 'ms': lambda x: np.rint(x * 1e3).astype(int), # s → ms - 'timedelta': partial(to_timedelta, unit='s'), - 'datetime': partial(to_timedelta, unit='s') - } - assert_array_equal(funcs[time_format](times) + offset, df['time'][:10]) + offset = raw.info["meas_date"] + elif time_format == "timedelta": + offset = Timedelta(0.0) + funcs = { + None: lambda x: x, + "ms": lambda x: np.rint(x * 1e3).astype(int), # s → ms + "timedelta": partial(to_timedelta, unit="s"), + "datetime": partial(to_timedelta, unit="s"), + } + assert_array_equal(funcs[time_format](times) + offset, df["time"][:10]) def test_add_channels(): @@ -1441,29 +1543,29 @@ def test_add_channels(): raw_meg = raw.copy().pick_types(meg=True) raw_stim = raw.copy().pick_types(stim=True) raw_new = raw_meg.copy().add_channels([raw_eeg, raw_stim]) - assert ( - all(ch in raw_new.ch_names - for ch in list(raw_stim.ch_names) + list(raw_meg.ch_names)) + assert all( + ch in raw_new.ch_names + for ch in list(raw_stim.ch_names) + list(raw_meg.ch_names) ) raw_new = raw_meg.copy().add_channels([raw_eeg]) assert (ch in raw_new.ch_names for ch in raw.ch_names) assert_array_equal(raw_new[:, :][0], raw_eeg_meg[:, :][0]) assert_array_equal(raw_new[:, :][1], raw[:, :][1]) - assert (all(ch not in raw_new.ch_names for ch in raw_stim.ch_names)) + assert all(ch not in raw_new.ch_names for ch in raw_stim.ch_names) # Testing force updates - raw_arr_info = create_info(['1', '2'], raw_meg.info['sfreq'], 'eeg') - orig_head_t = raw_arr_info['dev_head_t'] + raw_arr_info = create_info(["1", "2"], raw_meg.info["sfreq"], "eeg") + orig_head_t = raw_arr_info["dev_head_t"] raw_arr = rng.randn(2, raw_eeg.n_times) raw_arr = RawArray(raw_arr, raw_arr_info) # This should error because of conflicts in Info - raw_arr.info['dev_head_t'] = orig_head_t - with pytest.raises(ValueError, match='mutually inconsistent dev_head_t'): + raw_arr.info["dev_head_t"] = orig_head_t + with pytest.raises(ValueError, match="mutually inconsistent dev_head_t"): raw_meg.copy().add_channels([raw_arr]) raw_meg.copy().add_channels([raw_arr], force_update_info=True) # Make sure that values didn't get overwritten - assert_object_equal(raw_arr.info['dev_head_t'], orig_head_t) + assert_object_equal(raw_arr.info["dev_head_t"], orig_head_t) # Make sure all variants work for simult in (False, True): # simultaneous adding or not raw_new = raw_meg.copy() @@ -1475,13 +1577,14 @@ def test_add_channels(): for other in (raw_meg, raw_stim, raw_eeg): assert_allclose( raw_new.copy().pick_channels(other.ch_names).get_data(), - other.get_data()) + other.get_data(), + ) # Now test errors raw_badsf = raw_eeg.copy() with raw_badsf.info._unlock(): - raw_badsf.info['sfreq'] = 3.1415927 - raw_eeg.crop(.5) + raw_badsf.info["sfreq"] = 3.1415927 + raw_eeg.crop(0.5) pytest.raises(RuntimeError, raw_meg.add_channels, [raw_nopre]) pytest.raises(RuntimeError, raw_meg.add_channels, [raw_badsf]) @@ -1493,23 +1596,23 @@ def test_add_channels(): @testing.requires_testing_data def test_save(tmp_path): """Test saving raw.""" - temp_fname = tmp_path / 'test_raw.fif' + temp_fname = tmp_path / "test_raw.fif" shutil.copyfile(fif_fname, temp_fname) raw = read_raw_fif(temp_fname, preload=False) # can't write over file being read - with pytest.raises(ValueError, match='to the same file'): + with pytest.raises(ValueError, match="to the same file"): raw.save(temp_fname) raw.load_data() # can't overwrite file without overwrite=True - with pytest.raises(OSError, match='file exists'): + with pytest.raises(OSError, match="file exists"): raw.save(fif_fname) # test abspath support and annotations - orig_time = _dt_to_stamp(raw.info['meas_date'])[0] + raw._first_time - annot = Annotations([10], [5], ['test'], orig_time=orig_time) + orig_time = _dt_to_stamp(raw.info["meas_date"])[0] + raw._first_time + annot = Annotations([10], [5], ["test"], orig_time=orig_time) raw.set_annotations(annot) annot = raw.annotations - new_fname = tmp_path / 'break_raw.fif' + new_fname = tmp_path / "break_raw.fif" raw.save(new_fname, overwrite=True) new_raw = read_raw_fif(new_fname, preload=False) pytest.raises(ValueError, new_raw.save, new_fname) @@ -1522,37 +1625,38 @@ def test_save(tmp_path): raw.set_meas_date(None) raw.save(new_fname, overwrite=True) new_raw = read_raw_fif(new_fname, preload=False) - assert new_raw.info['meas_date'] is None + assert new_raw.info["meas_date"] is None @testing.requires_testing_data def test_annotation_crop(tmp_path): """Test annotation sync after cropping and concatenating.""" - annot = Annotations([5., 11., 15.], [2., 1., 3.], ['test', 'test', 'test']) + annot = Annotations([5.0, 11.0, 15.0], [2.0, 1.0, 3.0], ["test", "test", "test"]) raw = read_raw_fif(fif_fname, preload=False) raw.set_annotations(annot) r1 = raw.copy().crop(2.5, 7.5) r2 = raw.copy().crop(12.5, 17.5) - r3 = raw.copy().crop(10., 12.) + r3 = raw.copy().crop(10.0, 12.0) raw = concatenate_raws([r1, r2, r3]) # segments reordered assert_and_remove_boundary_annot(raw, 2) onsets = raw.annotations.onset durations = raw.annotations.duration # 2*5s clips combined with annotations at 2.5s + 2s clip, annotation at 1s assert_array_almost_equal(onsets[:3], [47.95, 52.95, 56.46], decimal=2) - assert_array_almost_equal([2., 2.5, 1.], durations[:3], decimal=2) + assert_array_almost_equal([2.0, 2.5, 1.0], durations[:3], decimal=2) # test annotation clipping - orig_time = _dt_to_stamp(raw.info['meas_date']) - orig_time = orig_time[0] + orig_time[1] * 1e-6 + raw._first_time - 1. - annot = Annotations([0., raw.times[-1]], [2., 2.], 'test', orig_time) - with pytest.warns(RuntimeWarning, match='Limited .* expanding outside'): + orig_time = _dt_to_stamp(raw.info["meas_date"]) + orig_time = orig_time[0] + orig_time[1] * 1e-6 + raw._first_time - 1.0 + annot = Annotations([0.0, raw.times[-1]], [2.0, 2.0], "test", orig_time) + with pytest.warns(RuntimeWarning, match="Limited .* expanding outside"): raw.set_annotations(annot) - assert_allclose(raw.annotations.duration, - [1., 1. + 1. / raw.info['sfreq']], atol=1e-3) + assert_allclose( + raw.annotations.duration, [1.0, 1.0 + 1.0 / raw.info["sfreq"]], atol=1e-3 + ) # make sure we can overwrite the file we loaded when preload=True - new_fname = tmp_path / 'break_raw.fif' + new_fname = tmp_path / "break_raw.fif" raw.save(new_fname) new_raw = read_raw_fif(new_fname, preload=True) new_raw.save(new_fname, overwrite=True) @@ -1588,16 +1692,16 @@ def test_compensation_raw(tmp_path): assert raw_0.compensation_grade == 0 data_0, times_new = raw_0[:, :] assert_array_equal(times, times_new) - assert (np.mean(np.abs(data_0 - data_3)) > 1e-12) + assert np.mean(np.abs(data_0 - data_3)) > 1e-12 # change to grade 1 raw_1 = raw_0.copy().apply_gradient_compensation(1) assert raw_1.compensation_grade == 1 data_1, times_new = raw_1[:, :] assert_array_equal(times, times_new) - assert (np.mean(np.abs(data_1 - data_3)) > 1e-12) + assert np.mean(np.abs(data_1 - data_3)) > 1e-12 pytest.raises(ValueError, raw_1.apply_gradient_compensation, 33) raw_bad = raw_0.copy() - raw_bad.add_proj(compute_proj_raw(raw_0, duration=0.5, verbose='error')) + raw_bad.add_proj(compute_proj_raw(raw_0, duration=0.5, verbose="error")) raw_bad.apply_proj() pytest.raises(RuntimeError, raw_bad.apply_gradient_compensation, 1) # with preload @@ -1606,7 +1710,7 @@ def test_compensation_raw(tmp_path): assert raw_1_new.compensation_grade == 1 data_1_new, times_new = raw_1_new[:, :] assert_array_equal(times, times_new) - assert (np.mean(np.abs(data_1_new - data_3)) > 1e-12) + assert np.mean(np.abs(data_1_new - data_3)) > 1e-12 assert_allclose(data_1, data_1_new, **tols) # change back raw_3_new = raw_1.copy().apply_gradient_compensation(3) @@ -1625,11 +1729,11 @@ def test_compensation_raw(tmp_path): assert raw_3_new.compensation_grade == 3 data_3_new, times_new = raw_3_new[:, :] assert_array_equal(times, times_new) - assert (np.mean(np.abs(data_3_new - data_1)) > 1e-12) + assert np.mean(np.abs(data_3_new - data_1)) > 1e-12 assert_allclose(data_3, data_3_new, **tols) # Try IO with compensation - temp_file = tmp_path / 'raw.fif' + temp_file = tmp_path / "raw.fif" raw_3.save(temp_file, overwrite=True) for preload in (True, False): raw_read = read_raw_fif(temp_file, preload=preload) @@ -1667,10 +1771,20 @@ def test_compensation_raw(tmp_path): @requires_mne def test_compensation_raw_mne(tmp_path): """Test Raw compensation by comparing with MNE-C.""" + def compensate_mne(fname, grad): - tmp_fname = tmp_path / 'mne_ctf_test_raw.fif' - cmd = ['mne_process_raw', '--raw', fname, '--save', tmp_fname, - '--grad', str(grad), '--projoff', '--filteroff'] + tmp_fname = tmp_path / "mne_ctf_test_raw.fif" + cmd = [ + "mne_process_raw", + "--raw", + fname, + "--save", + tmp_fname, + "--grad", + str(grad), + "--projoff", + "--filteroff", + ] run_subprocess(cmd) return read_raw_fif(tmp_fname, preload=True) @@ -1679,12 +1793,19 @@ def compensate_mne(fname, grad): raw_py.apply_gradient_compensation(grad) raw_c = compensate_mne(ctf_comp_fname, grad) assert_allclose(raw_py._data, raw_c._data, rtol=1e-6, atol=1e-17) - assert raw_py.info['nchan'] == raw_c.info['nchan'] - for ch_py, ch_c in zip(raw_py.info['chs'], raw_c.info['chs']): - for key in ('ch_name', 'coil_type', 'scanno', 'logno', 'unit', - 'coord_frame', 'kind'): + assert raw_py.info["nchan"] == raw_c.info["nchan"] + for ch_py, ch_c in zip(raw_py.info["chs"], raw_c.info["chs"]): + for key in ( + "ch_name", + "coil_type", + "scanno", + "logno", + "unit", + "coord_frame", + "kind", + ): assert ch_py[key] == ch_c[key] - for key in ('loc', 'unit_mul', 'range', 'cal'): + for key in ("loc", "unit_mul", "range", "cal"): assert_allclose(ch_py[key], ch_c[key]) @@ -1709,23 +1830,23 @@ def test_drop_channels_mixin(): # Test that dropping all channels a projector applies to will lead to the # removal of said projector. raw = read_raw_fif(fif_fname).crop(0, 1) - n_projs = len(raw.info['projs']) - eeg_names = raw.info['projs'][-1]['data']['col_names'] - with pytest.raises(RuntimeError, match='loaded'): + n_projs = len(raw.info["projs"]) + eeg_names = raw.info["projs"][-1]["data"]["col_names"] + with pytest.raises(RuntimeError, match="loaded"): raw.copy().apply_proj().drop_channels(eeg_names) raw.load_data().drop_channels(eeg_names) # EEG proj - assert len(raw.info['projs']) == n_projs - 1 + assert len(raw.info["projs"]) == n_projs - 1 # Dropping EEG channels with custom ref removes info['custom_ref_applied'] raw = read_raw_fif(fif_fname).crop(0, 1).load_data() raw.set_eeg_reference() - assert raw.info['custom_ref_applied'] + assert raw.info["custom_ref_applied"] raw.drop_channels(eeg_names) - assert not raw.info['custom_ref_applied'] + assert not raw.info["custom_ref_applied"] @testing.requires_testing_data -@pytest.mark.parametrize('preload', (True, False)) +@pytest.mark.parametrize("preload", (True, False)) def test_pick_channels_mixin(preload): """Test channel-picking functionality.""" raw = read_raw_fif(fif_fname, preload=preload) @@ -1742,7 +1863,7 @@ def test_pick_channels_mixin(preload): assert ch_names == raw.ch_names assert len(ch_names) == len(raw._cals) assert len(ch_names) == raw.get_data().shape[0] - with pytest.raises(ValueError, match='must be'): + with pytest.raises(ValueError, match="must be"): raw.pick_channels(ch_names[0]) assert_allclose(raw[:][0], raw_orig[:3][0]) @@ -1776,13 +1897,13 @@ def test_memmap(tmp_path): # add_channels orig_data = raw_0[:][0] new_ch_info = pick_info(raw_0.info, [0]) - new_ch_info['chs'][0]['ch_name'] = 'foo' + new_ch_info["chs"][0]["ch_name"] = "foo" new_ch_info._update_redundant() new_data = np.linspace(0, 1, len(raw_0.times))[np.newaxis] ch = RawArray(new_data, new_ch_info) raw_0.add_channels([ch]) - if sys.platform == 'darwin': - assert not hasattr(raw_0._data, 'filename') + if sys.platform == "darwin": + assert not hasattr(raw_0._data, "filename") else: assert raw_0._data.filename == memmaps[2] assert_allclose(orig_data, raw_0[:-1][0], atol=1e-7) @@ -1797,7 +1918,7 @@ def test_memmap(tmp_path): raw_1 = raw_0.copy() assert isinstance(raw_1._data, np.memmap) assert raw_1._data.filename is None - raw_0._data[:] = 0. + raw_0._data[:] = 0.0 assert not raw_0._data.any() assert raw_1._data[:1, 3:5].all() # other things like drop_channels and crop work but do not use memmapping, @@ -1806,41 +1927,55 @@ def test_memmap(tmp_path): # These are slow on Azure Windows so let's do a subset -@pytest.mark.parametrize('kind', [ - 'file', - pytest.param('bytes', marks=pytest.mark.slowtest), -]) -@pytest.mark.parametrize('preload', [ - True, - pytest.param(str, marks=pytest.mark.slowtest), -]) -@pytest.mark.parametrize('split', [ - False, - pytest.param(True, marks=pytest.mark.slowtest), -]) +@pytest.mark.parametrize( + "kind", + [ + "file", + pytest.param("bytes", marks=pytest.mark.slowtest), + ], +) +@pytest.mark.parametrize( + "preload", + [ + True, + pytest.param(str, marks=pytest.mark.slowtest), + ], +) +@pytest.mark.parametrize( + "split", + [ + False, + pytest.param(True, marks=pytest.mark.slowtest), + ], +) def test_file_like(kind, preload, split, tmp_path): """Test handling with file-like objects.""" if split: - fname = tmp_path / 'test_raw.fif' - read_raw_fif(test_fif_fname).save(fname, split_size='5MB') + fname = tmp_path / "test_raw.fif" + read_raw_fif(test_fif_fname).save(fname, split_size="5MB") assert fname.is_file() - assert Path(str(fname)[:-4] + '-1.fif').is_file() + assert Path(str(fname)[:-4] + "-1.fif").is_file() else: fname = test_fif_fname if preload is str: - preload = str(tmp_path / 'memmap') - with open(str(fname), 'rb') as file_fid: - fid = BytesIO(file_fid.read()) if kind == 'bytes' else file_fid + preload = str(tmp_path / "memmap") + with open(str(fname), "rb") as file_fid: + fid = BytesIO(file_fid.read()) if kind == "bytes" else file_fid assert not fid.closed assert not file_fid.closed - with pytest.raises(ValueError, match='preload must be used with file'): + with pytest.raises(ValueError, match="preload must be used with file"): read_raw_fif(fid) assert not fid.closed assert not file_fid.closed # Use test_preloading=False but explicitly pass the preload type # so that we don't bother testing preload=False - kwargs = dict(fname=fid, preload=preload, on_split_missing='ignore', - test_preloading=False, test_kwargs=False) + kwargs = dict( + fname=fid, + preload=preload, + on_split_missing="ignore", + test_preloading=False, + test_kwargs=False, + ) _test_raw_reader(read_raw_fif, **kwargs) assert not fid.closed assert not file_fid.closed @@ -1855,17 +1990,20 @@ def test_str_like(): assert_allclose(raw_path._data, raw_str._data) -@pytest.mark.parametrize('fname', [ - test_fif_fname, - testing._pytest_param(fif_fname), - testing._pytest_param(ms_fname), -]) +@pytest.mark.parametrize( + "fname", + [ + test_fif_fname, + testing._pytest_param(fif_fname), + testing._pytest_param(ms_fname), + ], +) def test_bad_acq(fname): """Test handling of acquisition errors.""" # see gh-7844 - raw = read_raw_fif(fname, allow_maxshield='yes').load_data() - with open(fname, 'rb') as fid: - for ent in raw._raw_extras[0]['ent']: + raw = read_raw_fif(fname, allow_maxshield="yes").load_data() + with open(fname, "rb") as fid: + for ent in raw._raw_extras[0]["ent"]: fid.seek(ent.pos, 0) tag = _read_tag_header(fid) # hack these, others (kind, type) should be correct @@ -1874,22 +2012,23 @@ def test_bad_acq(fname): @testing.requires_testing_data -@pytest.mark.skipif(sys.platform not in ('darwin', 'linux'), - reason='Needs proper symlinking') +@pytest.mark.skipif( + sys.platform not in ("darwin", "linux"), reason="Needs proper symlinking" +) def test_split_symlink(tmp_path): """Test split files with symlinks.""" # regression test for gh-9221 - (tmp_path / 'first').mkdir() - first = tmp_path / 'first' / 'test_raw.fif' - raw = read_raw_fif(fif_fname).pick('meg').load_data() - raw.save(first, buffer_size_sec=1, split_size='10MB', verbose=True) - second = Path(str(first)[:-4] + '-1.fif') + (tmp_path / "first").mkdir() + first = tmp_path / "first" / "test_raw.fif" + raw = read_raw_fif(fif_fname).pick("meg").load_data() + raw.save(first, buffer_size_sec=1, split_size="10MB", verbose=True) + second = Path(str(first)[:-4] + "-1.fif") assert second.is_file() - assert not Path(str(first)[:-4] + '-2.fif').is_file() - (tmp_path / 'a').mkdir() - (tmp_path / 'b').mkdir() - new_first = tmp_path / 'a' / 'test_raw.fif' - new_second = tmp_path / 'b' / 'test_raw-1.fif' + assert not Path(str(first)[:-4] + "-2.fif").is_file() + (tmp_path / "a").mkdir() + (tmp_path / "b").mkdir() + new_first = tmp_path / "a" / "test_raw.fif" + new_second = tmp_path / "b" / "test_raw-1.fif" shutil.move(first, new_first) shutil.move(second, new_second) os.symlink(new_first, first) @@ -1904,17 +2043,17 @@ def test_corrupted(tmp_path): # Must be a file written by Neuromag, not us, since we don't write the dir # at the end, so use the skip one (straight from acq). raw = read_raw_fif(skip_fname) - with open(skip_fname, 'rb') as fid: + with open(skip_fname, "rb") as fid: tag = read_tag_info(fid) tag = read_tag(fid) dirpos = int(tag.data.item()) assert dirpos == 12641532 fid.seek(0) data = fid.read(dirpos) - bad_fname = tmp_path / 'test_raw.fif' - with open(bad_fname, 'wb') as fid: + bad_fname = tmp_path / "test_raw.fif" + with open(bad_fname, "wb") as fid: fid.write(data) - with pytest.warns(RuntimeWarning, match='.*tag directory.*corrupt.*'): + with pytest.warns(RuntimeWarning, match=".*tag directory.*corrupt.*"): raw_bad = read_raw_fif(bad_fname) assert_allclose(raw.get_data(), raw_bad.get_data()) @@ -1922,17 +2061,14 @@ def test_corrupted(tmp_path): @testing.requires_testing_data def test_expand_user(tmp_path, monkeypatch): """Test that we're expanding `~` before reading and writing.""" - monkeypatch.setenv('HOME', str(tmp_path)) - monkeypatch.setenv('USERPROFILE', str(tmp_path)) # Windows + monkeypatch.setenv("HOME", str(tmp_path)) + monkeypatch.setenv("USERPROFILE", str(tmp_path)) # Windows path_in = Path(fif_fname) path_out = tmp_path / path_in.name - path_home = Path('~') / path_in.name + path_home = Path("~") / path_in.name - shutil.copyfile( - src=path_in, - dst=path_out - ) + shutil.copyfile(src=path_in, dst=path_out) raw = read_raw_fif(fname=path_home, preload=True) raw.save(fname=path_home, overwrite=True) diff --git a/mne/io/fil/__init__.py b/mne/io/fil/__init__.py index e07e1fc0bab..3e9acc6c329 100644 --- a/mne/io/fil/__init__.py +++ b/mne/io/fil/__init__.py @@ -2,4 +2,4 @@ # # License: BSD-3-Clause -from .fil import read_raw_fil \ No newline at end of file +from .fil import read_raw_fil diff --git a/mne/io/fil/fil.py b/mne/io/fil/fil.py index 29ed794d7ac..b5b3a96d9c3 100644 --- a/mne/io/fil/fil.py +++ b/mne/io/fil/fil.py @@ -16,12 +16,16 @@ from ...transforms import get_ras_to_neuromag_trans, apply_trans, Transform from ...utils import warn, fill_doc, verbose, _check_fname -from .sensors import (_refine_sensor_orientation, _get_pos_units, - _size2units, _get_plane_vectors) +from .sensors import ( + _refine_sensor_orientation, + _get_pos_units, + _size2units, + _get_plane_vectors, +) @verbose -def read_raw_fil(binfile, precision='single', preload=False, *, verbose=None): +def read_raw_fil(binfile, precision="single", preload=False, *, verbose=None): """Raw object from FIL-OPMEG formatted data. Parameters @@ -71,67 +75,67 @@ class RawFIL(BaseRaw): mne.io.Raw : Documentation of attributes and methods of RawFIL. """ - def __init__(self, binfile, precision='single', preload=False): - - if precision == 'single': - dt = np.dtype('>f') + def __init__(self, binfile, precision="single", preload=False): + if precision == "single": + dt = np.dtype(">f") bps = 4 else: - dt = np.dtype('>d') + dt = np.dtype(">d") bps = 8 sample_info = dict() - sample_info['dt'] = dt - sample_info['bps'] = bps + sample_info["dt"] = dt + sample_info["bps"] = bps files = _get_file_names(binfile) - chans = _from_tsv(files['chans']) - chanpos = _from_tsv(files['positions']) - nchans = len(chans['name']) - nlocs = len(chanpos['name']) - nsamples = _determine_nsamples(files['bin'], nchans, precision) - 1 - sample_info['nsamples'] = nsamples + chans = _from_tsv(files["chans"]) + chanpos = _from_tsv(files["positions"]) + nchans = len(chans["name"]) + nlocs = len(chanpos["name"]) + nsamples = _determine_nsamples(files["bin"], nchans, precision) - 1 + sample_info["nsamples"] = nsamples raw_extras = list() raw_extras.append(sample_info) - chans['pos'] = [None] * nchans - chans['ori'] = [None] * nchans + chans["pos"] = [None] * nchans + chans["ori"] = [None] * nchans for ii in range(0, nlocs): - idx = chans['name'].index(chanpos['name'][ii]) - tmp = np.array([chanpos['Px'][ii], - chanpos['Py'][ii], - chanpos['Pz'][ii]]) - chans['pos'][idx] = tmp.astype(np.float64) - tmp = np.array([chanpos['Ox'][ii], - chanpos['Oy'][ii], - chanpos['Oz'][ii]]) - chans['ori'][idx] = tmp.astype(np.float64) - - with open(files['meg'], 'r') as fid: + idx = chans["name"].index(chanpos["name"][ii]) + tmp = np.array([chanpos["Px"][ii], chanpos["Py"][ii], chanpos["Pz"][ii]]) + chans["pos"][idx] = tmp.astype(np.float64) + tmp = np.array([chanpos["Ox"][ii], chanpos["Oy"][ii], chanpos["Oz"][ii]]) + chans["ori"][idx] = tmp.astype(np.float64) + + with open(files["meg"], "r") as fid: meg = json.load(fid) info = _compose_meas_info(meg, chans) super(RawFIL, self).__init__( - info, preload, filenames=[files['bin']], raw_extras=raw_extras, - last_samps=[nsamples], orig_format=precision) - - if files['coordsystem'].is_file(): - with open(files['coordsystem'], 'r') as fid: + info, + preload, + filenames=[files["bin"]], + raw_extras=raw_extras, + last_samps=[nsamples], + orig_format=precision, + ) + + if files["coordsystem"].is_file(): + with open(files["coordsystem"], "r") as fid: csys = json.load(fid) - hc = csys['HeadCoilCoordinates'] + hc = csys["HeadCoilCoordinates"] for key in hc: - if key.lower() == 'lpa': + if key.lower() == "lpa": lpa = np.asarray(hc[key]) - elif key.lower() == 'rpa': + elif key.lower() == "rpa": rpa = np.asarray(hc[key]) - elif key.lower().startswith('nas'): + elif key.lower().startswith("nas"): nas = np.asarray(hc[key]) else: - warn(f'{key} is not a valid fiducial name!') + warn(f"{key} is not a valid fiducial name!") size = np.linalg.norm(nas - rpa) unit, sf = _size2units(size) @@ -149,79 +153,102 @@ def __init__(self, binfile, precision='single', preload=False): rpa = apply_trans(t, rpa) with self.info._unlock(): - self.info['dig'] = _make_dig_points(nasion=nas, - lpa=lpa, - rpa=rpa, - coord_frame='meg') + self.info["dig"] = _make_dig_points( + nasion=nas, lpa=lpa, rpa=rpa, coord_frame="meg" + ) else: warn( - 'No fiducials found in files, defaulting sensor array to ' - 'FIFFV_COORD_DEVICE, this may cause problems later!') + "No fiducials found in files, defaulting sensor array to " + "FIFFV_COORD_DEVICE, this may cause problems later!" + ) t = np.eye(4) with self.info._unlock(): - self.info['dev_head_t'] = \ - Transform(FIFF.FIFFV_COORD_DEVICE, - FIFF.FIFFV_COORD_HEAD, t) + self.info["dev_head_t"] = Transform( + FIFF.FIFFV_COORD_DEVICE, FIFF.FIFFV_COORD_HEAD, t + ) def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): """Read a chunk of raw data.""" si = self._raw_extras[fi] _read_segments_file( - self, data, idx, fi, start, stop, cals, mult, dtype=si['dt']) + self, data, idx, fi, start, stop, cals, mult, dtype=si["dt"] + ) def _convert_channel_info(chans): """Convert the imported _channels.tsv into the chs element of raw.info.""" nmeg = nstim = nmisc = nref = 0 - units, sf = _get_pos_units(chans['pos']) + units, sf = _get_pos_units(chans["pos"]) chs = list() - for ii in range(len(chans['name'])): - ch = dict(scanno=ii + 1, range=1., cal=1., loc=np.full(12, np.nan), - unit_mul=FIFF.FIFF_UNITM_NONE, ch_name=chans['name'][ii], - coil_type=FIFF.FIFFV_COIL_NONE) + for ii in range(len(chans["name"])): + ch = dict( + scanno=ii + 1, + range=1.0, + cal=1.0, + loc=np.full(12, np.nan), + unit_mul=FIFF.FIFF_UNITM_NONE, + ch_name=chans["name"][ii], + coil_type=FIFF.FIFFV_COIL_NONE, + ) chs.append(ch) # create the channel information - if chans['pos'][ii] is not None: - r0 = chans['pos'][ii].copy() / sf # mm to m - ez = chans['ori'][ii].copy() + if chans["pos"][ii] is not None: + r0 = chans["pos"][ii].copy() / sf # mm to m + ez = chans["ori"][ii].copy() ez = ez / np.linalg.norm(ez) ex, ey = _get_plane_vectors(ez) - ch['loc'] = np.concatenate([r0, ex, ey, ez]) + ch["loc"] = np.concatenate([r0, ex, ey, ez]) - if chans['type'][ii] == 'MEGMAG': + if chans["type"][ii] == "MEGMAG": nmeg += 1 - ch.update(logno=nmeg, coord_frame=FIFF.FIFFV_COORD_DEVICE, - kind=FIFF.FIFFV_MEG_CH, unit=FIFF.FIFF_UNIT_T, - coil_type=FIFF.FIFFV_COIL_QUSPIN_ZFOPM_MAG2) - elif chans['type'][ii] == 'MEGREFMAG': + ch.update( + logno=nmeg, + coord_frame=FIFF.FIFFV_COORD_DEVICE, + kind=FIFF.FIFFV_MEG_CH, + unit=FIFF.FIFF_UNIT_T, + coil_type=FIFF.FIFFV_COIL_QUSPIN_ZFOPM_MAG2, + ) + elif chans["type"][ii] == "MEGREFMAG": nref += 1 - ch.update(logno=nref, coord_frame=FIFF.FIFFV_COORD_UNKNOWN, - kind=FIFF.FIFFV_REF_MEG_CH, unit=FIFF.FIFF_UNIT_T, - coil_type=FIFF.FIFFV_COIL_QUSPIN_ZFOPM_MAG2) - elif chans['type'][ii] == 'TRIG': + ch.update( + logno=nref, + coord_frame=FIFF.FIFFV_COORD_UNKNOWN, + kind=FIFF.FIFFV_REF_MEG_CH, + unit=FIFF.FIFF_UNIT_T, + coil_type=FIFF.FIFFV_COIL_QUSPIN_ZFOPM_MAG2, + ) + elif chans["type"][ii] == "TRIG": nstim += 1 - ch.update(logno=nstim, coord_frame=FIFF.FIFFV_COORD_UNKNOWN, - kind=FIFF.FIFFV_STIM_CH, unit=FIFF.FIFF_UNIT_V) + ch.update( + logno=nstim, + coord_frame=FIFF.FIFFV_COORD_UNKNOWN, + kind=FIFF.FIFFV_STIM_CH, + unit=FIFF.FIFF_UNIT_V, + ) else: nmisc += 1 - ch.update(logno=nmisc, coord_frame=FIFF.FIFFV_COORD_UNKNOWN, - kind=FIFF.FIFFV_MISC_CH, unit=FIFF.FIFF_UNIT_NONE) + ch.update( + logno=nmisc, + coord_frame=FIFF.FIFFV_COORD_UNKNOWN, + kind=FIFF.FIFFV_MISC_CH, + unit=FIFF.FIFF_UNIT_NONE, + ) # set the calibration based on the units - MNE expects T units for meg # and V for eeg - if chans['units'][ii] == 'fT': + if chans["units"][ii] == "fT": ch.update(cal=1e-15) - elif chans['units'][ii] == 'pT': + elif chans["units"][ii] == "pT": ch.update(cal=1e-12) - elif chans['units'][ii] == 'nT': + elif chans["units"][ii] == "nT": ch.update(cal=1e-9) - elif chans['units'][ii] == 'mV': + elif chans["units"][ii] == "mV": ch.update(cal=1e3) - elif chans['units'][ii] == 'uV': + elif chans["units"][ii] == "uV": ch.update(cal=1e6) return chs @@ -229,15 +256,15 @@ def _convert_channel_info(chans): def _compose_meas_info(meg, chans): """Create info structure.""" - info = _empty_info(meg['SamplingFrequency']) + info = _empty_info(meg["SamplingFrequency"]) # Collect all the necessary data from the structures read - info['meas_id'] = get_new_file_id() + info["meas_id"] = get_new_file_id() tmp = _convert_channel_info(chans) - info['chs'] = _refine_sensor_orientation(tmp) + info["chs"] = _refine_sensor_orientation(tmp) # info['chs'] = _convert_channel_info(chans) - info['line_freq'] = meg['PowerLineFrequency'] - info['bads'] = _read_bad_channels(chans) + info["line_freq"] = meg["PowerLineFrequency"] + info["bads"] = _read_bad_channels(chans) info._unlocked = False info._update_redundant() return info @@ -246,7 +273,7 @@ def _compose_meas_info(meg, chans): def _determine_nsamples(bin_fname, nchans, precision): """Identify how many temporal samples in a dataset.""" bsize = bin_fname.stat().st_size - if precision == 'single': + if precision == "single": bps = 4 else: bps = 8 @@ -257,16 +284,17 @@ def _determine_nsamples(bin_fname, nchans, precision): def _read_bad_channels(chans): """Check _channels.tsv file to look for premarked bad channels.""" bads = list() - for ii in range(0, len(chans['status'])): - if chans['status'][ii] == 'bad': - bads.append(chans['name'][ii]) + for ii in range(0, len(chans["status"])): + if chans["status"][ii] == "bad": + bads.append(chans["name"][ii]) return bads def _from_tsv(fname, dtypes=None): """Read a tsv file into a dict (which we know is ordered).""" - data = np.loadtxt(fname, dtype=str, delimiter='\t', ndmin=2, - comments=None, encoding='utf-8-sig') + data = np.loadtxt( + fname, dtype=str, delimiter="\t", ndmin=2, comments=None, encoding="utf-8-sig" + ) column_names = data[0, :] info = data[1:, :] data_dict = dict() @@ -275,8 +303,10 @@ def _from_tsv(fname, dtypes=None): if not isinstance(dtypes, (list, tuple)): dtypes = [dtypes] * info.shape[1] if not len(dtypes) == info.shape[1]: - raise ValueError('dtypes length mismatch. Provided: {0}, ' - 'Expected: {1}'.format(len(dtypes), info.shape[1])) + raise ValueError( + "dtypes length mismatch. Provided: {0}, " + "Expected: {1}".format(len(dtypes), info.shape[1]) + ) for i, name in enumerate(column_names): data_dict[name] = info[:, i].astype(dtypes[i]).tolist() return data_dict @@ -285,16 +315,16 @@ def _from_tsv(fname, dtypes=None): def _get_file_names(binfile): """Guess the filenames based on predicted suffixes.""" binfile = pathlib.Path( - _check_fname(binfile, overwrite='read', must_exist=True, name='fname')) - if not (binfile.suffix == '.bin' and binfile.stem.endswith('_meg')): - raise ValueError( - f'File must be a filename ending in _meg.bin, got {binfile}') + _check_fname(binfile, overwrite="read", must_exist=True, name="fname") + ) + if not (binfile.suffix == ".bin" and binfile.stem.endswith("_meg")): + raise ValueError(f"File must be a filename ending in _meg.bin, got {binfile}") files = dict() dir_ = binfile.parent root = binfile.stem[:-4] # no _meg - files['bin'] = dir_ / (root + '_meg.bin') - files['meg'] = dir_ / (root + '_meg.json') - files['chans'] = dir_ / (root + '_channels.tsv') - files['positions'] = dir_ / (root + '_positions.tsv') - files['coordsystem'] = dir_ / (root + '_coordsystem.json') + files["bin"] = dir_ / (root + "_meg.bin") + files["meg"] = dir_ / (root + "_meg.json") + files["chans"] = dir_ / (root + "_channels.tsv") + files["positions"] = dir_ / (root + "_positions.tsv") + files["coordsystem"] = dir_ / (root + "_coordsystem.json") return files diff --git a/mne/io/fil/sensors.py b/mne/io/fil/sensors.py index 942057787d0..ab94e65ccbc 100644 --- a/mne/io/fil/sensors.py +++ b/mne/io/fil/sensors.py @@ -32,10 +32,9 @@ def _refine_sensor_orientation(chanin): if np.isnan(targetloc.sum()) is False: targetloc = targetloc.reshape(3, 4, order="F") tmploc[:, 2] = targetloc[:, 3] - tmploc[:, 1] = flipFlag * np.cross(tmploc[:, 2], - tmploc[:, 3]) + tmploc[:, 1] = flipFlag * np.cross(tmploc[:, 2], tmploc[:, 3]) chanout[ii]["loc"] = tmploc.reshape(12, order="F") - logger.info('[done]') + logger.info("[done]") return chanout diff --git a/mne/io/fil/tests/test_fil.py b/mne/io/fil/tests/test_fil.py index 87017e04567..0788fd17667 100644 --- a/mne/io/fil/tests/test_fil.py +++ b/mne/io/fil/tests/test_fil.py @@ -15,31 +15,33 @@ import scipy.io -fil_path = testing.data_path(download=False) / 'FIL' +fil_path = testing.data_path(download=False) / "FIL" # TODO: Ignore this warning in all these tests until we deal with this properly pytestmark = pytest.mark.filterwarnings( - 'ignore:.*problems later!:RuntimeWarning', + "ignore:.*problems later!:RuntimeWarning", ) def unpack_mat(matin): """Extract relevant entries from unstructred readmat.""" - data = matin['data'] - grad = data[0][0]['grad'] + data = matin["data"] + grad = data[0][0]["grad"] label = list() coil_label = list() - for ii in range(len(data[0][0]['label'])): - label.append(str(data[0][0]['label'][ii][0][0])) - for ii in range(len(grad[0][0]['label'])): - coil_label.append(str(grad[0][0]['label'][ii][0][0])) - - matout = {'label': label, - 'trial': data['trial'][0][0][0][0], - 'coil_label': coil_label, - 'coil_pos': grad[0][0]['coilpos'], - 'coil_ori': grad[0][0]['coilori']} + for ii in range(len(data[0][0]["label"])): + label.append(str(data[0][0]["label"][ii][0][0])) + for ii in range(len(grad[0][0]["label"])): + coil_label.append(str(grad[0][0]["label"][ii][0][0])) + + matout = { + "label": label, + "trial": data["trial"][0][0][0][0], + "coil_label": coil_label, + "coil_pos": grad[0][0]["coilpos"], + "coil_ori": grad[0][0]["coilori"], + } return matout @@ -65,8 +67,7 @@ def _get_channels_with_positions(info): def _fil_megmag(raw_test, raw_mat): """Test the magnetometer channels.""" - test_inds = pick_types(raw_test.info, meg="mag", - ref_meg=False, exclude="bads") + test_inds = pick_types(raw_test.info, meg="mag", ref_meg=False, exclude="bads") test_list = list(raw_test.info["ch_names"][i] for i in test_inds) mat_list = raw_mat["label"] mat_inds = _match_str(test_list, mat_list) @@ -129,9 +130,7 @@ def _fil_sensorpos(raw_test, raw_mat): def test_fil_all(): """Test FIL reader, match to known answers from .mat file.""" binname = fil_path / "sub-noise_ses-001_task-noise220622_run-001_meg.bin" - matname = ( - fil_path / "sub-noise_ses-001_task-noise220622_run-001_fieldtrip.mat" - ) + matname = fil_path / "sub-noise_ses-001_task-noise220622_run-001_fieldtrip.mat" raw = read_raw_fil(binname) raw.load_data(verbose=False) diff --git a/mne/io/hitachi/hitachi.py b/mne/io/hitachi/hitachi.py index 892e4d33c72..51027642442 100644 --- a/mne/io/hitachi/hitachi.py +++ b/mne/io/hitachi/hitachi.py @@ -12,8 +12,7 @@ from ..meas_info import create_info, _merge_info from ..nirx.nirx import _read_csv_rows_cols from ..utils import _mult_cal_one -from ...utils import (logger, verbose, fill_doc, warn, _check_fname, - _check_option) +from ...utils import logger, verbose, fill_doc, warn, _check_fname, _check_option @fill_doc @@ -45,7 +44,7 @@ def read_raw_hitachi(fname, preload=False, verbose=None): def _check_bad(cond, msg): if cond: - raise RuntimeError(f'Could not parse file: {msg}') + raise RuntimeError(f"Could not parse file: {msg}") @fill_doc @@ -73,18 +72,17 @@ def __init__(self, fname, preload=False, *, verbose=None): fname = [fname] fname = list(fname) # our own list that we can modify for fi, this_fname in enumerate(fname): - fname[fi] = str( - _check_fname(this_fname, "read", True, f"fname[{fi}]") - ) + fname[fi] = str(_check_fname(this_fname, "read", True, f"fname[{fi}]")) infos = list() probes = list() last_samps = list() S_offset = D_offset = 0 - ignore_names = ['Time'] + ignore_names = ["Time"] for this_fname in fname: info, extra, last_samp, offsets = self._get_hitachi_info( - this_fname, S_offset, D_offset, ignore_names) - ignore_names = list(set(ignore_names + info['ch_names'])) + this_fname, S_offset, D_offset, ignore_names + ) + ignore_names = list(set(ignore_names + info["ch_names"])) S_offset += offsets[0] D_offset += offsets[1] infos.append(info) @@ -96,92 +94,98 @@ def __init__(self, fname, preload=False, *, verbose=None): else: info = infos[0] if len(set(last_samps)) != 1: - raise RuntimeError('All files must have the same number of ' - 'samples, got: {last_samps}') + raise RuntimeError( + "All files must have the same number of " "samples, got: {last_samps}" + ) last_samps = [last_samps[0]] raw_extras = [dict(probes=probes)] # One representative filename is good enough here # (additional filenames indicate temporal concat, not ch concat) filenames = [fname[0]] super().__init__( - info, preload, filenames=filenames, last_samps=last_samps, - raw_extras=raw_extras, verbose=verbose) + info, + preload, + filenames=filenames, + last_samps=last_samps, + raw_extras=raw_extras, + verbose=verbose, + ) # This could be a function, but for the sake of indentation, let's make it # a method instead def _get_hitachi_info(self, fname, S_offset, D_offset, ignore_names): - logger.info('Loading %s' % fname) + logger.info("Loading %s" % fname) raw_extra = dict(fname=fname) info_extra = dict() subject_info = dict() ch_wavelengths = dict() fnirs_wavelengths = [None, None] meas_date = age = ch_names = sfreq = None - with open(fname, 'rb') as fid: + with open(fname, "rb") as fid: lines = fid.read() - lines = lines.decode('latin-1').rstrip('\r\n') + lines = lines.decode("latin-1").rstrip("\r\n") oldlen = len(lines) assert len(lines) == oldlen bounds = [0] - end = '\n' if '\n' in lines else '\r' + end = "\n" if "\n" in lines else "\r" bounds.extend(a.end() for a in re.finditer(end, lines)) bounds.append(len(lines)) lines = lines.split(end) assert len(bounds) == len(lines) + 1 - line = lines[0].rstrip(',\r\n') - _check_bad(line != 'Header', 'no header found') + line = lines[0].rstrip(",\r\n") + _check_bad(line != "Header", "no header found") li = 0 mode = None for li, line in enumerate(lines[1:], 1): # Newer format has some blank lines if len(line) == 0: continue - parts = line.rstrip(',\r\n').split(',') + parts = line.rstrip(",\r\n").split(",") if len(parts) == 0: # some header lines are blank continue kind, parts = parts[0], parts[1:] if len(parts) == 0: - parts = [''] # some fields (e.g., Comment) meaningfully blank - if kind == 'File Version': - logger.info(f'Reading Hitachi fNIRS file version {parts[0]}') - elif kind == 'AnalyzeMode': - _check_bad( - parts != ['Continuous'], f'not continuous data ({parts})') - elif kind == 'Sampling Period[s]': + parts = [""] # some fields (e.g., Comment) meaningfully blank + if kind == "File Version": + logger.info(f"Reading Hitachi fNIRS file version {parts[0]}") + elif kind == "AnalyzeMode": + _check_bad(parts != ["Continuous"], f"not continuous data ({parts})") + elif kind == "Sampling Period[s]": sfreq = 1 / float(parts[0]) - elif kind == 'Exception': + elif kind == "Exception": raise NotImplementedError(kind) - elif kind == 'Comment': - info_extra['description'] = parts[0] - elif kind == 'ID': - subject_info['his_id'] = parts[0] - elif kind == 'Name': + elif kind == "Comment": + info_extra["description"] = parts[0] + elif kind == "ID": + subject_info["his_id"] = parts[0] + elif kind == "Name": if len(parts): - name = parts[0].split(' ') + name = parts[0].split(" ") if len(name): - subject_info['first_name'] = name[0] - subject_info['last_name'] = ' '.join(name[1:]) - elif kind == 'Age': - age = int(parts[0].rstrip('y')) - elif kind == 'Mode': + subject_info["first_name"] = name[0] + subject_info["last_name"] = " ".join(name[1:]) + elif kind == "Age": + age = int(parts[0].rstrip("y")) + elif kind == "Mode": mode = parts[0] - elif kind in ('HPF[Hz]', 'LPF[Hz]'): + elif kind in ("HPF[Hz]", "LPF[Hz]"): try: freq = float(parts[0]) except ValueError: pass else: - info_extra[{'HPF[Hz]': 'highpass', - 'LPF[Hz]': 'lowpass'}[kind]] = freq - elif kind == 'Date': + info_extra[ + {"HPF[Hz]": "highpass", "LPF[Hz]": "lowpass"}[kind] + ] = freq + elif kind == "Date": # 5/17/04 5:14 try: - mdy, HM = parts[0].split(' ') - H, M = HM.split(':') + mdy, HM = parts[0].split(" ") + H, M = HM.split(":") if len(H) == 1: - H = f'0{H}' - mdyHM = ' '.join([mdy, ':'.join([H, M])]) - for fmt in ('%m/%d/%y %H:%M', '%Y/%m/%d %H:%M'): + H = f"0{H}" + mdyHM = " ".join([mdy, ":".join([H, M])]) + for fmt in ("%m/%d/%y %H:%M", "%Y/%m/%d %H:%M"): try: meas_date = dt.datetime.strptime(mdyHM, fmt) except Exception: @@ -191,59 +195,66 @@ def _get_hitachi_info(self, fname, S_offset, D_offset, ignore_names): else: raise RuntimeError # unknown format except Exception: - warn('Extraction of measurement date failed. ' - 'Please report this as a github issue. ' - 'The date is being set to January 1st, 2000, ' - f'instead of {repr(parts[0])}') - elif kind == 'Sex': + warn( + "Extraction of measurement date failed. " + "Please report this as a github issue. " + "The date is being set to January 1st, 2000, " + f"instead of {repr(parts[0])}" + ) + elif kind == "Sex": try: - subject_info['sex'] = dict( - female=FIFF.FIFFV_SUBJ_SEX_FEMALE, - male=FIFF.FIFFV_SUBJ_SEX_MALE)[parts[0].lower()] + subject_info["sex"] = dict( + female=FIFF.FIFFV_SUBJ_SEX_FEMALE, male=FIFF.FIFFV_SUBJ_SEX_MALE + )[parts[0].lower()] except KeyError: pass - elif kind == 'Wave[nm]': + elif kind == "Wave[nm]": fnirs_wavelengths[:] = [int(part) for part in parts] - elif kind == 'Wave Length': - ch_regex = re.compile(r'^(.*)\(([0-9\.]+)\)$') + elif kind == "Wave Length": + ch_regex = re.compile(r"^(.*)\(([0-9\.]+)\)$") for ent in parts: _, v = ch_regex.match(ent).groups() ch_wavelengths[ent] = float(v) - elif kind == 'Data': + elif kind == "Data": break fnirs_wavelengths = np.array(fnirs_wavelengths, int) assert len(fnirs_wavelengths) == 2 - ch_names = lines[li + 1].rstrip(',\r\n').split(',') + ch_names = lines[li + 1].rstrip(",\r\n").split(",") # cull to correct ones - raw_extra['keep_mask'] = ~np.in1d(ch_names, list(ignore_names)) + raw_extra["keep_mask"] = ~np.in1d(ch_names, list(ignore_names)) for ci, ch_name in enumerate(ch_names): - if re.match('Probe[0-9]+', ch_name): - raw_extra['keep_mask'][ci] = False + if re.match("Probe[0-9]+", ch_name): + raw_extra["keep_mask"][ci] = False # set types - ch_names = [ch_name for ci, ch_name in enumerate(ch_names) - if raw_extra['keep_mask'][ci]] - ch_types = ['fnirs_cw_amplitude' if ch_name.startswith('CH') - else 'stim' - for ch_name in ch_names] + ch_names = [ + ch_name for ci, ch_name in enumerate(ch_names) if raw_extra["keep_mask"][ci] + ] + ch_types = [ + "fnirs_cw_amplitude" if ch_name.startswith("CH") else "stim" + for ch_name in ch_names + ] # get locations - nirs_names = [ch_name for ch_name, ch_type in zip(ch_names, ch_types) - if ch_type == 'fnirs_cw_amplitude'] + nirs_names = [ + ch_name + for ch_name, ch_type in zip(ch_names, ch_types) + if ch_type == "fnirs_cw_amplitude" + ] n_nirs = len(nirs_names) assert n_nirs % 2 == 0 names = { - '3x3': 'ETG-100', - '3x5': 'ETG-7000', - '4x4': 'ETG-7000', - '3x11': 'ETG-4000', + "3x3": "ETG-100", + "3x5": "ETG-7000", + "4x4": "ETG-7000", + "3x11": "ETG-4000", } - _check_option('Hitachi mode', mode, sorted(names)) - n_row, n_col = [int(x) for x in mode.split('x')] - logger.info(f'Constructing pairing matrix for {names[mode]} ({mode})') - pairs = _compute_pairs(n_row, n_col, n=1 + (mode == '3x3')) + _check_option("Hitachi mode", mode, sorted(names)) + n_row, n_col = [int(x) for x in mode.split("x")] + logger.info(f"Constructing pairing matrix for {names[mode]} ({mode})") + pairs = _compute_pairs(n_row, n_col, n=1 + (mode == "3x3")) assert n_nirs == len(pairs) * 2 locs = np.zeros((len(ch_names), 12)) locs[:, :9] = np.nan - idxs = np.where(np.array(ch_types, 'U') == 'fnirs_cw_amplitude')[0] + idxs = np.where(np.array(ch_types, "U") == "fnirs_cw_amplitude")[0] for ii, idx in enumerate(idxs): ch_name = ch_names[idx] # Use the actual/accurate wavelength in loc @@ -252,51 +263,57 @@ def _get_hitachi_info(self, fname, S_offset, D_offset, ignore_names): # Rename channel based on standard naming scheme, using the # nominal wavelength sidx, didx = pairs[ii // 2] - nom_freq = fnirs_wavelengths[np.argmin(np.abs( - acc_freq - fnirs_wavelengths))] + nom_freq = fnirs_wavelengths[ + np.argmin(np.abs(acc_freq - fnirs_wavelengths)) + ] ch_names[idx] = ( - f'S{S_offset + sidx + 1}_' - f'D{D_offset + didx + 1} ' - f'{nom_freq}' + f"S{S_offset + sidx + 1}_" f"D{D_offset + didx + 1} " f"{nom_freq}" ) offsets = np.array(pairs, int).max(axis=0) + 1 # figure out bounds - bounds = raw_extra['bounds'] = bounds[li + 2:] + bounds = raw_extra["bounds"] = bounds[li + 2 :] last_samp = len(bounds) - 2 if age is not None and meas_date is not None: - subject_info['birthday'] = (meas_date.year - age, - meas_date.month, - meas_date.day) + subject_info["birthday"] = ( + meas_date.year - age, + meas_date.month, + meas_date.day, + ) if meas_date is None: meas_date = dt.datetime(2000, 1, 1, 0, 0, 0) meas_date = meas_date.replace(tzinfo=dt.timezone.utc) if subject_info: - info_extra['subject_info'] = subject_info + info_extra["subject_info"] = subject_info # Create mne structure info = create_info(ch_names, sfreq, ch_types=ch_types) with info._unlock(): info.update(info_extra) - info['meas_date'] = meas_date + info["meas_date"] = meas_date for li, loc in enumerate(locs): - info['chs'][li]['loc'][:] = loc + info["chs"][li]["loc"][:] = loc return info, raw_extra, last_samp, offsets def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): """Read a segment of data from a file.""" this_data = list() - for this_probe in self._raw_extras[fi]['probes']: - this_data.append(_read_csv_rows_cols( - this_probe['fname'], - start, stop, this_probe['keep_mask'], - this_probe['bounds'], sep=',', - replace=lambda x: - x.replace('\r', '\n') - .replace('\n\n', '\n') - .replace('\n', ',') - .replace(':', '')).T) + for this_probe in self._raw_extras[fi]["probes"]: + this_data.append( + _read_csv_rows_cols( + this_probe["fname"], + start, + stop, + this_probe["keep_mask"], + this_probe["bounds"], + sep=",", + replace=lambda x: x.replace("\r", "\n") + .replace("\n\n", "\n") + .replace("\n", ",") + .replace(":", ""), + ).T + ) this_data = np.concatenate(this_data, axis=0) _mult_cal_one(data, this_data, idx, cals, mult) return data diff --git a/mne/io/hitachi/tests/test_hitachi.py b/mne/io/hitachi/tests/test_hitachi.py index d04218b1eb0..89b94fd7a06 100644 --- a/mne/io/hitachi/tests/test_hitachi.py +++ b/mne/io/hitachi/tests/test_hitachi.py @@ -12,13 +12,19 @@ from mne.io import read_raw_hitachi from mne.io.hitachi.hitachi import _compute_pairs from mne.io.tests.test_raw import _test_raw_reader -from mne.preprocessing.nirs import (source_detector_distances, - optical_density, tddr, beer_lambert_law, - scalp_coupling_index) +from mne.preprocessing.nirs import ( + source_detector_distances, + optical_density, + tddr, + beer_lambert_law, + scalp_coupling_index, +) CONTENTS = dict() -CONTENTS['1.18'] = b"""\ +CONTENTS[ + "1.18" +] = b"""\ Header,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,, File Version,1.18,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,, Patient Information,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,, @@ -123,7 +129,9 @@ """ # noqa: E501 -CONTENTS['1.25'] = b"""\ +CONTENTS[ + "1.25" +] = b"""\ Header File Version,1.25 Patient Information @@ -178,17 +186,21 @@ """ # noqa: E501 -@pytest.mark.parametrize('preload', (True, False)) -@pytest.mark.parametrize('version, n_ch, n_times, lowpass, sex, date, end', [ - ('1.18', 48, 60, 0.1, 2, (2004, 5, 17, 5, 14, 0, 0), None), - ('1.25', 108, 10, 5., 1, (2020, 2, 2, 11, 20, 0, 0), b'\r'), - ('1.25', 108, 10, 5., 1, (2020, 2, 2, 11, 20, 0, 0), b'\n'), - ('1.25', 108, 10, 5., 1, (2020, 2, 2, 11, 20, 0, 0), b'\r\n'), - # Fake a dual-probe file - (['1.18', '1.18'], 92, 60, 0.1, 2, (2004, 5, 17, 5, 14, 0, 0), None), -]) -def test_hitachi_basic(preload, version, n_ch, n_times, lowpass, sex, date, - end, tmp_path): +@pytest.mark.parametrize("preload", (True, False)) +@pytest.mark.parametrize( + "version, n_ch, n_times, lowpass, sex, date, end", + [ + ("1.18", 48, 60, 0.1, 2, (2004, 5, 17, 5, 14, 0, 0), None), + ("1.25", 108, 10, 5.0, 1, (2020, 2, 2, 11, 20, 0, 0), b"\r"), + ("1.25", 108, 10, 5.0, 1, (2020, 2, 2, 11, 20, 0, 0), b"\n"), + ("1.25", 108, 10, 5.0, 1, (2020, 2, 2, 11, 20, 0, 0), b"\r\n"), + # Fake a dual-probe file + (["1.18", "1.18"], 92, 60, 0.1, 2, (2004, 5, 17, 5, 14, 0, 0), None), + ], +) +def test_hitachi_basic( + preload, version, n_ch, n_times, lowpass, sex, date, end, tmp_path +): """Test NIRSport1 file with no saturation.""" if not isinstance(version, list): versions = [version] @@ -197,63 +209,68 @@ def test_hitachi_basic(preload, version, n_ch, n_times, lowpass, sex, date, del version fnames = list() for vi, v in enumerate(versions, 1): - fname = tmp_path / f'test{vi}.csv' - contents = CONTENTS[v].replace( - f'Probe{vi - 1}'.encode(), - f'Probe{vi}'.encode()) + fname = tmp_path / f"test{vi}.csv" + contents = CONTENTS[v].replace(f"Probe{vi - 1}".encode(), f"Probe{vi}".encode()) if end is not None: - contents = contents.replace(b'\r', b'\n').replace(b'\n\n', b'\n') - contents = contents.replace(b'\n', end) - with open(fname, 'wb') as fid: + contents = contents.replace(b"\r", b"\n").replace(b"\n\n", b"\n") + contents = contents.replace(b"\n", end) + with open(fname, "wb") as fid: fid.write(CONTENTS[v]) fnames.append(fname) del fname raw = read_raw_hitachi(fnames, preload=preload, verbose=True) data = raw.get_data() assert data.shape == (n_ch, n_times) - assert raw.info['sfreq'] == 10 - assert raw.info['lowpass'] == lowpass - assert raw.info['subject_info']['sex'] == sex + assert raw.info["sfreq"] == 10 + assert raw.info["lowpass"] == lowpass + assert raw.info["subject_info"]["sex"] == sex assert np.isfinite(raw.get_data()).all() - assert raw.info['meas_date'] == dt.datetime(*date, tzinfo=dt.timezone.utc) + assert raw.info["meas_date"] == dt.datetime(*date, tzinfo=dt.timezone.utc) # bad distances (zero) distances = source_detector_distances(raw.info) want = [np.nan] * (n_ch - 4) - assert_allclose(distances, want, atol=0.) + assert_allclose(distances, want, atol=0.0) raw_od_bad = optical_density(raw) - with pytest.warns(RuntimeWarning, match='will be zero'): + with pytest.warns(RuntimeWarning, match="will be zero"): beer_lambert_law(raw_od_bad, ppf=6) # bad distances (too big) - if versions[0] == '1.18' and len(fnames) == 1: - need = sum(([f'S{ii}', f'D{ii}'] for ii in range(1, 9)), [])[:-1] - have = 'P7 FC3 C3 CP3 P3 F5 FC5 C5 CP5 P5 F7 FT7 T7 TP7 F3'.split() + if versions[0] == "1.18" and len(fnames) == 1: + need = sum(([f"S{ii}", f"D{ii}"] for ii in range(1, 9)), [])[:-1] + have = "P7 FC3 C3 CP3 P3 F5 FC5 C5 CP5 P5 F7 FT7 T7 TP7 F3".split() assert len(need) == len(have) - mon = make_standard_montage('standard_1020') + mon = make_standard_montage("standard_1020") mon.rename_channels(dict(zip(have, need))) raw.set_montage(mon) raw_od_bad = optical_density(raw) - with pytest.warns(RuntimeWarning, match='greater than 10 cm'): + with pytest.warns(RuntimeWarning, match="greater than 10 cm"): beer_lambert_law(raw_od_bad, ppf=6) # good distances - mon = make_standard_montage('standard_1020') - if versions[0] == '1.18': + mon = make_standard_montage("standard_1020") + if versions[0] == "1.18": assert len(fnames) in (1, 2) - need = sum(([f'S{ii}', f'D{ii}'] for ii in range(1, 9)), [])[:-1] - have = 'F3 FC3 C3 CP3 P3 F5 FC5 C5 CP5 P5 F7 FT7 T7 TP7 P7'.split() + need = sum(([f"S{ii}", f"D{ii}"] for ii in range(1, 9)), [])[:-1] + have = "F3 FC3 C3 CP3 P3 F5 FC5 C5 CP5 P5 F7 FT7 T7 TP7 P7".split() assert len(need) == 15 if len(fnames) == 2: - need.extend(sum(( - [f'S{ii}', f'D{jj}'] - for ii, jj in zip(range(9, 17), range(8, 16))), [])[:-1]) - have.extend( - 'F4 FC4 C4 CP4 P4 F6 FC6 C6 CP6 P6 F8 FT8 T8 TP8 P8'.split()) + need.extend( + sum( + ( + [f"S{ii}", f"D{jj}"] + for ii, jj in zip(range(9, 17), range(8, 16)) + ), + [], + )[:-1] + ) + have.extend("F4 FC4 C4 CP4 P4 F6 FC6 C6 CP6 P6 F8 FT8 T8 TP8 P8".split()) assert len(need) == 30 else: assert len(fnames) == 1 - need = sum(([f'S{ii}', f'D{ii}'] for ii in range(1, 18)), [])[:-1] - have = ('FT9 FT7 FC5 FC3 FC1 FCz FC2 FC4 FC6 FT8 FT10 ' - 'T9 T7 C5 C3 C1 Cz C2 C4 C6 T8 T10 ' - 'TP9 TP7 CP5 CP3 CP1 CPz CP2 CP4 CP6 TP8 TP10').split() + need = sum(([f"S{ii}", f"D{ii}"] for ii in range(1, 18)), [])[:-1] + have = ( + "FT9 FT7 FC5 FC3 FC1 FCz FC2 FC4 FC6 FT8 FT10 " + "T9 T7 C5 C3 C1 Cz C2 C4 C6 T8 T10 " + "TP9 TP7 CP5 CP3 CP1 CPz CP2 CP4 CP6 TP8 TP10" + ).split() assert len(need) == 33 assert len(need) == len(have) for h in have: @@ -265,9 +282,10 @@ def test_hitachi_basic(preload, version, n_ch, n_times, lowpass, sex, date, distances = source_detector_distances(raw.info) want = [0.03] * (n_ch - 4) assert_allclose(distances, want, atol=0.01) - test_rank = 'less' if n_times < n_ch else True - _test_raw_reader(read_raw_hitachi, fname=fnames, - boundary_decimal=1, test_rank=test_rank) # low fs + test_rank = "less" if n_times < n_ch else True + _test_raw_reader( + read_raw_hitachi, fname=fnames, boundary_decimal=1, test_rank=test_rank + ) # low fs # TODO: eventually we should refactor these to be in # mne/io/tests/test_raw.py and run them for all fNIRS readers @@ -275,9 +293,9 @@ def test_hitachi_basic(preload, version, n_ch, n_times, lowpass, sex, date, # OD raw_od = optical_density(raw) assert np.isfinite(raw_od.get_data()).all() - sci = scalp_coupling_index(raw_od, verbose='error') + sci = scalp_coupling_index(raw_od, verbose="error") lo, mi, hi = np.percentile(sci, [5, 50, 95]) - if versions[0] == '1.18': + if versions[0] == "1.18": assert -0.1 < lo < 0.1 # not great assert 0.4 < mi < 0.5 assert 0.8 < hi < 0.9 @@ -285,58 +303,158 @@ def test_hitachi_basic(preload, version, n_ch, n_times, lowpass, sex, date, assert 0.99 <= lo <= hi <= 1 # TDDR raw_tddr = tddr(raw_od) - data = raw_tddr.get_data('fnirs') + data = raw_tddr.get_data("fnirs") assert np.isfinite(data.all()) peaks = np.ptp(data, axis=-1) - assert_array_less(1e-4, peaks, err_msg='TDDR too small') - assert_array_less(peaks, 1, err_msg='TDDR too big') + assert_array_less(1e-4, peaks, err_msg="TDDR too small") + assert_array_less(peaks, 1, err_msg="TDDR too big") # HbO/HbR raw_tddr.set_montage(mon) raw_h = beer_lambert_law(raw_tddr, ppf=6) - data = raw_h.get_data('fnirs') + data = raw_h.get_data("fnirs") assert np.isfinite(data).all() assert data.shape == (n_ch - 4, n_times) peaks = np.ptp(data, axis=-1) - assert_array_less(1e-10, peaks, err_msg='Beer-Lambert too small') - assert_array_less(peaks, 1e-5, err_msg='Beer-Lambert too big') + assert_array_less(1e-10, peaks, err_msg="Beer-Lambert too small") + assert_array_less(peaks, 1e-5, err_msg="Beer-Lambert too big") # From Hitachi 2 Homer KNOWN_PAIRS = { (3, 3, 2): ( - (0, 0), (1, 0), (0, 1), (2, 0), (1, 2), - (2, 1), (2, 2), (3, 1), (2, 3), (4, 2), - (3, 3), (4, 3), (5, 4), (6, 4), (5, 5), - (7, 4), (6, 6), (7, 5), (7, 6), (8, 5), - (7, 7), (9, 6), (8, 7), (9, 7)), + (0, 0), + (1, 0), + (0, 1), + (2, 0), + (1, 2), + (2, 1), + (2, 2), + (3, 1), + (2, 3), + (4, 2), + (3, 3), + (4, 3), + (5, 4), + (6, 4), + (5, 5), + (7, 4), + (6, 6), + (7, 5), + (7, 6), + (8, 5), + (7, 7), + (9, 6), + (8, 7), + (9, 7), + ), (3, 5, 1): ( - (0, 0), (1, 0), (1, 1), (2, 1), (0, 2), - (3, 0), (1, 3), (4, 1), (2, 4), (3, 2), - (3, 3), (4, 3), (4, 4), (5, 2), (3, 5), - (6, 3), (4, 6), (7, 4), (5, 5), (6, 5), - (6, 6), (7, 6)), + (0, 0), + (1, 0), + (1, 1), + (2, 1), + (0, 2), + (3, 0), + (1, 3), + (4, 1), + (2, 4), + (3, 2), + (3, 3), + (4, 3), + (4, 4), + (5, 2), + (3, 5), + (6, 3), + (4, 6), + (7, 4), + (5, 5), + (6, 5), + (6, 6), + (7, 6), + ), (4, 4, 1): ( - (0, 0), (1, 0), (1, 1), (0, 2), (2, 0), - (1, 3), (3, 1), (2, 2), (2, 3), (3, 3), - (4, 2), (2, 4), (5, 3), (3, 5), (4, 4), - (5, 4), (5, 5), (4, 6), (6, 4), (5, 7), - (7, 5), (6, 6), (6, 7), (7, 7)), + (0, 0), + (1, 0), + (1, 1), + (0, 2), + (2, 0), + (1, 3), + (3, 1), + (2, 2), + (2, 3), + (3, 3), + (4, 2), + (2, 4), + (5, 3), + (3, 5), + (4, 4), + (5, 4), + (5, 5), + (4, 6), + (6, 4), + (5, 7), + (7, 5), + (6, 6), + (6, 7), + (7, 7), + ), (3, 11, 1): ( - (0, 0), (1, 0), (1, 1), (2, 1), (2, 2), - (3, 2), (3, 3), (4, 3), (4, 4), (5, 4), - (0, 5), (6, 0), (1, 6), (7, 1), (2, 7), - (8, 2), (3, 8), (9, 3), (4, 9), (10, 4), - (5, 10), (6, 5), (6, 6), (7, 6), (7, 7), - (8, 7), (8, 8), (9, 8), (9, 9), (10, 9), - (10, 10), (11, 5), (6, 11), (12, 6), (7, 12), - (13, 7), (8, 13), (14, 8), (9, 14), (15, 9), - (10, 15), (16, 10), (11, 11), (12, 11), (12, 12), - (13, 12), (13, 13), (14, 13), (14, 14), (15, 14), - (15, 15), (16, 15)), + (0, 0), + (1, 0), + (1, 1), + (2, 1), + (2, 2), + (3, 2), + (3, 3), + (4, 3), + (4, 4), + (5, 4), + (0, 5), + (6, 0), + (1, 6), + (7, 1), + (2, 7), + (8, 2), + (3, 8), + (9, 3), + (4, 9), + (10, 4), + (5, 10), + (6, 5), + (6, 6), + (7, 6), + (7, 7), + (8, 7), + (8, 8), + (9, 8), + (9, 9), + (10, 9), + (10, 10), + (11, 5), + (6, 11), + (12, 6), + (7, 12), + (13, 7), + (8, 13), + (14, 8), + (9, 14), + (15, 9), + (10, 15), + (16, 10), + (11, 11), + (12, 11), + (12, 12), + (13, 12), + (13, 13), + (14, 13), + (14, 14), + (15, 14), + (15, 15), + (16, 15), + ), } -@pytest.mark.parametrize('n_rows, n_cols, n', list(KNOWN_PAIRS)) +@pytest.mark.parametrize("n_rows, n_cols, n", list(KNOWN_PAIRS)) def test_compute_pairs(n_rows, n_cols, n): """Test computation of S-D pairings.""" want = KNOWN_PAIRS[(n_rows, n_cols, n)] diff --git a/mne/io/kit/constants.py b/mne/io/kit/constants.py index 144dc584816..27e5fbe8630 100644 --- a/mne/io/kit/constants.py +++ b/mne/io/kit/constants.py @@ -18,7 +18,7 @@ # channel parameters KIT.CALIB_FACTOR = 1.0 # mne_manual p.272 -KIT.RANGE = 1. # mne_manual p.272 +KIT.RANGE = 1.0 # mne_manual p.272 KIT.UNIT_MUL = FIFF.FIFF_UNITM_NONE # default is 0 mne_manual p.273 KIT.GAINS = [1, 2, 5, 10, 20, 50, 100, 200] @@ -129,11 +129,11 @@ KIT.CHANNEL_NULL: FIFF.FIFFV_MISC_CH, } KIT.CH_LABEL = { - KIT.CHANNEL_TRIGGER: 'TRIGGER', - KIT.CHANNEL_EEG: 'EEG', - KIT.CHANNEL_ECG: 'ECG', - KIT.CHANNEL_ETC: 'MISC', - KIT.CHANNEL_NULL: 'MISC', + KIT.CHANNEL_TRIGGER: "TRIGGER", + KIT.CHANNEL_EEG: "EEG", + KIT.CHANNEL_ECG: "ECG", + KIT.CHANNEL_ETC: "MISC", + KIT.CHANNEL_NULL: "MISC", } # Acquisition modes @@ -170,19 +170,19 @@ # Sensor layouts for plotting KIT_LAYOUT = { KIT.SYSTEM_AS: None, - KIT.SYSTEM_AS_2008: 'KIT-AS-2008', - KIT.SYSTEM_MQ_ADULT: 'KIT-160', - KIT.SYSTEM_MQ_CHILD: 'KIT-125', - KIT.SYSTEM_NYU_2008: 'KIT-157', - KIT.SYSTEM_NYU_2009: 'KIT-157', - KIT.SYSTEM_NYU_2010: 'KIT-157', + KIT.SYSTEM_AS_2008: "KIT-AS-2008", + KIT.SYSTEM_MQ_ADULT: "KIT-160", + KIT.SYSTEM_MQ_CHILD: "KIT-125", + KIT.SYSTEM_NYU_2008: "KIT-157", + KIT.SYSTEM_NYU_2009: "KIT-157", + KIT.SYSTEM_NYU_2010: "KIT-157", KIT.SYSTEM_NYU_2019: None, - KIT.SYSTEM_NYUAD_2011: 'KIT-AD', - KIT.SYSTEM_NYUAD_2012: 'KIT-AD', - KIT.SYSTEM_NYUAD_2014: 'KIT-AD', + KIT.SYSTEM_NYUAD_2011: "KIT-AD", + KIT.SYSTEM_NYUAD_2012: "KIT-AD", + KIT.SYSTEM_NYUAD_2014: "KIT-AD", KIT.SYSTEM_UMD_2004: None, KIT.SYSTEM_UMD_2014_07: None, - KIT.SYSTEM_UMD_2014_12: 'KIT-UMD-3', + KIT.SYSTEM_UMD_2014_12: "KIT-UMD-3", KIT.SYSTEM_UMD_2019_09: None, KIT.SYSTEM_YOKOGAWA_2017_01: None, KIT.SYSTEM_YOKOGAWA_2018_01: None, @@ -195,17 +195,17 @@ KIT.SYSTEM_AS_2008: None, KIT.SYSTEM_MQ_ADULT: None, KIT.SYSTEM_MQ_CHILD: None, - KIT.SYSTEM_NYU_2008: 'KIT-157', - KIT.SYSTEM_NYU_2009: 'KIT-157', - KIT.SYSTEM_NYU_2010: 'KIT-157', - KIT.SYSTEM_NYU_2019: 'KIT-NYU-2019', - KIT.SYSTEM_NYUAD_2011: 'KIT-208', - KIT.SYSTEM_NYUAD_2012: 'KIT-208', - KIT.SYSTEM_NYUAD_2014: 'KIT-208', - KIT.SYSTEM_UMD_2004: 'KIT-UMD-1', - KIT.SYSTEM_UMD_2014_07: 'KIT-UMD-2', - KIT.SYSTEM_UMD_2014_12: 'KIT-UMD-3', - KIT.SYSTEM_UMD_2019_09: 'KIT-UMD-4', + KIT.SYSTEM_NYU_2008: "KIT-157", + KIT.SYSTEM_NYU_2009: "KIT-157", + KIT.SYSTEM_NYU_2010: "KIT-157", + KIT.SYSTEM_NYU_2019: "KIT-NYU-2019", + KIT.SYSTEM_NYUAD_2011: "KIT-208", + KIT.SYSTEM_NYUAD_2012: "KIT-208", + KIT.SYSTEM_NYUAD_2014: "KIT-208", + KIT.SYSTEM_UMD_2004: "KIT-UMD-1", + KIT.SYSTEM_UMD_2014_07: "KIT-UMD-2", + KIT.SYSTEM_UMD_2014_12: "KIT-UMD-3", + KIT.SYSTEM_UMD_2019_09: "KIT-UMD-4", KIT.SYSTEM_YOKOGAWA_2017_01: None, KIT.SYSTEM_YOKOGAWA_2018_01: None, KIT.SYSTEM_YOKOGAWA_2020_08: None, @@ -213,31 +213,31 @@ } # Names displayed in the info dict description KIT_SYSNAMES = { - KIT.SYSTEM_MQ_ADULT: 'Macquarie Dept of Cognitive Science (Adult), 2006-', - KIT.SYSTEM_MQ_CHILD: 'Macquarie Dept of Cognitive Science (Child), 2006-', - KIT.SYSTEM_AS: 'Academia Sinica, -2008', - KIT.SYSTEM_AS_2008: 'Academia Sinica, 2008-', - KIT.SYSTEM_NYU_2008: 'NYU New York, 2008-9', - KIT.SYSTEM_NYU_2009: 'NYU New York, 2009-10', - KIT.SYSTEM_NYU_2010: 'NYU New York, 2010-', - KIT.SYSTEM_NYUAD_2011: 'New York University Abu Dhabi, 2011-12', - KIT.SYSTEM_NYUAD_2012: 'New York University Abu Dhabi, 2012-14', - KIT.SYSTEM_NYUAD_2014: 'New York University Abu Dhabi, 2014-', - KIT.SYSTEM_UMD_2004: 'University of Maryland, 2004-14', - KIT.SYSTEM_UMD_2014_07: 'University of Maryland, 2014', - KIT.SYSTEM_UMD_2014_12: 'University of Maryland, 2014-', - KIT.SYSTEM_UMD_2019_09: 'University of Maryland, 2019-', - KIT.SYSTEM_YOKOGAWA_2017_01: 'Yokogawa of Kanazawa (until 2017)', - KIT.SYSTEM_YOKOGAWA_2018_01: 'Yokogawa of Kanazawa (since 2018)', - KIT.SYSTEM_YOKOGAWA_2020_08: 'Yokogawa of Kanazawa (since August 2020)', - KIT.SYSTEM_EAGLE_TECHNOLOGY_PTB_2008: 'Eagle Technology MEG (KIT/Yokogawa style) at PTB (since 2008, software upgrade in 2018)', # noqa: E501 + KIT.SYSTEM_MQ_ADULT: "Macquarie Dept of Cognitive Science (Adult), 2006-", + KIT.SYSTEM_MQ_CHILD: "Macquarie Dept of Cognitive Science (Child), 2006-", + KIT.SYSTEM_AS: "Academia Sinica, -2008", + KIT.SYSTEM_AS_2008: "Academia Sinica, 2008-", + KIT.SYSTEM_NYU_2008: "NYU New York, 2008-9", + KIT.SYSTEM_NYU_2009: "NYU New York, 2009-10", + KIT.SYSTEM_NYU_2010: "NYU New York, 2010-", + KIT.SYSTEM_NYUAD_2011: "New York University Abu Dhabi, 2011-12", + KIT.SYSTEM_NYUAD_2012: "New York University Abu Dhabi, 2012-14", + KIT.SYSTEM_NYUAD_2014: "New York University Abu Dhabi, 2014-", + KIT.SYSTEM_UMD_2004: "University of Maryland, 2004-14", + KIT.SYSTEM_UMD_2014_07: "University of Maryland, 2014", + KIT.SYSTEM_UMD_2014_12: "University of Maryland, 2014-", + KIT.SYSTEM_UMD_2019_09: "University of Maryland, 2019-", + KIT.SYSTEM_YOKOGAWA_2017_01: "Yokogawa of Kanazawa (until 2017)", + KIT.SYSTEM_YOKOGAWA_2018_01: "Yokogawa of Kanazawa (since 2018)", + KIT.SYSTEM_YOKOGAWA_2020_08: "Yokogawa of Kanazawa (since August 2020)", + KIT.SYSTEM_EAGLE_TECHNOLOGY_PTB_2008: "Eagle Technology MEG (KIT/Yokogawa style) at PTB (since 2008, software upgrade in 2018)", # noqa: E501 } LEGACY_AMP_PARAMS = { - KIT.SYSTEM_NYU_2008: (5., 11.), - KIT.SYSTEM_NYU_2009: (5., 11.), - KIT.SYSTEM_NYU_2010: (5., 11.), - KIT.SYSTEM_UMD_2004: (5., 11.), + KIT.SYSTEM_NYU_2008: (5.0, 11.0), + KIT.SYSTEM_NYU_2009: (5.0, 11.0), + KIT.SYSTEM_NYU_2010: (5.0, 11.0), + KIT.SYSTEM_UMD_2004: (5.0, 11.0), } # Ones that we don't use are commented out diff --git a/mne/io/kit/coreg.py b/mne/io/kit/coreg.py index c3dda423a97..6ea823913d7 100644 --- a/mne/io/kit/coreg.py +++ b/mne/io/kit/coreg.py @@ -14,13 +14,17 @@ from .constants import KIT, FIFF from .._digitization import _make_dig_points -from ...transforms import (Transform, apply_trans, get_ras_to_neuromag_trans, - als_ras_trans) +from ...transforms import ( + Transform, + apply_trans, + get_ras_to_neuromag_trans, + als_ras_trans, +) from ...utils import warn, _check_option, _check_fname -INT32 = ' KIT.DIG_POINTS: hsp = _decimate_points(hsp, res=0.005) n_new = len(hsp) - warn("The selected head shape contained {n_in} points, which is " - "more than recommended ({n_rec}), and was automatically " - "downsampled to {n_new} points. The preferred way to " - "downsample is using FastScan.".format( - n_in=n_pts, n_rec=KIT.DIG_POINTS, n_new=n_new)) + warn( + "The selected head shape contained {n_in} points, which is " + "more than recommended ({n_rec}), and was automatically " + "downsampled to {n_new} points. The preferred way to " + "downsample is using FastScan.".format( + n_in=n_pts, n_rec=KIT.DIG_POINTS, n_new=n_new + ) + ) if isinstance(elp, (str, Path, PathLike)): elp_points = _read_dig_kit(elp) if len(elp_points) != 8: - raise ValueError("File %r should contain 8 points; got shape " - "%s." % (elp, elp_points.shape)) + raise ValueError( + "File %r should contain 8 points; got shape " + "%s." % (elp, elp_points.shape) + ) elp = elp_points elif len(elp) not in (6, 7, 8): - raise ValueError("ELP should contain 6 ~ 8 points; got shape " - "%s." % (elp.shape,)) + raise ValueError( + "ELP should contain 6 ~ 8 points; got shape " "%s." % (elp.shape,) + ) if isinstance(mrk, (str, Path, PathLike)): mrk = read_mrk(mrk) @@ -166,55 +175,64 @@ def _set_dig_kit(mrk, elp, hsp, eeg): eeg = OrderedDict((k, apply_trans(nmtrans, p)) for k, p in eeg.items()) # device head transform - trans = fit_matched_points(tgt_pts=elp[3:], src_pts=mrk, out='trans') + trans = fit_matched_points(tgt_pts=elp[3:], src_pts=mrk, out="trans") nasion, lpa, rpa = elp[:3] elp = elp[3:] dig_points = _make_dig_points(nasion, lpa, rpa, elp, hsp, dig_ch_pos=eeg) - dev_head_t = Transform('meg', 'head', trans) - - hpi_results = [dict(dig_points=[ - dict(ident=ci, r=r, kind=FIFF.FIFFV_POINT_HPI, - coord_frame=FIFF.FIFFV_COORD_UNKNOWN) - for ci, r in enumerate(mrk)], coord_trans=dev_head_t)] + dev_head_t = Transform("meg", "head", trans) + + hpi_results = [ + dict( + dig_points=[ + dict( + ident=ci, + r=r, + kind=FIFF.FIFFV_POINT_HPI, + coord_frame=FIFF.FIFFV_COORD_UNKNOWN, + ) + for ci, r in enumerate(mrk) + ], + coord_trans=dev_head_t, + ) + ] return dig_points, dev_head_t, hpi_results -def _read_dig_kit(fname, unit='auto'): +def _read_dig_kit(fname, unit="auto"): # Read dig points from a file and return ndarray, using FastSCAN for .txt from ...channels.montage import ( - read_polhemus_fastscan, read_dig_polhemus_isotrak, read_custom_montage, - _check_dig_shape) - - fname = _check_fname( - fname, "read", must_exist=True, name="hsp or elp file" - ) - assert unit in ('auto', 'm', 'mm') - _check_option( - "file extension", fname.suffix, (".hsp", ".elp", ".mat", ".txt") + read_polhemus_fastscan, + read_dig_polhemus_isotrak, + read_custom_montage, + _check_dig_shape, ) + + fname = _check_fname(fname, "read", must_exist=True, name="hsp or elp file") + assert unit in ("auto", "m", "mm") + _check_option("file extension", fname.suffix, (".hsp", ".elp", ".mat", ".txt")) if fname.suffix == ".txt": - unit = 'mm' if unit == 'auto' else unit - out = read_polhemus_fastscan(fname, unit=unit, - on_header_missing='ignore') + unit = "mm" if unit == "auto" else unit + out = read_polhemus_fastscan(fname, unit=unit, on_header_missing="ignore") elif fname.suffix in (".hsp", ".elp"): - unit = 'm' if unit == 'auto' else unit + unit = "m" if unit == "auto" else unit mon = read_dig_polhemus_isotrak(fname, unit=unit) if fname.suffix == ".hsp": - dig = [d['r'] for d in mon.dig - if d['kind'] != FIFF.FIFFV_POINT_CARDINAL] + dig = [d["r"] for d in mon.dig if d["kind"] != FIFF.FIFFV_POINT_CARDINAL] else: - dig = [d['r'] for d in mon.dig] - if dig and \ - mon.dig[0]['kind'] == FIFF.FIFFV_POINT_CARDINAL and \ - mon.dig[0]['ident'] == FIFF.FIFFV_POINT_LPA: + dig = [d["r"] for d in mon.dig] + if ( + dig + and mon.dig[0]["kind"] == FIFF.FIFFV_POINT_CARDINAL + and mon.dig[0]["ident"] == FIFF.FIFFV_POINT_LPA + ): # LPA, Nasion, RPA -> NLR dig[:3] = [dig[1], dig[0], dig[2]] out = np.array(dig, float) else: assert fname.suffix == ".mat" - out = np.array([d['r'] for d in read_custom_montage(fname).dig]) + out = np.array([d["r"] for d in read_custom_montage(fname).dig]) _check_dig_shape(out) return out diff --git a/mne/io/kit/kit.py b/mne/io/kit/kit.py index 0a433f6203c..082202707fb 100644 --- a/mne/io/kit/kit.py +++ b/mne/io/kit/kit.py @@ -17,8 +17,15 @@ import numpy as np from ..pick import pick_types -from ...utils import (verbose, logger, warn, fill_doc, _check_option, - _stamp_to_dt, _check_fname) +from ...utils import ( + verbose, + logger, + warn, + fill_doc, + _check_option, + _stamp_to_dt, + _check_fname, +) from ...transforms import apply_trans, als_ras_trans from ..base import BaseRaw from ..utils import _mult_cal_one @@ -30,32 +37,36 @@ from ...event import read_events -FLOAT64 = '': + elif stim == ">": stim = picks else: - raise ValueError("stim needs to be list of int, '>' or " - "'<', not %r" % str(stim)) + raise ValueError( + "stim needs to be list of int, '>' or " + "'<', not %r" % str(stim) + ) else: stim = np.asarray(stim, int) - if stim.max() >= self._raw_extras[0]['nchan']: + if stim.max() >= self._raw_extras[0]["nchan"]: raise ValueError( - 'Got stim=%s, but sqd file only has %i channels' % - (stim, self._raw_extras[0]['nchan'])) + "Got stim=%s, but sqd file only has %i channels" + % (stim, self._raw_extras[0]["nchan"]) + ) # modify info - nchan = self._raw_extras[0]['nchan'] + 1 - info['chs'].append(dict( - cal=KIT.CALIB_FACTOR, logno=nchan, scanno=nchan, range=1.0, - unit=FIFF.FIFF_UNIT_NONE, unit_mul=FIFF.FIFF_UNITM_NONE, - ch_name='STI 014', - coil_type=FIFF.FIFFV_COIL_NONE, loc=np.full(12, np.nan), - kind=FIFF.FIFFV_STIM_CH, coord_frame=FIFF.FIFFV_COORD_UNKNOWN)) + nchan = self._raw_extras[0]["nchan"] + 1 + info["chs"].append( + dict( + cal=KIT.CALIB_FACTOR, + logno=nchan, + scanno=nchan, + range=1.0, + unit=FIFF.FIFF_UNIT_NONE, + unit_mul=FIFF.FIFF_UNITM_NONE, + ch_name="STI 014", + coil_type=FIFF.FIFFV_COIL_NONE, + loc=np.full(12, np.nan), + kind=FIFF.FIFFV_STIM_CH, + coord_frame=FIFF.FIFFV_COORD_UNKNOWN, + ) + ) info._update_redundant() - self._raw_extras[0]['stim'] = stim - self._raw_extras[0]['stim_code'] = stim_code + self._raw_extras[0]["stim"] = stim + self._raw_extras[0]["stim_code"] = stim_code def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): """Read a chunk of raw data.""" sqd = self._raw_extras[fi] - nchan = sqd['nchan'] + nchan = sqd["nchan"] data_left = (stop - start) * nchan - conv_factor = sqd['conv_factor'] + conv_factor = sqd["conv_factor"] - n_bytes = sqd['dtype'].itemsize + n_bytes = sqd["dtype"].itemsize assert n_bytes in (2, 4) # Read up to 100 MB of data at a time. blk_size = min(data_left, (100000000 // n_bytes // nchan) * nchan) - with open(self._filenames[fi], 'rb', buffering=0) as fid: + with open(self._filenames[fi], "rb", buffering=0) as fid: # extract data pointer = start * nchan * n_bytes - fid.seek(sqd['dirs'][KIT.DIR_INDEX_RAW_DATA]['offset'] + pointer) - stim = sqd['stim'] + fid.seek(sqd["dirs"][KIT.DIR_INDEX_RAW_DATA]["offset"] + pointer) + stim = sqd["stim"] for blk_start in np.arange(0, data_left, blk_size) // nchan: blk_size = min(blk_size, data_left - blk_start * nchan) - block = np.fromfile(fid, dtype=sqd['dtype'], count=blk_size) - block = block.reshape(nchan, -1, order='F').astype(float) + block = np.fromfile(fid, dtype=sqd["dtype"], count=blk_size) + block = block.reshape(nchan, -1, order="F").astype(float) blk_stop = blk_start + block.shape[1] data_view = data[:, blk_start:blk_stop] block *= conv_factor @@ -245,8 +283,12 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): # Create a synthetic stim channel if stim is not None: stim_ch = _make_stim_channel( - block[stim, :], sqd['slope'], sqd['stimthresh'], - sqd['stim_code'], stim) + block[stim, :], + sqd["slope"], + sqd["stimthresh"], + sqd["stim_code"], + stim, + ) block = np.vstack((block, stim_ch)) _mult_cal_one(data_view, block, idx, cals, mult) @@ -255,25 +297,24 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): def _default_stim_chs(info): """Return default stim channels for SQD files.""" - return pick_types(info, meg=False, ref_meg=False, misc=True, - exclude=[])[:8] + return pick_types(info, meg=False, ref_meg=False, misc=True, exclude=[])[:8] -def _make_stim_channel(trigger_chs, slope, threshold, stim_code, - trigger_values): +def _make_stim_channel(trigger_chs, slope, threshold, stim_code, trigger_values): """Create synthetic stim channel from multiple trigger channels.""" - if slope == '+': + if slope == "+": trig_chs_bin = trigger_chs > threshold - elif slope == '-': + elif slope == "-": trig_chs_bin = trigger_chs < threshold else: raise ValueError("slope needs to be '+' or '-'") # trigger value - if stim_code == 'binary': + if stim_code == "binary": trigger_values = 2 ** np.arange(len(trigger_chs)) - elif stim_code != 'channel': - raise ValueError("stim_code must be 'binary' or 'channel', got %s" % - repr(stim_code)) + elif stim_code != "channel": + raise ValueError( + "stim_code must be 'binary' or 'channel', got %s" % repr(stim_code) + ) trig_chs = trig_chs_bin * trigger_values[:, np.newaxis] return np.array(trig_chs.sum(axis=0), ndmin=2) @@ -322,53 +363,82 @@ class EpochsKIT(BaseEpochs): """ @verbose - def __init__(self, input_fname, events, event_id=None, tmin=0, - baseline=None, reject=None, flat=None, reject_tmin=None, - reject_tmax=None, mrk=None, elp=None, hsp=None, - allow_unknown_format=False, standardize_names=None, - verbose=None): # noqa: D102 - + def __init__( + self, + input_fname, + events, + event_id=None, + tmin=0, + baseline=None, + reject=None, + flat=None, + reject_tmin=None, + reject_tmax=None, + mrk=None, + elp=None, + hsp=None, + allow_unknown_format=False, + standardize_names=None, + verbose=None, + ): # noqa: D102 if isinstance(events, (str, PathLike, Path)): events = read_events(events) input_fname = str( _check_fname(fname=input_fname, must_exist=True, overwrite="read") ) - logger.info('Extracting KIT Parameters from %s...' % input_fname) + logger.info("Extracting KIT Parameters from %s..." % input_fname) self.info, kit_info = get_kit_info( - input_fname, allow_unknown_format, standardize_names) + input_fname, allow_unknown_format, standardize_names + ) kit_info.update(input_fname=input_fname) self._raw_extras = [kit_info] self._filenames = [] - if len(events) != self._raw_extras[0]['n_epochs']: - raise ValueError('Event list does not match number of epochs.') + if len(events) != self._raw_extras[0]["n_epochs"]: + raise ValueError("Event list does not match number of epochs.") - if self._raw_extras[0]['acq_type'] == KIT.EPOCHS: - self._raw_extras[0]['data_length'] = KIT.INT + if self._raw_extras[0]["acq_type"] == KIT.EPOCHS: + self._raw_extras[0]["data_length"] = KIT.INT else: - raise TypeError('SQD file contains raw data, not epochs or ' - 'average. Wrong reader.') + raise TypeError( + "SQD file contains raw data, not epochs or " "average. Wrong reader." + ) if event_id is None: # convert to int to make typing-checks happy event_id = {str(e): int(e) for e in np.unique(events[:, 2])} for key, val in event_id.items(): if val not in events[:, 2]: - raise ValueError('No matching events found for %s ' - '(event id %i)' % (key, val)) + raise ValueError( + "No matching events found for %s " "(event id %i)" % (key, val) + ) data = self._read_kit_data() - assert data.shape == (self._raw_extras[0]['n_epochs'], - self.info['nchan'], - self._raw_extras[0]['frame_length']) - tmax = ((data.shape[2] - 1) / self.info['sfreq']) + tmin + assert data.shape == ( + self._raw_extras[0]["n_epochs"], + self.info["nchan"], + self._raw_extras[0]["frame_length"], + ) + tmax = ((data.shape[2] - 1) / self.info["sfreq"]) + tmin super(EpochsKIT, self).__init__( - self.info, data, events, event_id, tmin, tmax, baseline, - reject=reject, flat=flat, reject_tmin=reject_tmin, - reject_tmax=reject_tmax, filename=input_fname, verbose=verbose) + self.info, + data, + events, + event_id, + tmin, + tmax, + baseline, + reject=reject, + flat=flat, + reject_tmin=reject_tmin, + reject_tmax=reject_tmax, + filename=input_fname, + verbose=verbose, + ) self.info = _call_digitization( - info=self.info, mrk=mrk, elp=elp, hsp=hsp, kit_info=kit_info) - logger.info('Ready.') + info=self.info, mrk=mrk, elp=elp, hsp=hsp, kit_info=kit_info + ) + logger.info("Ready.") def _read_kit_data(self): """Read epochs data. @@ -381,19 +451,19 @@ def _read_kit_data(self): returns the time values corresponding to the samples. """ info = self._raw_extras[0] - epoch_length = info['frame_length'] - n_epochs = info['n_epochs'] - n_samples = info['n_samples'] - input_fname = info['input_fname'] - dtype = info['dtype'] - nchan = info['nchan'] - - with open(input_fname, 'rb', buffering=0) as fid: - fid.seek(info['dirs'][KIT.DIR_INDEX_RAW_DATA]['offset']) + epoch_length = info["frame_length"] + n_epochs = info["n_epochs"] + n_samples = info["n_samples"] + input_fname = info["input_fname"] + dtype = info["dtype"] + nchan = info["nchan"] + + with open(input_fname, "rb", buffering=0) as fid: + fid.seek(info["dirs"][KIT.DIR_INDEX_RAW_DATA]["offset"]) count = n_samples * nchan data = np.fromfile(fid, dtype=dtype, count=count) data = data.reshape((n_samples, nchan)).T - data = data * info['conv_factor'] + data = data * info["conv_factor"] data = data.reshape((nchan, n_epochs, epoch_length)) data = data.transpose((1, 0, 2)) @@ -401,26 +471,27 @@ def _read_kit_data(self): def _read_dir(fid): - return dict(offset=np.fromfile(fid, UINT32, 1)[0], - size=np.fromfile(fid, INT32, 1)[0], - max_count=np.fromfile(fid, INT32, 1)[0], - count=np.fromfile(fid, INT32, 1)[0]) + return dict( + offset=np.fromfile(fid, UINT32, 1)[0], + size=np.fromfile(fid, INT32, 1)[0], + max_count=np.fromfile(fid, INT32, 1)[0], + count=np.fromfile(fid, INT32, 1)[0], + ) @verbose def _read_dirs(fid, verbose=None): dirs = list() dirs.append(_read_dir(fid)) - for ii in range(dirs[0]['count'] - 1): - logger.debug(f' KIT dir entry {ii} @ {fid.tell()}') + for ii in range(dirs[0]["count"] - 1): + logger.debug(f" KIT dir entry {ii} @ {fid.tell()}") dirs.append(_read_dir(fid)) - assert len(dirs) == dirs[KIT.DIR_INDEX_DIR]['count'] + assert len(dirs) == dirs[KIT.DIR_INDEX_DIR]["count"] return dirs @verbose -def get_kit_info(rawfile, allow_unknown_format, standardize_names=None, - verbose=None): +def get_kit_info(rawfile, allow_unknown_format, standardize_names=None, verbose=None): """Extract all the information from the sqd/con file. Parameters @@ -440,18 +511,18 @@ def get_kit_info(rawfile, allow_unknown_format, standardize_names=None, A dict containing all the sqd parameter settings. """ sqd = dict() - sqd['rawfile'] = rawfile + sqd["rawfile"] = rawfile unsupported_format = False - with open(rawfile, 'rb', buffering=0) as fid: # buffering=0 for np bug + with open(rawfile, "rb", buffering=0) as fid: # buffering=0 for np bug # # directories (0) # - sqd['dirs'] = dirs = _read_dirs(fid) + sqd["dirs"] = dirs = _read_dirs(fid) # # system (1) # - fid.seek(dirs[KIT.DIR_INDEX_SYSTEM]['offset']) + fid.seek(dirs[KIT.DIR_INDEX_SYSTEM]["offset"]) # check file format version version, revision = np.fromfile(fid, INT32, 2) if version < 2 or (version == 2 and revision < 3): @@ -463,8 +534,9 @@ def get_kit_info(rawfile, allow_unknown_format, standardize_names=None, raise UnsupportedKITFormat( version_string, "SQD file format %s is not officially supported. " - "Set allow_unknown_format=True to load it anyways." % - (version_string,)) + "Set allow_unknown_format=True to load it anyways." + % (version_string,), + ) sysid = np.fromfile(fid, INT32, 1)[0] # basic info @@ -472,7 +544,7 @@ def get_kit_info(rawfile, allow_unknown_format, standardize_names=None, # model name model_name = _read_name(fid, n=128) # channels - sqd['nchan'] = channel_count = int(np.fromfile(fid, INT32, 1)[0]) + sqd["nchan"] = channel_count = int(np.fromfile(fid, INT32, 1)[0]) comment = _read_name(fid, n=256) create_time, last_modified_time = np.fromfile(fid, INT32, 2) fid.seek(KIT.INT * 3, SEEK_CUR) # reserved @@ -490,12 +562,12 @@ def get_kit_info(rawfile, allow_unknown_format, standardize_names=None, else: adc_range = np.fromfile(fid, FLOAT64, 1)[0] adc_polarity, adc_allocated, adc_stored = np.fromfile(fid, INT32, 3) - system_name = system_name.replace('\x00', '') - system_name = system_name.strip().replace('\n', '/') - model_name = model_name.replace('\x00', '') - model_name = model_name.strip().replace('\n', '/') + system_name = system_name.replace("\x00", "") + system_name = system_name.strip().replace("\n", "/") + model_name = model_name.replace("\x00", "") + model_name = model_name.strip().replace("\n", "/") - full_version = f'V{version:d}R{revision:03d}' + full_version = f"V{version:d}R{revision:03d}" logger.debug("SQD file basic information:") logger.debug("Meg160 version = %s", full_version) logger.debug("System ID = %i", sysid) @@ -507,37 +579,37 @@ def get_kit_info(rawfile, allow_unknown_format, standardize_names=None, logger.debug("FLL type = %i", fll_type) logger.debug("Trigger type = %i", trigger_type) logger.debug("A/D board type = %i", adboard_type) - logger.debug("ADC range = +/-%s[V]", adc_range / 2.) + logger.debug("ADC range = +/-%s[V]", adc_range / 2.0) logger.debug("ADC allocate = %i[bit]", adc_allocated) logger.debug("ADC bit = %i[bit]", adc_stored) # MGH description: 'acquisition (megacq) VectorView system at NMR-MGH' - description = \ - f'{system_name} ({sysid}) {full_version} {model_name}' + description = f"{system_name} ({sysid}) {full_version} {model_name}" assert adc_allocated % 8 == 0 - sqd['dtype'] = np.dtype(f'%d, check ' - 'your data for correctness, including channel scales and ' - 'filter settings!' - % (system_name, model_name, sysid, fll_type, use_fll_type)) + use_fll_type = fll_types[np.searchsorted(fll_types, fll_type) - 1] + warn( + "Unknown site filter settings (FLL) for system " + '"%s" model "%s" (ID %s), will assume FLL %d->%d, check ' + "your data for correctness, including channel scales and " + "filter settings!" + % (system_name, model_name, sysid, fll_type, use_fll_type) + ) fll_type = use_fll_type # # channel information (4) # chan_dir = dirs[KIT.DIR_INDEX_CHANNELS] - chan_offset, chan_size = chan_dir['offset'], chan_dir['size'] - sqd['channels'] = channels = [] + chan_offset, chan_size = chan_dir["offset"], chan_dir["size"] + sqd["channels"] = channels = [] exg_gains = list() for i in range(channel_count): fid.seek(chan_offset + chan_size * i) - channel_type, = np.fromfile(fid, INT32, 1) + (channel_type,) = np.fromfile(fid, INT32, 1) # System 52 mislabeled reference channels as NULL. This was fixed # in system 53; not sure about 51... if sysid == 52 and i < 160 and channel_type == KIT.CHANNEL_NULL: @@ -547,31 +619,36 @@ def get_kit_info(rawfile, allow_unknown_format, standardize_names=None, if channel_type not in KIT.CH_TO_FIFF_COIL: raise NotImplementedError( "KIT channel type %i can not be read. Please contact " - "the mne-python developers." % channel_type) - channels.append({ - 'type': channel_type, - # (x, y, z, theta, phi) for all MEG channels. Some channel - # types have additional information which we're not using. - 'loc': np.fromfile(fid, dtype=FLOAT64, count=5), - }) + "the mne-python developers." % channel_type + ) + channels.append( + { + "type": channel_type, + # (x, y, z, theta, phi) for all MEG channels. Some channel + # types have additional information which we're not using. + "loc": np.fromfile(fid, dtype=FLOAT64, count=5), + } + ) if channel_type in KIT.CHANNEL_NAME_NCHAR: fid.seek(16, SEEK_CUR) # misc fields - channels[-1]['name'] = _read_name(fid, channel_type) + channels[-1]["name"] = _read_name(fid, channel_type) elif channel_type in KIT.CHANNELS_MISC: - channel_no, = np.fromfile(fid, INT32, 1) + (channel_no,) = np.fromfile(fid, INT32, 1) fid.seek(4, SEEK_CUR) name = _read_name(fid, channel_type) - channels.append({ - 'type': channel_type, - 'no': channel_no, - 'name': name, - }) + channels.append( + { + "type": channel_type, + "no": channel_no, + "name": name, + } + ) if channel_type in (KIT.CHANNEL_EEG, KIT.CHANNEL_ECG): offset = 6 if channel_type == KIT.CHANNEL_EEG else 8 fid.seek(offset, SEEK_CUR) exg_gains.append(np.fromfile(fid, FLOAT64, 1)[0]) elif channel_type == KIT.CHANNEL_NULL: - channels.append({'type': channel_type}) + channels.append({"type": channel_type}) else: raise OSError("Unknown KIT channel type: %i" % channel_type) exg_gains = np.array(exg_gains) @@ -583,7 +660,7 @@ def get_kit_info(rawfile, allow_unknown_format, standardize_names=None, # only sensor channels requires gain. the additional misc channels # (trigger channels, audio and voice channels) are passed # through unaffected - fid.seek(dirs[KIT.DIR_INDEX_CALIBRATION]['offset']) + fid.seek(dirs[KIT.DIR_INDEX_CALIBRATION]["offset"]) # (offset [Volt], gain [Tesla/Volt]) for each channel sensitivity = np.fromfile(fid, dtype=FLOAT64, count=channel_count * 2) sensitivity.shape = (channel_count, 2) @@ -593,14 +670,14 @@ def get_kit_info(rawfile, allow_unknown_format, standardize_names=None, # # amplifier gain (7) # - fid.seek(dirs[KIT.DIR_INDEX_AMP_FILTER]['offset']) + fid.seek(dirs[KIT.DIR_INDEX_AMP_FILTER]["offset"]) amp_data = np.fromfile(fid, INT32, 1)[0] if fll_type >= 100: # Kapper Type # gain: mask bit gain1 = (amp_data & 0x00007000) >> 12 gain2 = (amp_data & 0x70000000) >> 28 gain3 = (amp_data & 0x07000000) >> 24 - amp_gain = (KIT.GAINS[gain1] * KIT.GAINS[gain2] * KIT.GAINS[gain3]) + amp_gain = KIT.GAINS[gain1] * KIT.GAINS[gain2] * KIT.GAINS[gain3] # filter settings hpf = (amp_data & 0x00000700) >> 8 lpf = (amp_data & 0x00070000) >> 16 @@ -613,34 +690,36 @@ def get_kit_info(rawfile, allow_unknown_format, standardize_names=None, # filter settings hpf = (amp_data & 0x007) >> 4 lpf = (amp_data & 0x0700) >> 8 - bef = (amp_data & 0xc000) >> 14 + bef = (amp_data & 0xC000) >> 14 hpf_options, lpf_options, bef_options = KIT.FLL_SETTINGS[fll_type] - sqd['highpass'] = KIT.HPFS[hpf_options][hpf] - sqd['lowpass'] = KIT.LPFS[lpf_options][lpf] - sqd['notch'] = KIT.BEFS[bef_options][bef] + sqd["highpass"] = KIT.HPFS[hpf_options][hpf] + sqd["lowpass"] = KIT.LPFS[lpf_options][lpf] + sqd["notch"] = KIT.BEFS[bef_options][bef] # # Acquisition Parameters (8) # - fid.seek(dirs[KIT.DIR_INDEX_ACQ_COND]['offset']) - sqd['acq_type'], = acq_type, = np.fromfile(fid, INT32, 1) - sqd['sfreq'], = np.fromfile(fid, FLOAT64, 1) + fid.seek(dirs[KIT.DIR_INDEX_ACQ_COND]["offset"]) + (sqd["acq_type"],) = (acq_type,) = np.fromfile(fid, INT32, 1) + (sqd["sfreq"],) = np.fromfile(fid, FLOAT64, 1) if acq_type == KIT.CONTINUOUS: # samples_count, = np.fromfile(fid, INT32, 1) fid.seek(KIT.INT, SEEK_CUR) - sqd['n_samples'], = np.fromfile(fid, INT32, 1) + (sqd["n_samples"],) = np.fromfile(fid, INT32, 1) elif acq_type == KIT.EVOKED or acq_type == KIT.EPOCHS: - sqd['frame_length'], = np.fromfile(fid, INT32, 1) - sqd['pretrigger_length'], = np.fromfile(fid, INT32, 1) - sqd['average_count'], = np.fromfile(fid, INT32, 1) - sqd['n_epochs'], = np.fromfile(fid, INT32, 1) + (sqd["frame_length"],) = np.fromfile(fid, INT32, 1) + (sqd["pretrigger_length"],) = np.fromfile(fid, INT32, 1) + (sqd["average_count"],) = np.fromfile(fid, INT32, 1) + (sqd["n_epochs"],) = np.fromfile(fid, INT32, 1) if acq_type == KIT.EVOKED: - sqd['n_samples'] = sqd['frame_length'] + sqd["n_samples"] = sqd["frame_length"] else: - sqd['n_samples'] = sqd['frame_length'] * sqd['n_epochs'] + sqd["n_samples"] = sqd["frame_length"] * sqd["n_epochs"] else: - raise OSError("Invalid acquisition type: %i. Your file is neither " - "continuous nor epoched data." % (acq_type,)) + raise OSError( + "Invalid acquisition type: %i. Your file is neither " + "continuous nor epoched data." % (acq_type,) + ) # # digitization information (12 and 26) @@ -649,10 +728,10 @@ def get_kit_info(rawfile, allow_unknown_format, standardize_names=None, cor_dir = dirs[KIT.DIR_INDEX_COREG] dig = dict() hsp = list() - if dig_dir['count'] > 0 and cor_dir['count'] > 0: + if dig_dir["count"] > 0 and cor_dir["count"] > 0: # directories (0) - fid.seek(dig_dir['offset']) - for _ in range(dig_dir['count']): + fid.seek(dig_dir["offset"]) + for _ in range(dig_dir["count"]): name = _read_name(fid, n=8).strip() # Sometimes there are mismatches (e.g., AFz vs AFZ) between # the channel name and its digitized, name, so let's be case @@ -668,38 +747,48 @@ def get_kit_info(rawfile, allow_unknown_format, standardize_names=None, # nasion, lpa, rpa, HPI in native space elp = [] for key in ( - 'fidnz', 'fidt9', 'fidt10', - 'hpi_1', 'hpi_2', 'hpi_3', 'hpi_4', 'hpi_5'): + "fidnz", + "fidt9", + "fidt10", + "hpi_1", + "hpi_2", + "hpi_3", + "hpi_4", + "hpi_5", + ): if key in dig and np.isfinite(dig[key]).all(): elp.append(dig.pop(key)) elp = np.array(elp) hsp = np.array(hsp, float).reshape(-1, 3) if elp.shape not in ((6, 3), (7, 3), (8, 3)): - raise RuntimeError( - f'Fewer than 3 HPI coils found, got {len(elp) - 3}') + raise RuntimeError(f"Fewer than 3 HPI coils found, got {len(elp) - 3}") # coregistration - fid.seek(cor_dir['offset']) + fid.seek(cor_dir["offset"]) mrk = np.zeros((elp.shape[0] - 3, 3)) meg_done = [True] * 5 - for _ in range(cor_dir['count']): + for _ in range(cor_dir["count"]): done = np.fromfile(fid, INT32, 1)[0] - fid.seek(16 * KIT.DOUBLE + # meg_to_mri - 16 * KIT.DOUBLE, # mri_to_meg - SEEK_CUR) + fid.seek( + 16 * KIT.DOUBLE + 16 * KIT.DOUBLE, # meg_to_mri # mri_to_meg + SEEK_CUR, + ) marker_count = np.fromfile(fid, INT32, 1)[0] if not done: continue assert marker_count >= len(mrk) for mi in range(len(mrk)): - mri_type, meg_type, mri_done, this_meg_done = \ - np.fromfile(fid, INT32, 4) + mri_type, meg_type, mri_done, this_meg_done = np.fromfile( + fid, INT32, 4 + ) meg_done[mi] = bool(this_meg_done) fid.seek(3 * KIT.DOUBLE, SEEK_CUR) # mri_pos mrk[mi] = np.fromfile(fid, FLOAT64, 3) fid.seek(256, SEEK_CUR) # marker_file (char) if not all(meg_done): - logger.info(f'Keeping {sum(meg_done)}/{len(meg_done)} HPI ' - 'coils that were digitized') + logger.info( + f"Keeping {sum(meg_done)}/{len(meg_done)} HPI " + "coils that were digitized" + ) elp = elp[[True] * 3 + meg_done] mrk = mrk[meg_done] sqd.update(hsp=hsp, elp=elp, mrk=mrk) @@ -707,11 +796,10 @@ def get_kit_info(rawfile, allow_unknown_format, standardize_names=None, # precompute conversion factor for reading data if unsupported_format: if sysid not in LEGACY_AMP_PARAMS: - raise OSError("Legacy parameters for system ID %i unavailable" % - (sysid,)) + raise OSError("Legacy parameters for system ID %i unavailable" % (sysid,)) adc_range, adc_stored = LEGACY_AMP_PARAMS[sysid] - is_meg = np.array([ch['type'] in KIT.CHANNELS_MEG for ch in channels]) - ad_to_volt = adc_range / (2. ** adc_stored) + is_meg = np.array([ch["type"] in KIT.CHANNELS_MEG for ch in channels]) + ad_to_volt = adc_range / (2.0**adc_stored) ad_to_tesla = ad_to_volt / amp_gain * channel_gain conv_factor = np.where(is_meg, ad_to_tesla, ad_to_volt) # XXX this is a bit of a hack. Should probably do this more cleanly at @@ -719,33 +807,35 @@ def get_kit_info(rawfile, allow_unknown_format, standardize_names=None, # the test files with known amplitudes. The conv_factors need to be # replaced by these values otherwise we're off by a factor off 5000.0 # for the EEG data. - is_exg = [ch['type'] in (KIT.CHANNEL_EEG, KIT.CHANNEL_ECG) - for ch in channels] - exg_gains /= 2. ** (adc_stored - 14) + is_exg = [ch["type"] in (KIT.CHANNEL_EEG, KIT.CHANNEL_ECG) for ch in channels] + exg_gains /= 2.0 ** (adc_stored - 14) exg_gains[exg_gains == 0] = ad_to_volt conv_factor[is_exg] = exg_gains - sqd['conv_factor'] = conv_factor[:, np.newaxis] + sqd["conv_factor"] = conv_factor[:, np.newaxis] # Create raw.info dict for raw fif object with SQD data - info = _empty_info(float(sqd['sfreq'])) - info.update(meas_date=_stamp_to_dt((create_time, 0)), - lowpass=sqd['lowpass'], - highpass=sqd['highpass'], kit_system_id=sysid, - description=description) + info = _empty_info(float(sqd["sfreq"])) + info.update( + meas_date=_stamp_to_dt((create_time, 0)), + lowpass=sqd["lowpass"], + highpass=sqd["highpass"], + kit_system_id=sysid, + description=description, + ) # Creates a list of dicts of meg channels for raw.info - logger.info('Setting channel info structure...') - info['chs'] = fiff_channels = [] + logger.info("Setting channel info structure...") + info["chs"] = fiff_channels = [] channel_index = defaultdict(lambda: 0) - sqd['eeg_dig'] = OrderedDict() + sqd["eeg_dig"] = OrderedDict() for idx, ch in enumerate(channels, 1): - if ch['type'] in KIT.CHANNELS_MEG: - ch_name = ch.get('name', '') - if ch_name == '' or standardize_names: - ch_name = 'MEG %03d' % idx + if ch["type"] in KIT.CHANNELS_MEG: + ch_name = ch.get("name", "") + if ch_name == "" or standardize_names: + ch_name = "MEG %03d" % idx # create three orthogonal vector # ch_angles[0]: theta, ch_angles[1]: phi - theta, phi = np.radians(ch['loc'][3:]) + theta, phi = np.radians(ch["loc"][3:]) x = sin(theta) * cos(phi) y = sin(theta) * sin(phi) z = cos(theta) @@ -765,29 +855,38 @@ def get_kit_info(rawfile, allow_unknown_format, standardize_names=None, vec_x /= np.linalg.norm(vec_x) vec_y = np.cross(vec_z, vec_x) # transform to Neuromag like coordinate space - vecs = np.vstack((ch['loc'][:3], vec_x, vec_y, vec_z)) + vecs = np.vstack((ch["loc"][:3], vec_x, vec_y, vec_z)) vecs = apply_trans(als_ras_trans, vecs) unit = FIFF.FIFF_UNIT_T loc = vecs.ravel() else: - ch_type_label = KIT.CH_LABEL[ch['type']] + ch_type_label = KIT.CH_LABEL[ch["type"]] channel_index[ch_type_label] += 1 ch_type_index = channel_index[ch_type_label] - ch_name = ch.get('name', '') + ch_name = ch.get("name", "") eeg_name = ch_name.lower() # some files have all EEG labeled as EEG - if ch_name in ('', 'EEG') or standardize_names: - ch_name = '%s %03i' % (ch_type_label, ch_type_index) + if ch_name in ("", "EEG") or standardize_names: + ch_name = "%s %03i" % (ch_type_label, ch_type_index) unit = FIFF.FIFF_UNIT_V loc = np.zeros(12) if eeg_name and eeg_name in dig: - loc[:3] = sqd['eeg_dig'][eeg_name] = dig[eeg_name] - fiff_channels.append(dict( - cal=KIT.CALIB_FACTOR, logno=idx, scanno=idx, range=KIT.RANGE, - unit=unit, unit_mul=KIT.UNIT_MUL, ch_name=ch_name, - coord_frame=FIFF.FIFFV_COORD_DEVICE, - coil_type=KIT.CH_TO_FIFF_COIL[ch['type']], - kind=KIT.CH_TO_FIFF_KIND[ch['type']], loc=loc)) + loc[:3] = sqd["eeg_dig"][eeg_name] = dig[eeg_name] + fiff_channels.append( + dict( + cal=KIT.CALIB_FACTOR, + logno=idx, + scanno=idx, + range=KIT.RANGE, + unit=unit, + unit_mul=KIT.UNIT_MUL, + ch_name=ch_name, + coord_frame=FIFF.FIFFV_COORD_DEVICE, + coil_type=KIT.CH_TO_FIFF_COIL[ch["type"]], + kind=KIT.CH_TO_FIFF_KIND[ch["type"]], + loc=loc, + ) + ) info._unlocked = False info._update_redundant() return info, sqd @@ -795,14 +894,24 @@ def get_kit_info(rawfile, allow_unknown_format, standardize_names=None, def _read_name(fid, ch_type=None, n=None): n = n if ch_type is None else KIT.CHANNEL_NAME_NCHAR[ch_type] - return fid.read(n).split(b'\x00')[0].decode('utf-8') + return fid.read(n).split(b"\x00")[0].decode("utf-8") @fill_doc -def read_raw_kit(input_fname, mrk=None, elp=None, hsp=None, stim='>', - slope='-', stimthresh=1, preload=False, stim_code='binary', - allow_unknown_format=False, standardize_names=False, - verbose=None): +def read_raw_kit( + input_fname, + mrk=None, + elp=None, + hsp=None, + stim=">", + slope="-", + stimthresh=1, + preload=False, + stim_code="binary", + allow_unknown_format=False, + standardize_names=False, + verbose=None, +): r"""Reader function for Ricoh/KIT conversion to FIF. Parameters @@ -843,17 +952,34 @@ def read_raw_kit(input_fname, mrk=None, elp=None, hsp=None, stim='>', If ``mrk``\, ``hsp`` or ``elp`` are :term:`array_like` inputs, then the numbers in xyz coordinates should be in units of meters. """ - return RawKIT(input_fname=input_fname, mrk=mrk, elp=elp, hsp=hsp, - stim=stim, slope=slope, stimthresh=stimthresh, - preload=preload, stim_code=stim_code, - allow_unknown_format=allow_unknown_format, - standardize_names=standardize_names, verbose=verbose) + return RawKIT( + input_fname=input_fname, + mrk=mrk, + elp=elp, + hsp=hsp, + stim=stim, + slope=slope, + stimthresh=stimthresh, + preload=preload, + stim_code=stim_code, + allow_unknown_format=allow_unknown_format, + standardize_names=standardize_names, + verbose=verbose, + ) @fill_doc -def read_epochs_kit(input_fname, events, event_id=None, mrk=None, elp=None, - hsp=None, allow_unknown_format=False, - standardize_names=False, verbose=None): +def read_epochs_kit( + input_fname, + events, + event_id=None, + mrk=None, + elp=None, + hsp=None, + allow_unknown_format=False, + standardize_names=False, + verbose=None, +): """Reader function for Ricoh/KIT epochs files. Parameters @@ -886,9 +1012,15 @@ def read_epochs_kit(input_fname, events, event_id=None, mrk=None, elp=None, ----- .. versionadded:: 0.9.0 """ - epochs = EpochsKIT(input_fname=input_fname, events=events, - event_id=event_id, mrk=mrk, elp=elp, hsp=hsp, - allow_unknown_format=allow_unknown_format, - standardize_names=standardize_names, - verbose=verbose) + epochs = EpochsKIT( + input_fname=input_fname, + events=events, + event_id=event_id, + mrk=mrk, + elp=elp, + hsp=hsp, + allow_unknown_format=allow_unknown_format, + standardize_names=standardize_names, + verbose=verbose, + ) return epochs diff --git a/mne/io/kit/tests/test_coreg.py b/mne/io/kit/tests/test_coreg.py index 7c30a401507..f36a9a2be63 100644 --- a/mne/io/kit/tests/test_coreg.py +++ b/mne/io/kit/tests/test_coreg.py @@ -26,11 +26,11 @@ def test_io_mrk(tmp_path): # pickle fname = tmp_path / "mrk.pickled" - with open(fname, 'wb') as fid: + with open(fname, "wb") as fid: pickle.dump(dict(mrk=pts), fid) pts_2 = read_mrk(fname) assert_array_equal(pts_2, pts, "pickle mrk") - with open(fname, 'wb') as fid: + with open(fname, "wb") as fid: pickle.dump(dict(), fid) pytest.raises(ValueError, read_mrk, fname) diff --git a/mne/io/kit/tests/test_kit.py b/mne/io/kit/tests/test_kit.py index 696d10a83da..77a779a98ca 100644 --- a/mne/io/kit/tests/test_kit.py +++ b/mne/io/kit/tests/test_kit.py @@ -5,8 +5,12 @@ from pathlib import Path import numpy as np -from numpy.testing import (assert_array_almost_equal, assert_array_equal, - assert_equal, assert_allclose) +from numpy.testing import ( + assert_array_almost_equal, + assert_array_equal, + assert_equal, + assert_allclose, +) import pytest from scipy import linalg import scipy.io @@ -41,9 +45,7 @@ sqd_as_path = data_path / "KIT" / "test_as-raw.con" yokogawa_path = data_path / "KIT" / "ArtificalSignalData_Yokogawa_1khz.con" ricoh_path = data_path / "KIT" / "ArtificalSignalData_RICOH_1khz.con" -ricoh_systems_paths = [ - data_path / "KIT" / "Example_PQA160C_1001-export_anonymyze.con" -] +ricoh_systems_paths = [data_path / "KIT" / "Example_PQA160C_1001-export_anonymyze.con"] ricoh_systems_paths += [ data_path / "KIT" / "Example_RICOH160-1_10020-export_anonymyze.con" ] @@ -62,95 +64,134 @@ def test_data(tmp_path): pytest.raises(TypeError, read_raw_kit, epochs_path) pytest.raises(TypeError, read_epochs_kit, sqd_path) pytest.raises(ValueError, read_raw_kit, sqd_path, mrk_path, elp_txt_path) - pytest.raises(ValueError, read_raw_kit, sqd_path, None, None, None, - list(range(200, 190, -1))) - pytest.raises(ValueError, read_raw_kit, sqd_path, None, None, None, - list(range(167, 159, -1)), '*', 1, True) + pytest.raises( + ValueError, read_raw_kit, sqd_path, None, None, None, list(range(200, 190, -1)) + ) + pytest.raises( + ValueError, + read_raw_kit, + sqd_path, + None, + None, + None, + list(range(167, 159, -1)), + "*", + 1, + True, + ) # check functionality - raw_mrk = read_raw_kit(sqd_path, [mrk2_path, mrk3_path], elp_txt_path, - hsp_txt_path) + raw_mrk = read_raw_kit(sqd_path, [mrk2_path, mrk3_path], elp_txt_path, hsp_txt_path) assert ( - raw_mrk.info['description'] == 'NYU 160ch System since Jan24 2009 (34) V2R004 EQ1160C' # noqa: E501 + raw_mrk.info["description"] + == "NYU 160ch System since Jan24 2009 (34) V2R004 EQ1160C" # noqa: E501 + ) + raw_py = _test_raw_reader( + read_raw_kit, + input_fname=sqd_path, + mrk=mrk_path, + elp=elp_txt_path, + hsp=hsp_txt_path, + stim=list(range(167, 159, -1)), + slope="+", + stimthresh=1, ) - raw_py = _test_raw_reader(read_raw_kit, input_fname=sqd_path, mrk=mrk_path, - elp=elp_txt_path, hsp=hsp_txt_path, - stim=list(range(167, 159, -1)), slope='+', - stimthresh=1) - assert 'RawKIT' in repr(raw_py) - assert_equal(raw_mrk.info['kit_system_id'], KIT.SYSTEM_NYU_2010) + assert "RawKIT" in repr(raw_py) + assert_equal(raw_mrk.info["kit_system_id"], KIT.SYSTEM_NYU_2010) # check number/kind of channels - assert_equal(len(raw_py.info['chs']), 193) - kit_channels = (('kind', {FIFF.FIFFV_MEG_CH: 157, FIFF.FIFFV_REF_MEG_CH: 3, - FIFF.FIFFV_MISC_CH: 32, FIFF.FIFFV_STIM_CH: 1}), - ('coil_type', {FIFF.FIFFV_COIL_KIT_GRAD: 157, - FIFF.FIFFV_COIL_KIT_REF_MAG: 3, - FIFF.FIFFV_COIL_NONE: 33})) + assert_equal(len(raw_py.info["chs"]), 193) + kit_channels = ( + ( + "kind", + { + FIFF.FIFFV_MEG_CH: 157, + FIFF.FIFFV_REF_MEG_CH: 3, + FIFF.FIFFV_MISC_CH: 32, + FIFF.FIFFV_STIM_CH: 1, + }, + ), + ( + "coil_type", + { + FIFF.FIFFV_COIL_KIT_GRAD: 157, + FIFF.FIFFV_COIL_KIT_REF_MAG: 3, + FIFF.FIFFV_COIL_NONE: 33, + }, + ), + ) for label, target in kit_channels: - actual = {id_: sum(ch[label] == id_ for ch in raw_py.info['chs']) for - id_ in target.keys()} + actual = { + id_: sum(ch[label] == id_ for ch in raw_py.info["chs"]) + for id_ in target.keys() + } assert_equal(actual, target) # Test stim channel - raw_stim = read_raw_kit(sqd_path, mrk_path, elp_txt_path, hsp_txt_path, - stim='<', preload=False) + raw_stim = read_raw_kit( + sqd_path, mrk_path, elp_txt_path, hsp_txt_path, stim="<", preload=False + ) for raw in [raw_py, raw_stim, raw_mrk]: - stim_pick = pick_types(raw.info, meg=False, ref_meg=False, - stim=True, exclude='bads') + stim_pick = pick_types( + raw.info, meg=False, ref_meg=False, stim=True, exclude="bads" + ) stim1, _ = raw[stim_pick] stim2 = np.array(raw.read_stim_ch(), ndmin=2) assert_array_equal(stim1, stim2) # Binary file only stores the sensor channels - py_picks = pick_types(raw_py.info, meg=True, exclude='bads') + py_picks = pick_types(raw_py.info, meg=True, exclude="bads") raw_bin = data_dir / "test_bin_raw.fif" raw_bin = read_raw_fif(raw_bin, preload=True) - bin_picks = pick_types(raw_bin.info, meg=True, stim=True, exclude='bads') + bin_picks = pick_types(raw_bin.info, meg=True, stim=True, exclude="bads") data_bin, _ = raw_bin[bin_picks] data_py, _ = raw_py[py_picks] # this .mat was generated using the Yokogawa MEG Reader data_Ykgw = data_dir / "test_Ykgw.mat" - data_Ykgw = scipy.io.loadmat(data_Ykgw)['data'] + data_Ykgw = scipy.io.loadmat(data_Ykgw)["data"] data_Ykgw = data_Ykgw[py_picks] assert_array_almost_equal(data_py, data_Ykgw) - py_picks = pick_types(raw_py.info, meg=True, stim=True, ref_meg=False, - exclude='bads') + py_picks = pick_types( + raw_py.info, meg=True, stim=True, ref_meg=False, exclude="bads" + ) data_py, _ = raw_py[py_picks] assert_array_almost_equal(data_py, data_bin) # KIT-UMD data - _test_raw_reader(read_raw_kit, input_fname=sqd_umd_path, test_rank='less') + _test_raw_reader(read_raw_kit, input_fname=sqd_umd_path, test_rank="less") raw = read_raw_kit(sqd_umd_path) assert ( - raw.info['description'] == 'University of Maryland/Kanazawa Institute of Technology/160-channel MEG System (53) V2R004 PQ1160R' # noqa: E501 + raw.info["description"] + == "University of Maryland/Kanazawa Institute of Technology/160-channel MEG System (53) V2R004 PQ1160R" # noqa: E501 ) - assert_equal(raw.info['kit_system_id'], KIT.SYSTEM_UMD_2014_12) + assert_equal(raw.info["kit_system_id"], KIT.SYSTEM_UMD_2014_12) # check number/kind of channels - assert_equal(len(raw.info['chs']), 193) + assert_equal(len(raw.info["chs"]), 193) for label, target in kit_channels: - actual = {id_: sum(ch[label] == id_ for ch in raw.info['chs']) for - id_ in target.keys()} + actual = { + id_: sum(ch[label] == id_ for ch in raw.info["chs"]) + for id_ in target.keys() + } assert_equal(actual, target) # KIT Academia Sinica - raw = read_raw_kit(sqd_as_path, slope='+') + raw = read_raw_kit(sqd_as_path, slope="+") assert ( - raw.info['description'] == 'Academia Sinica/Institute of Linguistics//Magnetoencephalograph System (261) V2R004 PQ1160R-N2' # noqa: E501 + raw.info["description"] + == "Academia Sinica/Institute of Linguistics//Magnetoencephalograph System (261) V2R004 PQ1160R-N2" # noqa: E501 ) - assert_equal(raw.info['kit_system_id'], KIT.SYSTEM_AS_2008) - assert_equal(raw.info['chs'][100]['ch_name'], 'MEG 101') - assert_equal(raw.info['chs'][100]['kind'], FIFF.FIFFV_MEG_CH) - assert_equal(raw.info['chs'][100]['coil_type'], FIFF.FIFFV_COIL_KIT_GRAD) - assert_equal(raw.info['chs'][157]['ch_name'], 'MEG 158') - assert_equal(raw.info['chs'][157]['kind'], FIFF.FIFFV_REF_MEG_CH) - assert_equal(raw.info['chs'][157]['coil_type'], - FIFF.FIFFV_COIL_KIT_REF_MAG) - assert_equal(raw.info['chs'][160]['ch_name'], 'EEG 001') - assert_equal(raw.info['chs'][160]['kind'], FIFF.FIFFV_EEG_CH) - assert_equal(raw.info['chs'][160]['coil_type'], FIFF.FIFFV_COIL_EEG) + assert_equal(raw.info["kit_system_id"], KIT.SYSTEM_AS_2008) + assert_equal(raw.info["chs"][100]["ch_name"], "MEG 101") + assert_equal(raw.info["chs"][100]["kind"], FIFF.FIFFV_MEG_CH) + assert_equal(raw.info["chs"][100]["coil_type"], FIFF.FIFFV_COIL_KIT_GRAD) + assert_equal(raw.info["chs"][157]["ch_name"], "MEG 158") + assert_equal(raw.info["chs"][157]["kind"], FIFF.FIFFV_REF_MEG_CH) + assert_equal(raw.info["chs"][157]["coil_type"], FIFF.FIFFV_COIL_KIT_REF_MAG) + assert_equal(raw.info["chs"][160]["ch_name"], "EEG 001") + assert_equal(raw.info["chs"][160]["kind"], FIFF.FIFFV_EEG_CH) + assert_equal(raw.info["chs"][160]["coil_type"], FIFF.FIFFV_COIL_EEG) assert_array_equal(find_events(raw), [[91, 0, 2]]) @@ -159,21 +200,21 @@ def test_unknown_format(tmp_path): """Test our warning about an unknown format.""" fname = tmp_path / ricoh_path.name _, kit_info = get_kit_info(ricoh_path, allow_unknown_format=False) - n_before = kit_info['dirs'][KIT.DIR_INDEX_SYSTEM]['offset'] - with open(fname, 'wb') as fout: - with open(ricoh_path, 'rb') as fin: + n_before = kit_info["dirs"][KIT.DIR_INDEX_SYSTEM]["offset"] + with open(fname, "wb") as fout: + with open(ricoh_path, "rb") as fin: fout.write(fin.read(n_before)) - version, revision = np.fromfile(fin, ' 2 # good version = 1 # bad - np.array([version, revision], ' 5000 # should have similar size, distance from center - dist = np.sqrt(np.sum((hsp_m - np.mean(hsp_m, axis=0))**2, axis=1)) - dist_dec = np.sqrt(np.sum((hsp_dec - np.mean(hsp_dec, axis=0))**2, axis=1)) + dist = np.sqrt(np.sum((hsp_m - np.mean(hsp_m, axis=0)) ** 2, axis=1)) + dist_dec = np.sqrt(np.sum((hsp_dec - np.mean(hsp_dec, axis=0)) ** 2, axis=1)) hsp_rad = np.mean(dist) hsp_dec_rad = np.mean(dist_dec) assert_array_almost_equal(hsp_rad, hsp_dec_rad, decimal=3) @requires_testing_data -@pytest.mark.parametrize('fname, desc, system_id', [ - (ricoh_systems_paths[0], - 'Meg160/Analysis (1001) V2R004 PQA160C', 1001), - (ricoh_systems_paths[1], - 'RICOH MEG System (10020) V3R000 RICOH160-1', 10020), - (ricoh_systems_paths[2], - 'RICOH MEG System (10021) V3R000 RICOH160-1', 10021), - (ricoh_systems_paths[3], - 'Yokogawa Electric Corporation/MEG device for infants/151-channel MEG ' - 'System (903) V2R004 PQ1151R', 903), -]) +@pytest.mark.parametrize( + "fname, desc, system_id", + [ + (ricoh_systems_paths[0], "Meg160/Analysis (1001) V2R004 PQA160C", 1001), + (ricoh_systems_paths[1], "RICOH MEG System (10020) V3R000 RICOH160-1", 10020), + (ricoh_systems_paths[2], "RICOH MEG System (10021) V3R000 RICOH160-1", 10021), + ( + ricoh_systems_paths[3], + "Yokogawa Electric Corporation/MEG device for infants/151-channel MEG " + "System (903) V2R004 PQ1151R", + 903, + ), + ], +) def test_raw_system_id(fname, desc, system_id): """Test reading basics and system IDs.""" raw = _test_raw_reader(read_raw_kit, input_fname=fname) - assert raw.info['description'] == desc - assert raw.info['kit_system_id'] == system_id + assert raw.info["description"] == desc + assert raw.info["kit_system_id"] == system_id @requires_testing_data @@ -378,14 +438,15 @@ def test_berlin(): # gh-8535 raw = read_raw_kit(berlin_path) assert ( - raw.info['description'] == 'Physikalisch Technische Bundesanstalt, Berlin/128-channel MEG System (124) V2R004 PQ1128R-N2' # noqa: E501 + raw.info["description"] + == "Physikalisch Technische Bundesanstalt, Berlin/128-channel MEG System (124) V2R004 PQ1128R-N2" # noqa: E501 ) - assert raw.info['kit_system_id'] == 124 - assert raw.info['highpass'] == 0. - assert raw.info['lowpass'] == 200. - assert raw.info['sfreq'] == 500. - n = int(round(28.77 * raw.info['sfreq'])) - meg = raw.get_data('MEG 003', n, n + 1)[0, 0] + assert raw.info["kit_system_id"] == 124 + assert raw.info["highpass"] == 0.0 + assert raw.info["lowpass"] == 200.0 + assert raw.info["sfreq"] == 500.0 + n = int(round(28.77 * raw.info["sfreq"])) + meg = raw.get_data("MEG 003", n, n + 1)[0, 0] assert_allclose(meg, -8.89e-12, rtol=1e-3) - eeg = raw.get_data('E14', n, n + 1)[0, 0] + eeg = raw.get_data("E14", n, n + 1)[0, 0] assert_allclose(eeg, -2.55, rtol=1e-3) diff --git a/mne/io/matrix.py b/mne/io/matrix.py index 4da12b8506f..3699278d2de 100644 --- a/mne/io/matrix.py +++ b/mne/io/matrix.py @@ -5,19 +5,24 @@ from .constants import FIFF from .tag import find_tag, has_tag -from .write import (write_int, start_block, end_block, write_float_matrix, - write_name_list) +from .write import ( + write_int, + start_block, + end_block, + write_float_matrix, + write_name_list, +) from ..utils import logger def _transpose_named_matrix(mat): """Transpose mat inplace (no copy).""" - mat['nrow'], mat['ncol'] = mat['ncol'], mat['nrow'] - mat['row_names'], mat['col_names'] = mat['col_names'], mat['row_names'] - mat['data'] = mat['data'].T + mat["nrow"], mat["ncol"] = mat["ncol"], mat["nrow"] + mat["row_names"], mat["col_names"] = mat["col_names"], mat["row_names"] + mat["data"] = mat["data"].T -def _read_named_matrix(fid, node, matkind, indent=' ', transpose=False): +def _read_named_matrix(fid, node, matkind, indent=" ", transpose=False): """Read named matrix from the given node. Parameters @@ -38,48 +43,53 @@ def _read_named_matrix(fid, node, matkind, indent=' ', transpose=False): The matrix data """ # Descend one level if necessary - if node['block'] != FIFF.FIFFB_MNE_NAMED_MATRIX: - for k in range(node['nchild']): - if node['children'][k]['block'] == FIFF.FIFFB_MNE_NAMED_MATRIX: - if has_tag(node['children'][k], matkind): - node = node['children'][k] + if node["block"] != FIFF.FIFFB_MNE_NAMED_MATRIX: + for k in range(node["nchild"]): + if node["children"][k]["block"] == FIFF.FIFFB_MNE_NAMED_MATRIX: + if has_tag(node["children"][k], matkind): + node = node["children"][k] break else: - logger.info(indent + 'Desired named matrix (kind = %d) not ' - 'available' % matkind) + logger.info( + indent + "Desired named matrix (kind = %d) not " "available" % matkind + ) return None else: if not has_tag(node, matkind): - logger.info(indent + 'Desired named matrix (kind = %d) not ' - 'available' % matkind) + logger.info( + indent + "Desired named matrix (kind = %d) not " "available" % matkind + ) return None # Read everything we need tag = find_tag(fid, node, matkind) if tag is None: - raise ValueError('Matrix data missing') + raise ValueError("Matrix data missing") else: data = tag.data nrow, ncol = data.shape tag = find_tag(fid, node, FIFF.FIFF_MNE_NROW) if tag is not None and tag.data != nrow: - raise ValueError('Number of rows in matrix data and FIFF_MNE_NROW ' - 'tag do not match') + raise ValueError( + "Number of rows in matrix data and FIFF_MNE_NROW " "tag do not match" + ) tag = find_tag(fid, node, FIFF.FIFF_MNE_NCOL) if tag is not None and tag.data != ncol: - raise ValueError('Number of columns in matrix data and ' - 'FIFF_MNE_NCOL tag do not match') + raise ValueError( + "Number of columns in matrix data and " "FIFF_MNE_NCOL tag do not match" + ) tag = find_tag(fid, node, FIFF.FIFF_MNE_ROW_NAMES) - row_names = tag.data.split(':') if tag is not None else [] + row_names = tag.data.split(":") if tag is not None else [] tag = find_tag(fid, node, FIFF.FIFF_MNE_COL_NAMES) - col_names = tag.data.split(':') if tag is not None else [] + col_names = tag.data.split(":") if tag is not None else [] - mat = dict(nrow=nrow, ncol=ncol, row_names=row_names, col_names=col_names, - data=data) + mat = dict( + nrow=nrow, ncol=ncol, row_names=row_names, col_names=col_names, data=data + ) if transpose: _transpose_named_matrix(mat) return mat @@ -98,31 +108,32 @@ def write_named_matrix(fid, kind, mat): The type of matrix. """ # let's save ourselves from disaster - n_tot = mat['nrow'] * mat['ncol'] - if mat['data'].size != n_tot: - ratio = n_tot / float(mat['data'].size) - if n_tot < mat['data'].size and ratio > 0: + n_tot = mat["nrow"] * mat["ncol"] + if mat["data"].size != n_tot: + ratio = n_tot / float(mat["data"].size) + if n_tot < mat["data"].size and ratio > 0: ratio = 1 / ratio - raise ValueError('Cannot write matrix: row (%i) and column (%i) ' - 'total element (%i) mismatch with data size (%i), ' - 'appears to be off by a factor of %gx' - % (mat['nrow'], mat['ncol'], n_tot, - mat['data'].size, ratio)) + raise ValueError( + "Cannot write matrix: row (%i) and column (%i) " + "total element (%i) mismatch with data size (%i), " + "appears to be off by a factor of %gx" + % (mat["nrow"], mat["ncol"], n_tot, mat["data"].size, ratio) + ) start_block(fid, FIFF.FIFFB_MNE_NAMED_MATRIX) - write_int(fid, FIFF.FIFF_MNE_NROW, mat['nrow']) - write_int(fid, FIFF.FIFF_MNE_NCOL, mat['ncol']) + write_int(fid, FIFF.FIFF_MNE_NROW, mat["nrow"]) + write_int(fid, FIFF.FIFF_MNE_NCOL, mat["ncol"]) - if len(mat['row_names']) > 0: + if len(mat["row_names"]) > 0: # let's prevent unintentional stupidity - if len(mat['row_names']) != mat['nrow']: + if len(mat["row_names"]) != mat["nrow"]: raise ValueError('len(mat["row_names"]) != mat["nrow"]') - write_name_list(fid, FIFF.FIFF_MNE_ROW_NAMES, mat['row_names']) + write_name_list(fid, FIFF.FIFF_MNE_ROW_NAMES, mat["row_names"]) - if len(mat['col_names']) > 0: + if len(mat["col_names"]) > 0: # let's prevent unintentional stupidity - if len(mat['col_names']) != mat['ncol']: + if len(mat["col_names"]) != mat["ncol"]: raise ValueError('len(mat["col_names"]) != mat["ncol"]') - write_name_list(fid, FIFF.FIFF_MNE_COL_NAMES, mat['col_names']) + write_name_list(fid, FIFF.FIFF_MNE_COL_NAMES, mat["col_names"]) - write_float_matrix(fid, kind, mat['data']) + write_float_matrix(fid, kind, mat["data"]) end_block(fid, FIFF.FIFFB_MNE_NAMED_MATRIX) diff --git a/mne/io/meas_info.py b/mne/io/meas_info.py index 3e6eb62c4a6..9f08a4cd720 100644 --- a/mne/io/meas_info.py +++ b/mne/io/meas_info.py @@ -17,30 +17,84 @@ import numpy as np -from .pick import (channel_type, _get_channel_types, - get_channel_type_constants, pick_types, _contains_ch_type) +from .pick import ( + channel_type, + _get_channel_types, + get_channel_type_constants, + pick_types, + _contains_ch_type, +) from .constants import FIFF, _coord_frame_named from .open import fiff_open from .tree import dir_tree_find -from .tag import (read_tag, find_tag, _ch_coord_dict, _update_ch_info_named, - _rename_list, _int_item, _float_item) -from .proj import (_read_proj, _write_proj, _uniquify_projs, _normalize_proj, - _proj_equal, Projection) +from .tag import ( + read_tag, + find_tag, + _ch_coord_dict, + _update_ch_info_named, + _rename_list, + _int_item, + _float_item, +) +from .proj import ( + _read_proj, + _write_proj, + _uniquify_projs, + _normalize_proj, + _proj_equal, + Projection, +) from .ctf_comp import _read_ctf_comp, write_ctf_comp -from .write import (start_and_end_file, start_block, end_block, - write_string, write_dig_points, write_float, write_int, - write_coord_trans, write_ch_info, - write_julian, write_float_matrix, write_id, DATE_NONE, - _safe_name_list, write_name_list_sanitized) +from .write import ( + start_and_end_file, + start_block, + end_block, + write_string, + write_dig_points, + write_float, + write_int, + write_coord_trans, + write_ch_info, + write_julian, + write_float_matrix, + write_id, + DATE_NONE, + _safe_name_list, + write_name_list_sanitized, +) from .proc_history import _read_proc_history, _write_proc_history -from ..transforms import (invert_transform, Transform, _coord_frame_name, - _ensure_trans, _frame_to_str) -from ..utils import (logger, verbose, warn, object_diff, _validate_type, - _stamp_to_dt, _dt_to_stamp, _pl, _is_numeric, - _check_option, _on_missing, _check_on_missing, fill_doc, - _check_fname, repr_html) -from ._digitization import (_format_dig_points, _dig_kind_proper, DigPoint, - _dig_kind_rev, _dig_kind_ints, _read_dig_fif) +from ..transforms import ( + invert_transform, + Transform, + _coord_frame_name, + _ensure_trans, + _frame_to_str, +) +from ..utils import ( + logger, + verbose, + warn, + object_diff, + _validate_type, + _stamp_to_dt, + _dt_to_stamp, + _pl, + _is_numeric, + _check_option, + _on_missing, + _check_on_missing, + fill_doc, + _check_fname, + repr_html, +) +from ._digitization import ( + _format_dig_points, + _dig_kind_proper, + DigPoint, + _dig_kind_rev, + _dig_kind_ints, + _read_dig_fif, +) from ._digitization import write_dig, _get_data_as_dict_from_dig from .compensator import get_current_comp from ..defaults import _handle_default @@ -48,11 +102,20 @@ b = bytes # alias -_SCALAR_CH_KEYS = ('scanno', 'logno', 'kind', 'range', 'cal', 'coil_type', - 'unit', 'unit_mul', 'coord_frame') -_ALL_CH_KEYS_SET = set(_SCALAR_CH_KEYS + ('loc', 'ch_name')) +_SCALAR_CH_KEYS = ( + "scanno", + "logno", + "kind", + "range", + "cal", + "coil_type", + "unit", + "unit_mul", + "coord_frame", +) +_ALL_CH_KEYS_SET = set(_SCALAR_CH_KEYS + ("loc", "ch_name")) # XXX we need to require these except when doing simplify_info -_MIN_CH_KEYS_SET = set(('kind', 'cal', 'unit', 'loc', 'ch_name')) +_MIN_CH_KEYS_SET = set(("kind", "cal", "unit", "loc", "ch_name")) def _get_valid_units(): @@ -67,35 +130,137 @@ def _get_valid_units(): ---------- .. footbibliography:: """ - valid_prefix_names = ['yocto', 'zepto', 'atto', 'femto', 'pico', 'nano', - 'micro', 'milli', 'centi', 'deci', 'deca', 'hecto', - 'kilo', 'mega', 'giga', 'tera', 'peta', 'exa', - 'zetta', 'yotta'] - valid_prefix_symbols = ['y', 'z', 'a', 'f', 'p', 'n', 'µ', 'm', 'c', 'd', - 'da', 'h', 'k', 'M', 'G', 'T', 'P', 'E', 'Z', 'Y'] - valid_unit_names = ['metre', 'kilogram', 'second', 'ampere', 'kelvin', - 'mole', 'candela', 'radian', 'steradian', 'hertz', - 'newton', 'pascal', 'joule', 'watt', 'coulomb', 'volt', - 'farad', 'ohm', 'siemens', 'weber', 'tesla', 'henry', - 'degree Celsius', 'lumen', 'lux', 'becquerel', 'gray', - 'sievert', 'katal'] - valid_unit_symbols = ['m', 'kg', 's', 'A', 'K', 'mol', 'cd', 'rad', 'sr', - 'Hz', 'N', 'Pa', 'J', 'W', 'C', 'V', 'F', 'Ω', 'S', - 'Wb', 'T', 'H', '°C', 'lm', 'lx', 'Bq', 'Gy', 'Sv', - 'kat'] + valid_prefix_names = [ + "yocto", + "zepto", + "atto", + "femto", + "pico", + "nano", + "micro", + "milli", + "centi", + "deci", + "deca", + "hecto", + "kilo", + "mega", + "giga", + "tera", + "peta", + "exa", + "zetta", + "yotta", + ] + valid_prefix_symbols = [ + "y", + "z", + "a", + "f", + "p", + "n", + "µ", + "m", + "c", + "d", + "da", + "h", + "k", + "M", + "G", + "T", + "P", + "E", + "Z", + "Y", + ] + valid_unit_names = [ + "metre", + "kilogram", + "second", + "ampere", + "kelvin", + "mole", + "candela", + "radian", + "steradian", + "hertz", + "newton", + "pascal", + "joule", + "watt", + "coulomb", + "volt", + "farad", + "ohm", + "siemens", + "weber", + "tesla", + "henry", + "degree Celsius", + "lumen", + "lux", + "becquerel", + "gray", + "sievert", + "katal", + ] + valid_unit_symbols = [ + "m", + "kg", + "s", + "A", + "K", + "mol", + "cd", + "rad", + "sr", + "Hz", + "N", + "Pa", + "J", + "W", + "C", + "V", + "F", + "Ω", + "S", + "Wb", + "T", + "H", + "°C", + "lm", + "lx", + "Bq", + "Gy", + "Sv", + "kat", + ] # Valid units are all possible combinations of either prefix name or prefix # symbol together with either unit name or unit symbol. E.g., nV for # nanovolt valid_units = [] - valid_units += ([''.join([prefix, unit]) for prefix in valid_prefix_names - for unit in valid_unit_names]) - valid_units += ([''.join([prefix, unit]) for prefix in valid_prefix_names - for unit in valid_unit_symbols]) - valid_units += ([''.join([prefix, unit]) for prefix in valid_prefix_symbols - for unit in valid_unit_names]) - valid_units += ([''.join([prefix, unit]) for prefix in valid_prefix_symbols - for unit in valid_unit_symbols]) + valid_units += [ + "".join([prefix, unit]) + for prefix in valid_prefix_names + for unit in valid_unit_names + ] + valid_units += [ + "".join([prefix, unit]) + for prefix in valid_prefix_names + for unit in valid_unit_symbols + ] + valid_units += [ + "".join([prefix, unit]) + for prefix in valid_prefix_symbols + for unit in valid_unit_names + ] + valid_units += [ + "".join([prefix, unit]) + for prefix in valid_prefix_symbols + for unit in valid_unit_symbols + ] # units are also valid without a prefix valid_units += valid_unit_names @@ -116,17 +281,17 @@ def _unique_channel_names(ch_names, max_length=None, verbose=None): ch_names[:] = [name[:max_length] for name in ch_names] unique_ids = np.unique(ch_names, return_index=True)[1] if len(unique_ids) != len(ch_names): - dups = {ch_names[x] - for x in np.setdiff1d(range(len(ch_names)), unique_ids)} - warn('Channel names are not unique, found duplicates for: ' - '%s. Applying running numbers for duplicates.' % dups) + dups = {ch_names[x] for x in np.setdiff1d(range(len(ch_names)), unique_ids)} + warn( + "Channel names are not unique, found duplicates for: " + "%s. Applying running numbers for duplicates." % dups + ) for ch_stem in dups: overlaps = np.where(np.array(ch_names) == ch_stem)[0] # We need an extra character since we append '-'. # np.ceil(...) is the maximum number of appended digits. if max_length is not None: - n_keep = ( - max_length - 1 - int(np.ceil(np.log10(len(overlaps))))) + n_keep = max_length - 1 - int(np.ceil(np.log10(len(overlaps)))) else: n_keep = np.inf n_keep = min(len(ch_stem), n_keep) @@ -134,15 +299,17 @@ def _unique_channel_names(ch_names, max_length=None, verbose=None): for idx, ch_idx in enumerate(overlaps): # try idx first, then loop through lower case chars for suffix in (idx,) + suffixes: - ch_name = ch_stem + '-%s' % suffix + ch_name = ch_stem + "-%s" % suffix if ch_name not in ch_names: break if ch_name not in ch_names: ch_names[ch_idx] = ch_name else: - raise ValueError('Adding a single alphanumeric for a ' - 'duplicate resulted in another ' - 'duplicate name %s' % ch_name) + raise ValueError( + "Adding a single alphanumeric for a " + "duplicate resulted in another " + "duplicate name %s" % ch_name + ) return ch_names @@ -158,22 +325,31 @@ def get_montage(self): %(montage)s """ from ..channels.montage import make_dig_montage + info = self if isinstance(self, Info) else self.info - if info['dig'] is None: + if info["dig"] is None: return None # obtain coord_frame, and landmark coords # (nasion, lpa, rpa, hsp, hpi) from DigPoints - montage_bunch = _get_data_as_dict_from_dig(info['dig']) + montage_bunch = _get_data_as_dict_from_dig(info["dig"]) coord_frame = _frame_to_str.get(montage_bunch.coord_frame) # get the channel names and chs data structure - ch_names, chs = info['ch_names'], info['chs'] - picks = pick_types(info, meg=False, eeg=True, seeg=True, - ecog=True, dbs=True, fnirs=True, exclude=[]) + ch_names, chs = info["ch_names"], info["chs"] + picks = pick_types( + info, + meg=False, + eeg=True, + seeg=True, + ecog=True, + dbs=True, + fnirs=True, + exclude=[], + ) # channel positions from dig do not match ch_names one to one, # so use loc[:3] instead - ch_pos = {ch_names[ii]: chs[ii]['loc'][:3] for ii in picks} + ch_pos = {ch_names[ii]: chs[ii]["loc"][:3] for ii in picks} # fNIRS uses multiple channels for the same sensors, we use # a private function to format these for dig montage. @@ -181,10 +357,12 @@ def get_montage(self): if len(ch_pos) == len(fnirs_picks): ch_pos = _get_fnirs_ch_pos(info) elif len(fnirs_picks) > 0: - raise ValueError("MNE does not support getting the montage " - "for a mix of fNIRS and other data types. " - "Please raise a GitHub issue if you " - "require this feature.") + raise ValueError( + "MNE does not support getting the montage " + "for a mix of fNIRS and other data types. " + "Please raise a GitHub issue if you " + "require this feature." + ) # create montage montage = make_dig_montage( @@ -199,8 +377,14 @@ def get_montage(self): return montage @verbose - def set_montage(self, montage, match_case=True, match_alias=False, - on_missing='raise', verbose=None): + def set_montage( + self, + montage, + match_case=True, + match_alias=False, + on_missing="raise", + verbose=None, + ): """Set %(montage_types)s channel positions and digitization points. Parameters @@ -234,6 +418,7 @@ def set_montage(self, montage, match_case=True, match_alias=False, # https://gist.github.com/massich/f6a9f4799f1fbeb8f5e8f8bc7b07d3df from ..channels.montage import _set_montage + info = self if isinstance(self, Info) else self.info _set_montage(info, montage, match_case, match_alias, on_missing) return self @@ -267,9 +452,10 @@ def __contains__(self, ch_type): """ info = self if isinstance(self, Info) else self.info - if ch_type == 'meg': - has_ch_type = (_contains_ch_type(info, 'mag') or - _contains_ch_type(info, 'grad')) + if ch_type == "meg": + has_ch_type = _contains_ch_type(info, "mag") or _contains_ch_type( + info, "grad" + ) else: has_ch_type = _contains_ch_type(info, ch_type) return has_ch_type @@ -298,8 +484,9 @@ def get_channel_types(self, picks=None, unique=False, only_data_chs=False): The channel types. """ info = self if isinstance(self, Info) else self.info - return _get_channel_types(info, picks=picks, unique=unique, - only_data_chs=only_data_chs) + return _get_channel_types( + info, picks=picks, unique=unique, only_data_chs=only_data_chs + ) def _format_trans(obj, key): @@ -309,25 +496,25 @@ def _format_trans(obj, key): pass else: if t is not None: - obj[key] = Transform(t['from'], t['to'], t['trans']) + obj[key] = Transform(t["from"], t["to"], t["trans"]) def _check_ch_keys(ch, ci, name='info["chs"]', check_min=True): ch_keys = set(ch) bad = sorted(ch_keys.difference(_ALL_CH_KEYS_SET)) if bad: - raise KeyError( - f'key{_pl(bad)} errantly present for {name}[{ci}]: {bad}') + raise KeyError(f"key{_pl(bad)} errantly present for {name}[{ci}]: {bad}") if check_min: bad = sorted(_MIN_CH_KEYS_SET.difference(ch_keys)) if bad: raise KeyError( - f'key{_pl(bad)} missing for {name}[{ci}]: {bad}',) + f"key{_pl(bad)} missing for {name}[{ci}]: {bad}", + ) # As options are added here, test_meas_info.py:test_info_bad should be updated def _check_bads(bads): - _validate_type(bads, list, 'bads') + _validate_type(bads, list, "bads") return bads @@ -339,33 +526,47 @@ def _check_description(description): def _check_dev_head_t(dev_head_t): _validate_type(dev_head_t, (Transform, None), "info['dev_head_t']") if dev_head_t is not None: - dev_head_t = _ensure_trans(dev_head_t, 'meg', 'head') + dev_head_t = _ensure_trans(dev_head_t, "meg", "head") return dev_head_t def _check_experimenter(experimenter): - _validate_type(experimenter, (None, str), 'experimenter') + _validate_type(experimenter, (None, str), "experimenter") return experimenter def _check_line_freq(line_freq): - _validate_type(line_freq, (None, 'numeric'), 'line_freq') + _validate_type(line_freq, (None, "numeric"), "line_freq") line_freq = float(line_freq) if line_freq is not None else line_freq return line_freq def _check_subject_info(subject_info): - _validate_type(subject_info, (None, dict), 'subject_info') + _validate_type(subject_info, (None, dict), "subject_info") return subject_info def _check_device_info(device_info): - _validate_type(device_info, (None, dict, ), 'device_info') + _validate_type( + device_info, + ( + None, + dict, + ), + "device_info", + ) return device_info def _check_helium_info(helium_info): - _validate_type(helium_info, (None, dict, ), 'helium_info') + _validate_type( + helium_info, + ( + None, + dict, + ), + "helium_info", + ) return helium_info @@ -768,116 +969,117 @@ class Info(dict, MontageMixin, ContainsMixin): """ _attributes = { - 'acq_pars': 'acq_pars cannot be set directly. ' - 'See mne.AcqParserFIF() for details.', - 'acq_stim': 'acq_stim cannot be set directly.', - 'bads': _check_bads, - 'ch_names': 'ch_names cannot be set directly. ' - 'Please use methods inst.add_channels(), ' - 'inst.drop_channels(), inst.pick_channels(), ' - 'inst.rename_channels(), inst.reorder_channels() ' - 'and inst.set_channel_types() instead.', - 'chs': 'chs cannot be set directly. ' - 'Please use methods inst.add_channels(), ' - 'inst.drop_channels(), inst.pick_channels(), ' - 'inst.rename_channels(), inst.reorder_channels() ' - 'and inst.set_channel_types() instead.', - 'command_line': 'command_line cannot be set directly.', - 'comps': 'comps cannot be set directly. ' - 'Please use method Raw.apply_gradient_compensation() ' - 'instead.', - 'ctf_head_t': 'ctf_head_t cannot be set directly.', - 'custom_ref_applied': 'custom_ref_applied cannot be set directly. ' - 'Please use method inst.set_eeg_reference() ' - 'instead.', - 'description': _check_description, - 'dev_ctf_t': 'dev_ctf_t cannot be set directly.', - 'dev_head_t': _check_dev_head_t, - 'device_info': _check_device_info, - 'dig': 'dig cannot be set directly. ' - 'Please use method inst.set_montage() instead.', - 'events': 'events cannot be set directly.', - 'experimenter': _check_experimenter, - 'file_id': 'file_id cannot be set directly.', - 'gantry_angle': 'gantry_angle cannot be set directly.', - 'helium_info': _check_helium_info, - 'highpass': 'highpass cannot be set directly. ' - 'Please use method inst.filter() instead.', - 'hpi_meas': 'hpi_meas can not be set directly.', - 'hpi_results': 'hpi_results cannot be set directly.', - 'hpi_subsystem': 'hpi_subsystem cannot be set directly.', - 'kit_system_id': 'kit_system_id cannot be set directly.', - 'line_freq': _check_line_freq, - 'lowpass': 'lowpass cannot be set directly. ' - 'Please use method inst.filter() instead.', - 'maxshield': 'maxshield cannot be set directly.', - 'meas_date': 'meas_date cannot be set directly. ' - 'Please use method inst.set_meas_date() instead.', - 'meas_file': 'meas_file cannot be set directly.', - 'meas_id': 'meas_id cannot be set directly.', - 'mri_file': 'mri_file cannot be set directly.', - 'mri_head_t': 'mri_head_t cannot be set directly.', - 'mri_id': 'mri_id cannot be set directly.', - 'nchan': 'nchan cannot be set directly. ' - 'Please use methods inst.add_channels(), ' - 'inst.drop_channels(), and inst.pick_channels() instead.', - 'proc_history': 'proc_history cannot be set directly.', - 'proj_id': 'proj_id cannot be set directly.', - 'proj_name': 'proj_name cannot be set directly.', - 'projs': 'projs cannot be set directly. ' - 'Please use methods inst.add_proj() and inst.del_proj() ' - 'instead.', - 'sfreq': 'sfreq cannot be set directly. ' - 'Please use method inst.resample() instead.', - 'subject_info': _check_subject_info, - 'temp': lambda x: x, - 'utc_offset': 'utc_offset cannot be set directly.', - 'working_dir': 'working_dir cannot be set directly.', - 'xplotter_layout': 'xplotter_layout cannot be set directly.' + "acq_pars": "acq_pars cannot be set directly. " + "See mne.AcqParserFIF() for details.", + "acq_stim": "acq_stim cannot be set directly.", + "bads": _check_bads, + "ch_names": "ch_names cannot be set directly. " + "Please use methods inst.add_channels(), " + "inst.drop_channels(), inst.pick_channels(), " + "inst.rename_channels(), inst.reorder_channels() " + "and inst.set_channel_types() instead.", + "chs": "chs cannot be set directly. " + "Please use methods inst.add_channels(), " + "inst.drop_channels(), inst.pick_channels(), " + "inst.rename_channels(), inst.reorder_channels() " + "and inst.set_channel_types() instead.", + "command_line": "command_line cannot be set directly.", + "comps": "comps cannot be set directly. " + "Please use method Raw.apply_gradient_compensation() " + "instead.", + "ctf_head_t": "ctf_head_t cannot be set directly.", + "custom_ref_applied": "custom_ref_applied cannot be set directly. " + "Please use method inst.set_eeg_reference() " + "instead.", + "description": _check_description, + "dev_ctf_t": "dev_ctf_t cannot be set directly.", + "dev_head_t": _check_dev_head_t, + "device_info": _check_device_info, + "dig": "dig cannot be set directly. " + "Please use method inst.set_montage() instead.", + "events": "events cannot be set directly.", + "experimenter": _check_experimenter, + "file_id": "file_id cannot be set directly.", + "gantry_angle": "gantry_angle cannot be set directly.", + "helium_info": _check_helium_info, + "highpass": "highpass cannot be set directly. " + "Please use method inst.filter() instead.", + "hpi_meas": "hpi_meas can not be set directly.", + "hpi_results": "hpi_results cannot be set directly.", + "hpi_subsystem": "hpi_subsystem cannot be set directly.", + "kit_system_id": "kit_system_id cannot be set directly.", + "line_freq": _check_line_freq, + "lowpass": "lowpass cannot be set directly. " + "Please use method inst.filter() instead.", + "maxshield": "maxshield cannot be set directly.", + "meas_date": "meas_date cannot be set directly. " + "Please use method inst.set_meas_date() instead.", + "meas_file": "meas_file cannot be set directly.", + "meas_id": "meas_id cannot be set directly.", + "mri_file": "mri_file cannot be set directly.", + "mri_head_t": "mri_head_t cannot be set directly.", + "mri_id": "mri_id cannot be set directly.", + "nchan": "nchan cannot be set directly. " + "Please use methods inst.add_channels(), " + "inst.drop_channels(), and inst.pick_channels() instead.", + "proc_history": "proc_history cannot be set directly.", + "proj_id": "proj_id cannot be set directly.", + "proj_name": "proj_name cannot be set directly.", + "projs": "projs cannot be set directly. " + "Please use methods inst.add_proj() and inst.del_proj() " + "instead.", + "sfreq": "sfreq cannot be set directly. " + "Please use method inst.resample() instead.", + "subject_info": _check_subject_info, + "temp": lambda x: x, + "utc_offset": "utc_offset cannot be set directly.", + "working_dir": "working_dir cannot be set directly.", + "xplotter_layout": "xplotter_layout cannot be set directly.", } def __init__(self, *args, **kwargs): self._unlocked = True super().__init__(*args, **kwargs) # Deal with h5io writing things as dict - for key in ('dev_head_t', 'ctf_head_t', 'dev_ctf_t'): + for key in ("dev_head_t", "ctf_head_t", "dev_ctf_t"): _format_trans(self, key) - for res in self.get('hpi_results', []): - _format_trans(res, 'coord_trans') - if self.get('dig', None) is not None and len(self['dig']): - if isinstance(self['dig'], dict): # needs to be unpacked - self['dig'] = _dict_unpack(self['dig'], _DIG_CAST) - if not isinstance(self['dig'][0], DigPoint): - self['dig'] = _format_dig_points(self['dig']) - if isinstance(self.get('chs', None), dict): - self['chs']['ch_name'] = [str(x) for x in np.char.decode( - self['chs']['ch_name'], encoding='utf8')] - self['chs'] = _dict_unpack(self['chs'], _CH_CAST) - for pi, proj in enumerate(self.get('projs', [])): + for res in self.get("hpi_results", []): + _format_trans(res, "coord_trans") + if self.get("dig", None) is not None and len(self["dig"]): + if isinstance(self["dig"], dict): # needs to be unpacked + self["dig"] = _dict_unpack(self["dig"], _DIG_CAST) + if not isinstance(self["dig"][0], DigPoint): + self["dig"] = _format_dig_points(self["dig"]) + if isinstance(self.get("chs", None), dict): + self["chs"]["ch_name"] = [ + str(x) for x in np.char.decode(self["chs"]["ch_name"], encoding="utf8") + ] + self["chs"] = _dict_unpack(self["chs"], _CH_CAST) + for pi, proj in enumerate(self.get("projs", [])): if not isinstance(proj, Projection): - self['projs'][pi] = Projection(**proj) + self["projs"][pi] = Projection(**proj) # Old files could have meas_date as tuple instead of datetime try: - meas_date = self['meas_date'] + meas_date = self["meas_date"] except KeyError: pass else: - self['meas_date'] = _ensure_meas_date_none_or_dt(meas_date) + self["meas_date"] = _ensure_meas_date_none_or_dt(meas_date) self._unlocked = False def __getstate__(self): """Get state (for pickling).""" - return {'_unlocked': self._unlocked} + return {"_unlocked": self._unlocked} def __setstate__(self, state): """Set state (for pickling).""" - self._unlocked = state['_unlocked'] + self._unlocked = state["_unlocked"] def __setitem__(self, key, val): """Attribute setter.""" # During unpickling, the _unlocked attribute has not been set, so # let __setstate__ do it later and act unlocked now - unlocked = getattr(self, '_unlocked', True) + unlocked = getattr(self, "_unlocked", True) if key in self._attributes: if isinstance(self._attributes[key], str): if not unlocked: @@ -888,7 +1090,8 @@ def __setitem__(self, key, val): raise RuntimeError( f"Info does not support directly setting the key {repr(key)}. " "You can set info['temp'] to store temporary objects in an " - "Info instance, but these will not survive an I/O round-trip.") + "Info instance, but these will not survive an I/O round-trip." + ) super().__setitem__(key, val) def update(self, other=None, **kwargs): @@ -904,7 +1107,7 @@ def update(self, other=None, **kwargs): def _unlock(self, *, update_redundant=False, check_after=False): """Context manager unlocking access to attributes.""" # needed for nested _unlock() - state = self._unlocked if hasattr(self, '_unlocked') else False + state = self._unlocked if hasattr(self, "_unlocked") else False self._unlocked = True try: @@ -948,68 +1151,73 @@ def normalize_proj(self): def __repr__(self): """Summarize info instead of printing all.""" MAX_WIDTH = 68 - strs = [' 0: - entr = ('%d item%s (%s)' % (this_len, _pl(this_len), - type(v).__name__)) + entr = "%d item%s (%s)" % ( + this_len, + _pl(this_len), + type(v).__name__, + ) else: - entr = '' - if entr != '': + entr = "" + if entr != "": non_empty += 1 - strs.append('%s: %s' % (k, entr)) - st = '\n '.join(sorted(strs)) - st += '\n>' + strs.append("%s: %s" % (k, entr)) + st = "\n ".join(sorted(strs)) + st += "\n>" st %= non_empty return st @@ -1038,22 +1249,22 @@ def __deepcopy__(self, memodict): result._unlocked = True for k, v in self.items(): # chs is roughly half the time but most are immutable - if k == 'chs': + if k == "chs": # dict shallow copy is fast, so use it then overwrite result[k] = list() for ch in v: ch = ch.copy() # shallow - ch['loc'] = ch['loc'].copy() + ch["loc"] = ch["loc"].copy() result[k].append(ch) - elif k == 'ch_names': + elif k == "ch_names": # we know it's list of str, shallow okay and saves ~100 µs result[k] = v.copy() - elif k == 'hpi_meas': + elif k == "hpi_meas": hms = list() for hm in v: hm = hm.copy() # the only mutable thing here is some entries in coils - hm['hpi_coils'] = [coil.copy() for coil in hm['hpi_coils']] + hm["hpi_coils"] = [coil.copy() for coil in hm["hpi_coils"]] # There is a *tiny* risk here that someone could write # raw.info['hpi_meas'][0]['hpi_coils'][1]['epoch'] = ... # and assume that info.copy() will make an actual copy, @@ -1069,111 +1280,129 @@ def __deepcopy__(self, memodict): result._unlocked = False return result - def _check_consistency(self, prepend_error=''): + def _check_consistency(self, prepend_error=""): """Do some self-consistency checks and datatype tweaks.""" - missing = [bad for bad in self['bads'] if bad not in self['ch_names']] + missing = [bad for bad in self["bads"] if bad not in self["ch_names"]] if len(missing) > 0: - msg = '%sbad channel(s) %s marked do not exist in info' - raise RuntimeError(msg % (prepend_error, missing,)) - meas_date = self.get('meas_date') + msg = "%sbad channel(s) %s marked do not exist in info" + raise RuntimeError( + msg + % ( + prepend_error, + missing, + ) + ) + meas_date = self.get("meas_date") if meas_date is not None: - if (not isinstance(self['meas_date'], datetime.datetime) or - self['meas_date'].tzinfo is None or - self['meas_date'].tzinfo is not datetime.timezone.utc): - raise RuntimeError('%sinfo["meas_date"] must be a datetime ' - 'object in UTC or None, got %r' - % (prepend_error, repr(self['meas_date']),)) - - chs = [ch['ch_name'] for ch in self['chs']] - if len(self['ch_names']) != len(chs) or any( - ch_1 != ch_2 for ch_1, ch_2 in zip(self['ch_names'], chs)) or \ - self['nchan'] != len(chs): - raise RuntimeError('%sinfo channel name inconsistency detected, ' - 'please notify mne-python developers' - % (prepend_error,)) + if ( + not isinstance(self["meas_date"], datetime.datetime) + or self["meas_date"].tzinfo is None + or self["meas_date"].tzinfo is not datetime.timezone.utc + ): + raise RuntimeError( + '%sinfo["meas_date"] must be a datetime ' + "object in UTC or None, got %r" + % ( + prepend_error, + repr(self["meas_date"]), + ) + ) + + chs = [ch["ch_name"] for ch in self["chs"]] + if ( + len(self["ch_names"]) != len(chs) + or any(ch_1 != ch_2 for ch_1, ch_2 in zip(self["ch_names"], chs)) + or self["nchan"] != len(chs) + ): + raise RuntimeError( + "%sinfo channel name inconsistency detected, " + "please notify mne-python developers" % (prepend_error,) + ) # make sure we have the proper datatypes with self._unlock(): - for key in ('sfreq', 'highpass', 'lowpass'): + for key in ("sfreq", "highpass", "lowpass"): if self.get(key) is not None: self[key] = float(self[key]) - for pi, proj in enumerate(self.get('projs', [])): + for pi, proj in enumerate(self.get("projs", [])): _validate_type(proj, Projection, f'info["projs"][{pi}]') - for key in ('kind', 'active', 'desc', 'data', 'explained_var'): + for key in ("kind", "active", "desc", "data", "explained_var"): if key not in proj: - raise RuntimeError(f'Projection incomplete, missing {key}') + raise RuntimeError(f"Projection incomplete, missing {key}") # Ensure info['chs'] has immutable entries (copies much faster) - for ci, ch in enumerate(self['chs']): + for ci, ch in enumerate(self["chs"]): _check_ch_keys(ch, ci) - ch_name = ch['ch_name'] + ch_name = ch["ch_name"] if not isinstance(ch_name, str): raise TypeError( 'Bad info: info["chs"][%d]["ch_name"] is not a string, ' - 'got type %s' % (ci, type(ch_name))) + "got type %s" % (ci, type(ch_name)) + ) for key in _SCALAR_CH_KEYS: val = ch.get(key, 1) if not _is_numeric(val): raise TypeError( 'Bad info: info["chs"][%d][%r] = %s is type %s, must ' - 'be float or int' % (ci, key, val, type(val))) - loc = ch['loc'] + "be float or int" % (ci, key, val, type(val)) + ) + loc = ch["loc"] if not (isinstance(loc, np.ndarray) and loc.shape == (12,)): raise TypeError( 'Bad info: info["chs"][%d]["loc"] must be ndarray with ' - '12 elements, got %r' % (ci, loc)) + "12 elements, got %r" % (ci, loc) + ) # make sure channel names are unique with self._unlock(): - self['ch_names'] = _unique_channel_names(self['ch_names']) - for idx, ch_name in enumerate(self['ch_names']): - self['chs'][idx]['ch_name'] = ch_name + self["ch_names"] = _unique_channel_names(self["ch_names"]) + for idx, ch_name in enumerate(self["ch_names"]): + self["chs"][idx]["ch_name"] = ch_name def _update_redundant(self): """Update the redundant entries.""" with self._unlock(): - self['ch_names'] = [ch['ch_name'] for ch in self['chs']] - self['nchan'] = len(self['chs']) + self["ch_names"] = [ch["ch_name"] for ch in self["chs"]] + self["nchan"] = len(self["chs"]) @property def ch_names(self): - return self['ch_names'] + return self["ch_names"] def _get_chs_for_repr(self): - titles = _handle_default('titles') + titles = _handle_default("titles") # good channels channels = {} - ch_types = [channel_type(self, idx) for idx in range(len(self['chs']))] + ch_types = [channel_type(self, idx) for idx in range(len(self["chs"]))] ch_counts = Counter(ch_types) for ch_type, count in ch_counts.items(): - if ch_type == 'meg': - channels['mag'] = len(pick_types(self, meg='mag')) - channels['grad'] = len(pick_types(self, meg='grad')) - elif ch_type == 'eog': + if ch_type == "meg": + channels["mag"] = len(pick_types(self, meg="mag")) + channels["grad"] = len(pick_types(self, meg="grad")) + elif ch_type == "eog": pick_eog = pick_types(self, eog=True) - eog = ', '.join( - np.array(self['ch_names'])[pick_eog]) - elif ch_type == 'ecg': + eog = ", ".join(np.array(self["ch_names"])[pick_eog]) + elif ch_type == "ecg": pick_ecg = pick_types(self, ecg=True) - ecg = ', '.join( - np.array(self['ch_names'])[pick_ecg]) + ecg = ", ".join(np.array(self["ch_names"])[pick_ecg]) channels[ch_type] = count - good_channels = ', '.join( - [f'{v} {titles.get(k, k.upper())}' for k, v in channels.items()]) + good_channels = ", ".join( + [f"{v} {titles.get(k, k.upper())}" for k, v in channels.items()] + ) - if 'ecg' not in channels.keys(): - ecg = 'Not available' - if 'eog' not in channels.keys(): - eog = 'Not available' + if "ecg" not in channels.keys(): + ecg = "Not available" + if "eog" not in channels.keys(): + eog = "Not available" # bad channels - if len(self['bads']) > 0: - bad_channels = ', '.join(self['bads']) + if len(self["bads"]) > 0: + bad_channels = ", ".join(self["bads"]) else: - bad_channels = 'None' + bad_channels = "None" return good_channels, bad_channels, ecg, eog @@ -1181,10 +1410,11 @@ def _get_chs_for_repr(self): def _repr_html_(self, caption=None): """Summarize info for HTML representation.""" from ..html_templates import repr_templates_env + if isinstance(caption, str): - html = f'

    {caption}

    ' + html = f"

    {caption}

    " else: - html = '' + html = "" good_channels, bad_channels, ecg, eog = self._get_chs_for_repr() @@ -1198,20 +1428,19 @@ def _repr_html_(self, caption=None): # repr). # meas date - meas_date = self.get('meas_date') + meas_date = self.get("meas_date") if meas_date is not None: - meas_date = meas_date.strftime("%B %d, %Y %H:%M:%S") + ' GMT' + meas_date = meas_date.strftime("%B %d, %Y %H:%M:%S") + " GMT" - projs = self.get('projs') + projs = self.get("projs") if projs: projs = [ - f'{p["desc"]} : {"on" if p["active"] else "off"}' - for p in self['projs'] + f'{p["desc"]} : {"on" if p["active"] else "off"}' for p in self["projs"] ] else: projs = None - info_template = repr_templates_env.get_template('info.html.jinja') + info_template = repr_templates_env.get_template("info.html.jinja") return html + info_template.render( caption=caption, meas_date=meas_date, @@ -1220,12 +1449,12 @@ def _repr_html_(self, caption=None): eog=eog, good_channels=good_channels, bad_channels=bad_channels, - dig=self.get('dig'), - subject_info=self.get('subject_info'), - lowpass=self.get('lowpass'), - highpass=self.get('highpass'), - sfreq=self.get('sfreq'), - experimenter=self.get('experimenter'), + dig=self.get("dig"), + subject_info=self.get("subject_info"), + lowpass=self.get("lowpass"), + highpass=self.get("highpass"), + sfreq=self.get("sfreq"), + experimenter=self.get("experimenter"), ) def save(self, fname): @@ -1241,12 +1470,17 @@ def save(self, fname): def _simplify_info(info): """Return a simplified info structure to speed up picking.""" - chs = [{key: ch[key] - for key in ('ch_name', 'kind', 'unit', 'coil_type', 'loc', 'cal')} - for ch in info['chs']] - sub_info = Info(chs=chs, bads=info['bads'], comps=info['comps'], - projs=info['projs'], - custom_ref_applied=info['custom_ref_applied']) + chs = [ + {key: ch[key] for key in ("ch_name", "kind", "unit", "coil_type", "loc", "cal")} + for ch in info["chs"] + ] + sub_info = Info( + chs=chs, + bads=info["bads"], + comps=info["comps"], + projs=info["projs"], + custom_ref_applied=info["custom_ref_applied"], + ) sub_info._update_redundant() return sub_info @@ -1269,20 +1503,16 @@ def read_fiducials(fname, verbose=None): The coordinate frame of the points (one of ``mne.io.constants.FIFF.FIFFV_COORD_...``). """ - fname = _check_fname( - fname=fname, - overwrite='read', - must_exist=True - ) + fname = _check_fname(fname=fname, overwrite="read", must_exist=True) fid, tree, _ = fiff_open(fname) with fid: isotrak = dir_tree_find(tree, FIFF.FIFFB_ISOTRAK) isotrak = isotrak[0] pts = [] coord_frame = FIFF.FIFFV_COORD_HEAD - for k in range(isotrak['nent']): - kind = isotrak['directory'][k].kind - pos = isotrak['directory'][k].pos + for k in range(isotrak["nent"]): + kind = isotrak["directory"][k].kind + pos = isotrak["directory"][k].pos if kind == FIFF.FIFF_DIG_POINT: tag = read_tag(fid, pos) pts.append(DigPoint(tag.data)) @@ -1293,14 +1523,15 @@ def read_fiducials(fname, verbose=None): # coord_frame is not stored in the tag for pt in pts: - pt['coord_frame'] = coord_frame + pt["coord_frame"] = coord_frame return pts, coord_frame @verbose -def write_fiducials(fname, pts, coord_frame='unknown', *, overwrite=False, - verbose=None): +def write_fiducials( + fname, pts, coord_frame="unknown", *, overwrite=False, verbose=None +): """Write fiducials to a fiff file. Parameters @@ -1372,7 +1603,7 @@ def _read_bad_channels(fid, node, ch_names_mapping): for node in nodes: tag = find_tag(fid, node, FIFF.FIFF_MNE_CH_NAME_LIST) if tag is not None and tag.data is not None: - bads = _safe_name_list(tag.data, 'read', 'bads') + bads = _safe_name_list(tag.data, "read", "bads") bads[:] = _rename_list(bads, ch_names_mapping) return bads @@ -1382,8 +1613,7 @@ def _write_bad_channels(fid, bads, ch_names_mapping): ch_names_mapping = {} if ch_names_mapping is None else ch_names_mapping bads = _rename_list(bads, ch_names_mapping) start_block(fid, FIFF.FIFFB_MNE_BAD_CHANNELS) - write_name_list_sanitized( - fid, FIFF.FIFF_MNE_CH_NAME_LIST, bads, 'bads') + write_name_list_sanitized(fid, FIFF.FIFF_MNE_CH_NAME_LIST, bads, "bads") end_block(fid, FIFF.FIFFB_MNE_BAD_CHANNELS) @@ -1412,16 +1642,16 @@ def read_meas_info(fid, tree, clean_bads=False, verbose=None): # Find the desired blocks meas = dir_tree_find(tree, FIFF.FIFFB_MEAS) if len(meas) == 0: - raise ValueError('Could not find measurement data') + raise ValueError("Could not find measurement data") if len(meas) > 1: - raise ValueError('Cannot read more that 1 measurement data') + raise ValueError("Cannot read more that 1 measurement data") meas = meas[0] meas_info = dir_tree_find(meas, FIFF.FIFFB_MEAS_INFO) if len(meas_info) == 0: - raise ValueError('Could not find measurement info') + raise ValueError("Could not find measurement info") if len(meas_info) > 1: - raise ValueError('Cannot read more that 1 measurement info') + raise ValueError("Cannot read more that 1 measurement info") meas_info = meas_info[0] # Read measurement info @@ -1444,9 +1674,9 @@ def read_meas_info(fid, tree, clean_bads=False, verbose=None): custom_ref_applied = FIFF.FIFFV_MNE_CUSTOM_REF_OFF xplotter_layout = None kit_system_id = None - for k in range(meas_info['nent']): - kind = meas_info['directory'][k].kind - pos = meas_info['directory'][k].pos + for k in range(meas_info["nent"]): + kind = meas_info["directory"][k].kind + pos = meas_info["directory"][k].pos if kind == FIFF.FIFF_NCHAN: tag = read_tag(fid, pos) nchan = int(tag.data.item()) @@ -1476,18 +1706,26 @@ def read_meas_info(fid, tree, clean_bads=False, verbose=None): tag = read_tag(fid, pos) cand = tag.data - if cand['from'] == FIFF.FIFFV_COORD_DEVICE and \ - cand['to'] == FIFF.FIFFV_COORD_HEAD: + if ( + cand["from"] == FIFF.FIFFV_COORD_DEVICE + and cand["to"] == FIFF.FIFFV_COORD_HEAD + ): dev_head_t = cand - elif cand['from'] == FIFF.FIFFV_COORD_HEAD and \ - cand['to'] == FIFF.FIFFV_COORD_DEVICE: + elif ( + cand["from"] == FIFF.FIFFV_COORD_HEAD + and cand["to"] == FIFF.FIFFV_COORD_DEVICE + ): # this reversal can happen with BabyMEG data dev_head_t = invert_transform(cand) - elif cand['from'] == FIFF.FIFFV_MNE_COORD_CTF_HEAD and \ - cand['to'] == FIFF.FIFFV_COORD_HEAD: + elif ( + cand["from"] == FIFF.FIFFV_MNE_COORD_CTF_HEAD + and cand["to"] == FIFF.FIFFV_COORD_HEAD + ): ctf_head_t = cand - elif cand['from'] == FIFF.FIFFV_MNE_COORD_CTF_DEVICE and \ - cand['to'] == FIFF.FIFFV_MNE_COORD_CTF_HEAD: + elif ( + cand["from"] == FIFF.FIFFV_MNE_COORD_CTF_DEVICE + and cand["to"] == FIFF.FIFFV_MNE_COORD_CTF_HEAD + ): dev_ctf_t = cand elif kind == FIFF.FIFF_EXPERIMENTER: tag = read_tag(fid, pos) @@ -1520,34 +1758,38 @@ def read_meas_info(fid, tree, clean_bads=False, verbose=None): # Check that we have everything we need if nchan is None: - raise ValueError('Number of channels is not defined') + raise ValueError("Number of channels is not defined") if sfreq is None: - raise ValueError('Sampling frequency is not defined') + raise ValueError("Sampling frequency is not defined") if len(chs) == 0: - raise ValueError('Channel information not defined') + raise ValueError("Channel information not defined") if len(chs) != nchan: - raise ValueError('Incorrect number of channel definitions found') + raise ValueError("Incorrect number of channel definitions found") if dev_head_t is None or ctf_head_t is None: hpi_result = dir_tree_find(meas_info, FIFF.FIFFB_HPI_RESULT) if len(hpi_result) == 1: hpi_result = hpi_result[0] - for k in range(hpi_result['nent']): - kind = hpi_result['directory'][k].kind - pos = hpi_result['directory'][k].pos + for k in range(hpi_result["nent"]): + kind = hpi_result["directory"][k].kind + pos = hpi_result["directory"][k].pos if kind == FIFF.FIFF_COORD_TRANS: tag = read_tag(fid, pos) cand = tag.data - if (cand['from'] == FIFF.FIFFV_COORD_DEVICE and - cand['to'] == FIFF.FIFFV_COORD_HEAD and - dev_head_t is None): + if ( + cand["from"] == FIFF.FIFFV_COORD_DEVICE + and cand["to"] == FIFF.FIFFV_COORD_HEAD + and dev_head_t is None + ): dev_head_t = cand - elif (cand['from'] == FIFF.FIFFV_MNE_COORD_CTF_HEAD and - cand['to'] == FIFF.FIFFV_COORD_HEAD and - ctf_head_t is None): + elif ( + cand["from"] == FIFF.FIFFV_MNE_COORD_CTF_HEAD + and cand["to"] == FIFF.FIFFV_COORD_HEAD + and ctf_head_t is None + ): ctf_head_t = cand # Locate the Polhemus data @@ -1559,9 +1801,9 @@ def read_meas_info(fid, tree, clean_bads=False, verbose=None): acq_stim = None if len(acqpars) == 1: acqpars = acqpars[0] - for k in range(acqpars['nent']): - kind = acqpars['directory'][k].kind - pos = acqpars['directory'][k].pos + for k in range(acqpars["nent"]): + kind = acqpars["directory"][k].kind + pos = acqpars["directory"][k].pos if kind == FIFF.FIFF_DACQ_PARS: tag = read_tag(fid, pos) acq_pars = tag.data @@ -1570,21 +1812,18 @@ def read_meas_info(fid, tree, clean_bads=False, verbose=None): acq_stim = tag.data # Load the SSP data - projs = _read_proj( - fid, meas_info, ch_names_mapping=ch_names_mapping) + projs = _read_proj(fid, meas_info, ch_names_mapping=ch_names_mapping) # Load the CTF compensation data - comps = _read_ctf_comp( - fid, meas_info, chs, ch_names_mapping=ch_names_mapping) + comps = _read_ctf_comp(fid, meas_info, chs, ch_names_mapping=ch_names_mapping) # Load the bad channel list - bads = _read_bad_channels( - fid, meas_info, ch_names_mapping=ch_names_mapping) + bads = _read_bad_channels(fid, meas_info, ch_names_mapping=ch_names_mapping) # # Put the data together # - info = Info(file_id=tree['id']) + info = Info(file_id=tree["id"]) info._unlocked = True # Locate events list @@ -1592,92 +1831,92 @@ def read_meas_info(fid, tree, clean_bads=False, verbose=None): evs = list() for event in events: ev = dict() - for k in range(event['nent']): - kind = event['directory'][k].kind - pos = event['directory'][k].pos + for k in range(event["nent"]): + kind = event["directory"][k].kind + pos = event["directory"][k].pos if kind == FIFF.FIFF_EVENT_CHANNELS: - ev['channels'] = read_tag(fid, pos).data + ev["channels"] = read_tag(fid, pos).data elif kind == FIFF.FIFF_EVENT_LIST: - ev['list'] = read_tag(fid, pos).data + ev["list"] = read_tag(fid, pos).data evs.append(ev) - info['events'] = evs + info["events"] = evs # Locate HPI result hpi_results = dir_tree_find(meas_info, FIFF.FIFFB_HPI_RESULT) hrs = list() for hpi_result in hpi_results: hr = dict() - hr['dig_points'] = [] - for k in range(hpi_result['nent']): - kind = hpi_result['directory'][k].kind - pos = hpi_result['directory'][k].pos + hr["dig_points"] = [] + for k in range(hpi_result["nent"]): + kind = hpi_result["directory"][k].kind + pos = hpi_result["directory"][k].pos if kind == FIFF.FIFF_DIG_POINT: - hr['dig_points'].append(read_tag(fid, pos).data) + hr["dig_points"].append(read_tag(fid, pos).data) elif kind == FIFF.FIFF_HPI_DIGITIZATION_ORDER: - hr['order'] = read_tag(fid, pos).data + hr["order"] = read_tag(fid, pos).data elif kind == FIFF.FIFF_HPI_COILS_USED: - hr['used'] = read_tag(fid, pos).data + hr["used"] = read_tag(fid, pos).data elif kind == FIFF.FIFF_HPI_COIL_MOMENTS: - hr['moments'] = read_tag(fid, pos).data + hr["moments"] = read_tag(fid, pos).data elif kind == FIFF.FIFF_HPI_FIT_GOODNESS: - hr['goodness'] = read_tag(fid, pos).data + hr["goodness"] = read_tag(fid, pos).data elif kind == FIFF.FIFF_HPI_FIT_GOOD_LIMIT: - hr['good_limit'] = float(read_tag(fid, pos).data.item()) + hr["good_limit"] = float(read_tag(fid, pos).data.item()) elif kind == FIFF.FIFF_HPI_FIT_DIST_LIMIT: - hr['dist_limit'] = float(read_tag(fid, pos).data.item()) + hr["dist_limit"] = float(read_tag(fid, pos).data.item()) elif kind == FIFF.FIFF_HPI_FIT_ACCEPT: - hr['accept'] = int(read_tag(fid, pos).data.item()) + hr["accept"] = int(read_tag(fid, pos).data.item()) elif kind == FIFF.FIFF_COORD_TRANS: - hr['coord_trans'] = read_tag(fid, pos).data + hr["coord_trans"] = read_tag(fid, pos).data hrs.append(hr) - info['hpi_results'] = hrs + info["hpi_results"] = hrs # Locate HPI Measurement hpi_meass = dir_tree_find(meas_info, FIFF.FIFFB_HPI_MEAS) hms = list() for hpi_meas in hpi_meass: hm = dict() - for k in range(hpi_meas['nent']): - kind = hpi_meas['directory'][k].kind - pos = hpi_meas['directory'][k].pos + for k in range(hpi_meas["nent"]): + kind = hpi_meas["directory"][k].kind + pos = hpi_meas["directory"][k].pos if kind == FIFF.FIFF_CREATOR: - hm['creator'] = str(read_tag(fid, pos).data) + hm["creator"] = str(read_tag(fid, pos).data) elif kind == FIFF.FIFF_SFREQ: - hm['sfreq'] = float(read_tag(fid, pos).data.item()) + hm["sfreq"] = float(read_tag(fid, pos).data.item()) elif kind == FIFF.FIFF_NCHAN: - hm['nchan'] = int(read_tag(fid, pos).data.item()) + hm["nchan"] = int(read_tag(fid, pos).data.item()) elif kind == FIFF.FIFF_NAVE: - hm['nave'] = int(read_tag(fid, pos).data.item()) + hm["nave"] = int(read_tag(fid, pos).data.item()) elif kind == FIFF.FIFF_HPI_NCOIL: - hm['ncoil'] = int(read_tag(fid, pos).data.item()) + hm["ncoil"] = int(read_tag(fid, pos).data.item()) elif kind == FIFF.FIFF_FIRST_SAMPLE: - hm['first_samp'] = int(read_tag(fid, pos).data.item()) + hm["first_samp"] = int(read_tag(fid, pos).data.item()) elif kind == FIFF.FIFF_LAST_SAMPLE: - hm['last_samp'] = int(read_tag(fid, pos).data.item()) + hm["last_samp"] = int(read_tag(fid, pos).data.item()) hpi_coils = dir_tree_find(hpi_meas, FIFF.FIFFB_HPI_COIL) hcs = [] for hpi_coil in hpi_coils: hc = dict() - for k in range(hpi_coil['nent']): - kind = hpi_coil['directory'][k].kind - pos = hpi_coil['directory'][k].pos + for k in range(hpi_coil["nent"]): + kind = hpi_coil["directory"][k].kind + pos = hpi_coil["directory"][k].pos if kind == FIFF.FIFF_HPI_COIL_NO: - hc['number'] = int(read_tag(fid, pos).data.item()) + hc["number"] = int(read_tag(fid, pos).data.item()) elif kind == FIFF.FIFF_EPOCH: - hc['epoch'] = read_tag(fid, pos).data - hc['epoch'].flags.writeable = False + hc["epoch"] = read_tag(fid, pos).data + hc["epoch"].flags.writeable = False elif kind == FIFF.FIFF_HPI_SLOPES: - hc['slopes'] = read_tag(fid, pos).data - hc['slopes'].flags.writeable = False + hc["slopes"] = read_tag(fid, pos).data + hc["slopes"].flags.writeable = False elif kind == FIFF.FIFF_HPI_CORR_COEFF: - hc['corr_coeff'] = read_tag(fid, pos).data - hc['corr_coeff'].flags.writeable = False + hc["corr_coeff"] = read_tag(fid, pos).data + hc["corr_coeff"].flags.writeable = False elif kind == FIFF.FIFF_HPI_COIL_FREQ: - hc['coil_freq'] = float(read_tag(fid, pos).data.item()) + hc["coil_freq"] = float(read_tag(fid, pos).data.item()) hcs.append(hc) - hm['hpi_coils'] = hcs + hm["hpi_coils"] = hcs hms.append(hm) - info['hpi_meas'] = hms + info["hpi_meas"] = hms del hms subject_info = dir_tree_find(meas_info, FIFF.FIFFB_SUBJECT) @@ -1685,47 +1924,49 @@ def read_meas_info(fid, tree, clean_bads=False, verbose=None): if len(subject_info) == 1: subject_info = subject_info[0] si = dict() - for k in range(subject_info['nent']): - kind = subject_info['directory'][k].kind - pos = subject_info['directory'][k].pos + for k in range(subject_info["nent"]): + kind = subject_info["directory"][k].kind + pos = subject_info["directory"][k].pos if kind == FIFF.FIFF_SUBJ_ID: tag = read_tag(fid, pos) - si['id'] = int(tag.data.item()) + si["id"] = int(tag.data.item()) elif kind == FIFF.FIFF_SUBJ_HIS_ID: tag = read_tag(fid, pos) - si['his_id'] = str(tag.data) + si["his_id"] = str(tag.data) elif kind == FIFF.FIFF_SUBJ_LAST_NAME: tag = read_tag(fid, pos) - si['last_name'] = str(tag.data) + si["last_name"] = str(tag.data) elif kind == FIFF.FIFF_SUBJ_FIRST_NAME: tag = read_tag(fid, pos) - si['first_name'] = str(tag.data) + si["first_name"] = str(tag.data) elif kind == FIFF.FIFF_SUBJ_MIDDLE_NAME: tag = read_tag(fid, pos) - si['middle_name'] = str(tag.data) + si["middle_name"] = str(tag.data) elif kind == FIFF.FIFF_SUBJ_BIRTH_DAY: try: tag = read_tag(fid, pos) except OverflowError: - warn('Encountered an error while trying to read the ' - 'birthday from the input data. No birthday will be ' - 'set. Please check the integrity of the birthday ' - 'information in the input data.') + warn( + "Encountered an error while trying to read the " + "birthday from the input data. No birthday will be " + "set. Please check the integrity of the birthday " + "information in the input data." + ) continue - si['birthday'] = tag.data + si["birthday"] = tag.data elif kind == FIFF.FIFF_SUBJ_SEX: tag = read_tag(fid, pos) - si['sex'] = int(tag.data.item()) + si["sex"] = int(tag.data.item()) elif kind == FIFF.FIFF_SUBJ_HAND: tag = read_tag(fid, pos) - si['hand'] = int(tag.data.item()) + si["hand"] = int(tag.data.item()) elif kind == FIFF.FIFF_SUBJ_WEIGHT: tag = read_tag(fid, pos) - si['weight'] = tag.data + si["weight"] = tag.data elif kind == FIFF.FIFF_SUBJ_HEIGHT: tag = read_tag(fid, pos) - si['height'] = tag.data - info['subject_info'] = si + si["height"] = tag.data + info["subject_info"] = si del si device_info = dir_tree_find(meas_info, FIFF.FIFFB_DEVICE) @@ -1733,22 +1974,22 @@ def read_meas_info(fid, tree, clean_bads=False, verbose=None): if len(device_info) == 1: device_info = device_info[0] di = dict() - for k in range(device_info['nent']): - kind = device_info['directory'][k].kind - pos = device_info['directory'][k].pos + for k in range(device_info["nent"]): + kind = device_info["directory"][k].kind + pos = device_info["directory"][k].pos if kind == FIFF.FIFF_DEVICE_TYPE: tag = read_tag(fid, pos) - di['type'] = str(tag.data) + di["type"] = str(tag.data) elif kind == FIFF.FIFF_DEVICE_MODEL: tag = read_tag(fid, pos) - di['model'] = str(tag.data) + di["model"] = str(tag.data) elif kind == FIFF.FIFF_DEVICE_SERIAL: tag = read_tag(fid, pos) - di['serial'] = str(tag.data) + di["serial"] = str(tag.data) elif kind == FIFF.FIFF_DEVICE_SITE: tag = read_tag(fid, pos) - di['site'] = str(tag.data) - info['device_info'] = di + di["site"] = str(tag.data) + info["device_info"] = di del di helium_info = dir_tree_find(meas_info, FIFF.FIFFB_HELIUM) @@ -1756,22 +1997,22 @@ def read_meas_info(fid, tree, clean_bads=False, verbose=None): if len(helium_info) == 1: helium_info = helium_info[0] hi = dict() - for k in range(helium_info['nent']): - kind = helium_info['directory'][k].kind - pos = helium_info['directory'][k].pos + for k in range(helium_info["nent"]): + kind = helium_info["directory"][k].kind + pos = helium_info["directory"][k].pos if kind == FIFF.FIFF_HE_LEVEL_RAW: tag = read_tag(fid, pos) - hi['he_level_raw'] = float(tag.data.item()) + hi["he_level_raw"] = float(tag.data.item()) elif kind == FIFF.FIFF_HELIUM_LEVEL: tag = read_tag(fid, pos) - hi['helium_level'] = float(tag.data.item()) + hi["helium_level"] = float(tag.data.item()) elif kind == FIFF.FIFF_ORIG_FILE_GUID: tag = read_tag(fid, pos) - hi['orig_file_guid'] = str(tag.data) + hi["orig_file_guid"] = str(tag.data) elif kind == FIFF.FIFF_MEAS_DATE: tag = read_tag(fid, pos) - hi['meas_date'] = tuple(int(t) for t in tag.data) - info['helium_info'] = hi + hi["meas_date"] = tuple(int(t) for t in tag.data) + info["helium_info"] = hi del hi hpi_subsystem = dir_tree_find(meas_info, FIFF.FIFFB_HPI_SUBSYSTEM) @@ -1779,90 +2020,91 @@ def read_meas_info(fid, tree, clean_bads=False, verbose=None): if len(hpi_subsystem) == 1: hpi_subsystem = hpi_subsystem[0] hs = dict() - for k in range(hpi_subsystem['nent']): - kind = hpi_subsystem['directory'][k].kind - pos = hpi_subsystem['directory'][k].pos + for k in range(hpi_subsystem["nent"]): + kind = hpi_subsystem["directory"][k].kind + pos = hpi_subsystem["directory"][k].pos if kind == FIFF.FIFF_HPI_NCOIL: tag = read_tag(fid, pos) - hs['ncoil'] = int(tag.data.item()) + hs["ncoil"] = int(tag.data.item()) elif kind == FIFF.FIFF_EVENT_CHANNEL: tag = read_tag(fid, pos) - hs['event_channel'] = str(tag.data) + hs["event_channel"] = str(tag.data) hpi_coils = dir_tree_find(hpi_subsystem, FIFF.FIFFB_HPI_COIL) hc = [] for coil in hpi_coils: this_coil = dict() - for j in range(coil['nent']): - kind = coil['directory'][j].kind - pos = coil['directory'][j].pos + for j in range(coil["nent"]): + kind = coil["directory"][j].kind + pos = coil["directory"][j].pos if kind == FIFF.FIFF_EVENT_BITS: tag = read_tag(fid, pos) - this_coil['event_bits'] = np.array(tag.data) + this_coil["event_bits"] = np.array(tag.data) hc.append(this_coil) - hs['hpi_coils'] = hc - info['hpi_subsystem'] = hs + hs["hpi_coils"] = hc + info["hpi_subsystem"] = hs # Read processing history - info['proc_history'] = _read_proc_history(fid, tree) + info["proc_history"] = _read_proc_history(fid, tree) # Make the most appropriate selection for the measurement id - if meas_info['parent_id'] is None: - if meas_info['id'] is None: - if meas['id'] is None: - if meas['parent_id'] is None: - info['meas_id'] = info['file_id'] + if meas_info["parent_id"] is None: + if meas_info["id"] is None: + if meas["id"] is None: + if meas["parent_id"] is None: + info["meas_id"] = info["file_id"] else: - info['meas_id'] = meas['parent_id'] + info["meas_id"] = meas["parent_id"] else: - info['meas_id'] = meas['id'] + info["meas_id"] = meas["id"] else: - info['meas_id'] = meas_info['id'] + info["meas_id"] = meas_info["id"] else: - info['meas_id'] = meas_info['parent_id'] - info['experimenter'] = experimenter - info['description'] = description - info['proj_id'] = proj_id - info['proj_name'] = proj_name + info["meas_id"] = meas_info["parent_id"] + info["experimenter"] = experimenter + info["description"] = description + info["proj_id"] = proj_id + info["proj_name"] = proj_name if meas_date is None: - meas_date = (info['meas_id']['secs'], info['meas_id']['usecs']) - info['meas_date'] = _ensure_meas_date_none_or_dt(meas_date) - info['utc_offset'] = utc_offset + meas_date = (info["meas_id"]["secs"], info["meas_id"]["usecs"]) + info["meas_date"] = _ensure_meas_date_none_or_dt(meas_date) + info["utc_offset"] = utc_offset - info['sfreq'] = sfreq - info['highpass'] = highpass if highpass is not None else 0. - info['lowpass'] = lowpass if lowpass is not None else info['sfreq'] / 2.0 - info['line_freq'] = line_freq - info['gantry_angle'] = gantry_angle + info["sfreq"] = sfreq + info["highpass"] = highpass if highpass is not None else 0.0 + info["lowpass"] = lowpass if lowpass is not None else info["sfreq"] / 2.0 + info["line_freq"] = line_freq + info["gantry_angle"] = gantry_angle # Add the channel information and make a list of channel names # for convenience - info['chs'] = chs + info["chs"] = chs # # Add the coordinate transformations # - info['dev_head_t'] = dev_head_t - info['ctf_head_t'] = ctf_head_t - info['dev_ctf_t'] = dev_ctf_t + info["dev_head_t"] = dev_head_t + info["ctf_head_t"] = ctf_head_t + info["dev_ctf_t"] = dev_ctf_t if dev_head_t is not None and ctf_head_t is not None and dev_ctf_t is None: from ..transforms import Transform - head_ctf_trans = np.linalg.inv(ctf_head_t['trans']) - dev_ctf_trans = np.dot(head_ctf_trans, info['dev_head_t']['trans']) - info['dev_ctf_t'] = Transform('meg', 'ctf_head', dev_ctf_trans) + + head_ctf_trans = np.linalg.inv(ctf_head_t["trans"]) + dev_ctf_trans = np.dot(head_ctf_trans, info["dev_head_t"]["trans"]) + info["dev_ctf_t"] = Transform("meg", "ctf_head", dev_ctf_trans) # All kinds of auxliary stuff - info['dig'] = _format_dig_points(dig) - info['bads'] = bads + info["dig"] = _format_dig_points(dig) + info["bads"] = bads info._update_redundant() if clean_bads: - info['bads'] = [b for b in bads if b in info['ch_names']] - info['projs'] = projs - info['comps'] = comps - info['acq_pars'] = acq_pars - info['acq_stim'] = acq_stim - info['custom_ref_applied'] = custom_ref_applied - info['xplotter_layout'] = xplotter_layout - info['kit_system_id'] = kit_system_id + info["bads"] = [b for b in bads if b in info["ch_names"]] + info["projs"] = projs + info["comps"] = comps + info["acq_pars"] = acq_pars + info["acq_stim"] = acq_stim + info["custom_ref_applied"] = custom_ref_applied + info["xplotter_layout"] = xplotter_layout + info["kit_system_id"] = kit_system_id info._check_consistency() info._unlocked = False return info, meas @@ -1872,27 +2114,27 @@ def _read_extended_ch_info(chs, parent, fid): ch_infos = dir_tree_find(parent, FIFF.FIFFB_CH_INFO) if len(ch_infos) == 0: return - _check_option('length of channel infos', len(ch_infos), [len(chs)]) - logger.info(' Reading extended channel information') + _check_option("length of channel infos", len(ch_infos), [len(chs)]) + logger.info(" Reading extended channel information") # Here we assume that ``remap`` is in the same order as the channels # themselves, which is hopefully safe enough. ch_names_mapping = dict() for new, ch in zip(ch_infos, chs): - for k in range(new['nent']): - kind = new['directory'][k].kind + for k in range(new["nent"]): + kind = new["directory"][k].kind try: key, cast = _CH_READ_MAP[kind] except KeyError: # This shouldn't happen if we're up to date with the FIFF # spec - warn(f'Discarding extra channel information kind {kind}') + warn(f"Discarding extra channel information kind {kind}") continue assert key in ch - data = read_tag(fid, new['directory'][k].pos).data + data = read_tag(fid, new["directory"][k].pos).data if data is not None: data = cast(data) - if key == 'ch_name': + if key == "ch_name": ch_names_mapping[ch[key]] = data ch[key] = data _update_ch_info_named(ch) @@ -1905,8 +2147,8 @@ def _rename_comps(comps, ch_names_mapping): if not (comps and ch_names_mapping): return for comp in comps: - data = comp['data'] - for key in ('row_names', 'col_names'): + data = comp["data"] + for key in ("row_names", "col_names"): data[key][:] = _rename_list(data[key], ch_names_mapping) @@ -1918,38 +2160,53 @@ def _ensure_meas_date_none_or_dt(meas_date): return meas_date -def _check_dates(info, prepend_error=''): +def _check_dates(info, prepend_error=""): """Check dates before writing as fif files. It's needed because of the limited integer precision of the fix standard. """ - for key in ('file_id', 'meas_id'): + for key in ("file_id", "meas_id"): value = info.get(key) if value is not None: - assert 'msecs' not in value - for key_2 in ('secs', 'usecs'): - if (value[key_2] < np.iinfo('>i4').min or - value[key_2] > np.iinfo('>i4').max): - raise RuntimeError('%sinfo[%s][%s] must be between ' - '"%r" and "%r", got "%r"' - % (prepend_error, key, key_2, - np.iinfo('>i4').min, - np.iinfo('>i4').max, - value[key_2]),) - - meas_date = info.get('meas_date') + assert "msecs" not in value + for key_2 in ("secs", "usecs"): + if ( + value[key_2] < np.iinfo(">i4").min + or value[key_2] > np.iinfo(">i4").max + ): + raise RuntimeError( + "%sinfo[%s][%s] must be between " + '"%r" and "%r", got "%r"' + % ( + prepend_error, + key, + key_2, + np.iinfo(">i4").min, + np.iinfo(">i4").max, + value[key_2], + ), + ) + + meas_date = info.get("meas_date") if meas_date is None: return meas_date_stamp = _dt_to_stamp(meas_date) - if (meas_date_stamp[0] < np.iinfo('>i4').min or - meas_date_stamp[0] > np.iinfo('>i4').max): + if ( + meas_date_stamp[0] < np.iinfo(">i4").min + or meas_date_stamp[0] > np.iinfo(">i4").max + ): raise RuntimeError( '%sinfo["meas_date"] seconds must be between "%r" ' 'and "%r", got "%r"' - % (prepend_error, (np.iinfo('>i4').min, 0), - (np.iinfo('>i4').max, 0), meas_date_stamp[0],)) + % ( + prepend_error, + (np.iinfo(">i4").min, 0), + (np.iinfo(">i4").max, 0), + meas_date_stamp[0], + ) + ) @fill_doc @@ -1979,220 +2236,211 @@ def write_meas_info(fid, info, data_type=None, reset_range=True): start_block(fid, FIFF.FIFFB_MEAS_INFO) # Add measurement id - if info['meas_id'] is not None: - write_id(fid, FIFF.FIFF_PARENT_BLOCK_ID, info['meas_id']) + if info["meas_id"] is not None: + write_id(fid, FIFF.FIFF_PARENT_BLOCK_ID, info["meas_id"]) - for event in info['events']: + for event in info["events"]: start_block(fid, FIFF.FIFFB_EVENTS) - if event.get('channels') is not None: - write_int(fid, FIFF.FIFF_EVENT_CHANNELS, event['channels']) - if event.get('list') is not None: - write_int(fid, FIFF.FIFF_EVENT_LIST, event['list']) + if event.get("channels") is not None: + write_int(fid, FIFF.FIFF_EVENT_CHANNELS, event["channels"]) + if event.get("list") is not None: + write_int(fid, FIFF.FIFF_EVENT_LIST, event["list"]) end_block(fid, FIFF.FIFFB_EVENTS) # HPI Result - for hpi_result in info['hpi_results']: + for hpi_result in info["hpi_results"]: start_block(fid, FIFF.FIFFB_HPI_RESULT) - write_dig_points(fid, hpi_result['dig_points']) - if 'order' in hpi_result: - write_int(fid, FIFF.FIFF_HPI_DIGITIZATION_ORDER, - hpi_result['order']) - if 'used' in hpi_result: - write_int(fid, FIFF.FIFF_HPI_COILS_USED, hpi_result['used']) - if 'moments' in hpi_result: - write_float_matrix(fid, FIFF.FIFF_HPI_COIL_MOMENTS, - hpi_result['moments']) - if 'goodness' in hpi_result: - write_float(fid, FIFF.FIFF_HPI_FIT_GOODNESS, - hpi_result['goodness']) - if 'good_limit' in hpi_result: - write_float(fid, FIFF.FIFF_HPI_FIT_GOOD_LIMIT, - hpi_result['good_limit']) - if 'dist_limit' in hpi_result: - write_float(fid, FIFF.FIFF_HPI_FIT_DIST_LIMIT, - hpi_result['dist_limit']) - if 'accept' in hpi_result: - write_int(fid, FIFF.FIFF_HPI_FIT_ACCEPT, hpi_result['accept']) - if 'coord_trans' in hpi_result: - write_coord_trans(fid, hpi_result['coord_trans']) + write_dig_points(fid, hpi_result["dig_points"]) + if "order" in hpi_result: + write_int(fid, FIFF.FIFF_HPI_DIGITIZATION_ORDER, hpi_result["order"]) + if "used" in hpi_result: + write_int(fid, FIFF.FIFF_HPI_COILS_USED, hpi_result["used"]) + if "moments" in hpi_result: + write_float_matrix(fid, FIFF.FIFF_HPI_COIL_MOMENTS, hpi_result["moments"]) + if "goodness" in hpi_result: + write_float(fid, FIFF.FIFF_HPI_FIT_GOODNESS, hpi_result["goodness"]) + if "good_limit" in hpi_result: + write_float(fid, FIFF.FIFF_HPI_FIT_GOOD_LIMIT, hpi_result["good_limit"]) + if "dist_limit" in hpi_result: + write_float(fid, FIFF.FIFF_HPI_FIT_DIST_LIMIT, hpi_result["dist_limit"]) + if "accept" in hpi_result: + write_int(fid, FIFF.FIFF_HPI_FIT_ACCEPT, hpi_result["accept"]) + if "coord_trans" in hpi_result: + write_coord_trans(fid, hpi_result["coord_trans"]) end_block(fid, FIFF.FIFFB_HPI_RESULT) # HPI Measurement - for hpi_meas in info['hpi_meas']: + for hpi_meas in info["hpi_meas"]: start_block(fid, FIFF.FIFFB_HPI_MEAS) - if hpi_meas.get('creator') is not None: - write_string(fid, FIFF.FIFF_CREATOR, hpi_meas['creator']) - if hpi_meas.get('sfreq') is not None: - write_float(fid, FIFF.FIFF_SFREQ, hpi_meas['sfreq']) - if hpi_meas.get('nchan') is not None: - write_int(fid, FIFF.FIFF_NCHAN, hpi_meas['nchan']) - if hpi_meas.get('nave') is not None: - write_int(fid, FIFF.FIFF_NAVE, hpi_meas['nave']) - if hpi_meas.get('ncoil') is not None: - write_int(fid, FIFF.FIFF_HPI_NCOIL, hpi_meas['ncoil']) - if hpi_meas.get('first_samp') is not None: - write_int(fid, FIFF.FIFF_FIRST_SAMPLE, hpi_meas['first_samp']) - if hpi_meas.get('last_samp') is not None: - write_int(fid, FIFF.FIFF_LAST_SAMPLE, hpi_meas['last_samp']) - for hpi_coil in hpi_meas['hpi_coils']: + if hpi_meas.get("creator") is not None: + write_string(fid, FIFF.FIFF_CREATOR, hpi_meas["creator"]) + if hpi_meas.get("sfreq") is not None: + write_float(fid, FIFF.FIFF_SFREQ, hpi_meas["sfreq"]) + if hpi_meas.get("nchan") is not None: + write_int(fid, FIFF.FIFF_NCHAN, hpi_meas["nchan"]) + if hpi_meas.get("nave") is not None: + write_int(fid, FIFF.FIFF_NAVE, hpi_meas["nave"]) + if hpi_meas.get("ncoil") is not None: + write_int(fid, FIFF.FIFF_HPI_NCOIL, hpi_meas["ncoil"]) + if hpi_meas.get("first_samp") is not None: + write_int(fid, FIFF.FIFF_FIRST_SAMPLE, hpi_meas["first_samp"]) + if hpi_meas.get("last_samp") is not None: + write_int(fid, FIFF.FIFF_LAST_SAMPLE, hpi_meas["last_samp"]) + for hpi_coil in hpi_meas["hpi_coils"]: start_block(fid, FIFF.FIFFB_HPI_COIL) - if hpi_coil.get('number') is not None: - write_int(fid, FIFF.FIFF_HPI_COIL_NO, hpi_coil['number']) - if hpi_coil.get('epoch') is not None: - write_float_matrix(fid, FIFF.FIFF_EPOCH, hpi_coil['epoch']) - if hpi_coil.get('slopes') is not None: - write_float(fid, FIFF.FIFF_HPI_SLOPES, hpi_coil['slopes']) - if hpi_coil.get('corr_coeff') is not None: - write_float(fid, FIFF.FIFF_HPI_CORR_COEFF, - hpi_coil['corr_coeff']) - if hpi_coil.get('coil_freq') is not None: - write_float(fid, FIFF.FIFF_HPI_COIL_FREQ, - hpi_coil['coil_freq']) + if hpi_coil.get("number") is not None: + write_int(fid, FIFF.FIFF_HPI_COIL_NO, hpi_coil["number"]) + if hpi_coil.get("epoch") is not None: + write_float_matrix(fid, FIFF.FIFF_EPOCH, hpi_coil["epoch"]) + if hpi_coil.get("slopes") is not None: + write_float(fid, FIFF.FIFF_HPI_SLOPES, hpi_coil["slopes"]) + if hpi_coil.get("corr_coeff") is not None: + write_float(fid, FIFF.FIFF_HPI_CORR_COEFF, hpi_coil["corr_coeff"]) + if hpi_coil.get("coil_freq") is not None: + write_float(fid, FIFF.FIFF_HPI_COIL_FREQ, hpi_coil["coil_freq"]) end_block(fid, FIFF.FIFFB_HPI_COIL) end_block(fid, FIFF.FIFFB_HPI_MEAS) # Polhemus data - write_dig_points(fid, info['dig'], block=True) + write_dig_points(fid, info["dig"], block=True) # megacq parameters - if info['acq_pars'] is not None or info['acq_stim'] is not None: + if info["acq_pars"] is not None or info["acq_stim"] is not None: start_block(fid, FIFF.FIFFB_DACQ_PARS) - if info['acq_pars'] is not None: - write_string(fid, FIFF.FIFF_DACQ_PARS, info['acq_pars']) + if info["acq_pars"] is not None: + write_string(fid, FIFF.FIFF_DACQ_PARS, info["acq_pars"]) - if info['acq_stim'] is not None: - write_string(fid, FIFF.FIFF_DACQ_STIM, info['acq_stim']) + if info["acq_stim"] is not None: + write_string(fid, FIFF.FIFF_DACQ_STIM, info["acq_stim"]) end_block(fid, FIFF.FIFFB_DACQ_PARS) # Coordinate transformations if the HPI result block was not there - if info['dev_head_t'] is not None: - write_coord_trans(fid, info['dev_head_t']) + if info["dev_head_t"] is not None: + write_coord_trans(fid, info["dev_head_t"]) - if info['ctf_head_t'] is not None: - write_coord_trans(fid, info['ctf_head_t']) + if info["ctf_head_t"] is not None: + write_coord_trans(fid, info["ctf_head_t"]) - if info['dev_ctf_t'] is not None: - write_coord_trans(fid, info['dev_ctf_t']) + if info["dev_ctf_t"] is not None: + write_coord_trans(fid, info["dev_ctf_t"]) # Projectors - ch_names_mapping = _make_ch_names_mapping(info['chs']) - _write_proj(fid, info['projs'], ch_names_mapping=ch_names_mapping) + ch_names_mapping = _make_ch_names_mapping(info["chs"]) + _write_proj(fid, info["projs"], ch_names_mapping=ch_names_mapping) # Bad channels - _write_bad_channels(fid, info['bads'], ch_names_mapping=ch_names_mapping) + _write_bad_channels(fid, info["bads"], ch_names_mapping=ch_names_mapping) # General - if info.get('experimenter') is not None: - write_string(fid, FIFF.FIFF_EXPERIMENTER, info['experimenter']) - if info.get('description') is not None: - write_string(fid, FIFF.FIFF_DESCRIPTION, info['description']) - if info.get('proj_id') is not None: - write_int(fid, FIFF.FIFF_PROJ_ID, info['proj_id']) - if info.get('proj_name') is not None: - write_string(fid, FIFF.FIFF_PROJ_NAME, info['proj_name']) - if info.get('meas_date') is not None: - write_int(fid, FIFF.FIFF_MEAS_DATE, _dt_to_stamp(info['meas_date'])) - if info.get('utc_offset') is not None: - write_string(fid, FIFF.FIFF_UTC_OFFSET, info['utc_offset']) - write_int(fid, FIFF.FIFF_NCHAN, info['nchan']) - write_float(fid, FIFF.FIFF_SFREQ, info['sfreq']) - if info['lowpass'] is not None: - write_float(fid, FIFF.FIFF_LOWPASS, info['lowpass']) - if info['highpass'] is not None: - write_float(fid, FIFF.FIFF_HIGHPASS, info['highpass']) - if info.get('line_freq') is not None: - write_float(fid, FIFF.FIFF_LINE_FREQ, info['line_freq']) - if info.get('gantry_angle') is not None: - write_float(fid, FIFF.FIFF_GANTRY_ANGLE, info['gantry_angle']) + if info.get("experimenter") is not None: + write_string(fid, FIFF.FIFF_EXPERIMENTER, info["experimenter"]) + if info.get("description") is not None: + write_string(fid, FIFF.FIFF_DESCRIPTION, info["description"]) + if info.get("proj_id") is not None: + write_int(fid, FIFF.FIFF_PROJ_ID, info["proj_id"]) + if info.get("proj_name") is not None: + write_string(fid, FIFF.FIFF_PROJ_NAME, info["proj_name"]) + if info.get("meas_date") is not None: + write_int(fid, FIFF.FIFF_MEAS_DATE, _dt_to_stamp(info["meas_date"])) + if info.get("utc_offset") is not None: + write_string(fid, FIFF.FIFF_UTC_OFFSET, info["utc_offset"]) + write_int(fid, FIFF.FIFF_NCHAN, info["nchan"]) + write_float(fid, FIFF.FIFF_SFREQ, info["sfreq"]) + if info["lowpass"] is not None: + write_float(fid, FIFF.FIFF_LOWPASS, info["lowpass"]) + if info["highpass"] is not None: + write_float(fid, FIFF.FIFF_HIGHPASS, info["highpass"]) + if info.get("line_freq") is not None: + write_float(fid, FIFF.FIFF_LINE_FREQ, info["line_freq"]) + if info.get("gantry_angle") is not None: + write_float(fid, FIFF.FIFF_GANTRY_ANGLE, info["gantry_angle"]) if data_type is not None: write_int(fid, FIFF.FIFF_DATA_PACK, data_type) - if info.get('custom_ref_applied'): - write_int(fid, FIFF.FIFF_MNE_CUSTOM_REF, info['custom_ref_applied']) - if info.get('xplotter_layout'): - write_string(fid, FIFF.FIFF_XPLOTTER_LAYOUT, info['xplotter_layout']) + if info.get("custom_ref_applied"): + write_int(fid, FIFF.FIFF_MNE_CUSTOM_REF, info["custom_ref_applied"]) + if info.get("xplotter_layout"): + write_string(fid, FIFF.FIFF_XPLOTTER_LAYOUT, info["xplotter_layout"]) # Channel information - _write_ch_infos(fid, info['chs'], reset_range, ch_names_mapping) + _write_ch_infos(fid, info["chs"], reset_range, ch_names_mapping) # Subject information - if info.get('subject_info') is not None: + if info.get("subject_info") is not None: start_block(fid, FIFF.FIFFB_SUBJECT) - si = info['subject_info'] - if si.get('id') is not None: - write_int(fid, FIFF.FIFF_SUBJ_ID, si['id']) - if si.get('his_id') is not None: - write_string(fid, FIFF.FIFF_SUBJ_HIS_ID, si['his_id']) - if si.get('last_name') is not None: - write_string(fid, FIFF.FIFF_SUBJ_LAST_NAME, si['last_name']) - if si.get('first_name') is not None: - write_string(fid, FIFF.FIFF_SUBJ_FIRST_NAME, si['first_name']) - if si.get('middle_name') is not None: - write_string(fid, FIFF.FIFF_SUBJ_MIDDLE_NAME, si['middle_name']) - if si.get('birthday') is not None: - write_julian(fid, FIFF.FIFF_SUBJ_BIRTH_DAY, si['birthday']) - if si.get('sex') is not None: - write_int(fid, FIFF.FIFF_SUBJ_SEX, si['sex']) - if si.get('hand') is not None: - write_int(fid, FIFF.FIFF_SUBJ_HAND, si['hand']) - if si.get('weight') is not None: - write_float(fid, FIFF.FIFF_SUBJ_WEIGHT, si['weight']) - if si.get('height') is not None: - write_float(fid, FIFF.FIFF_SUBJ_HEIGHT, si['height']) + si = info["subject_info"] + if si.get("id") is not None: + write_int(fid, FIFF.FIFF_SUBJ_ID, si["id"]) + if si.get("his_id") is not None: + write_string(fid, FIFF.FIFF_SUBJ_HIS_ID, si["his_id"]) + if si.get("last_name") is not None: + write_string(fid, FIFF.FIFF_SUBJ_LAST_NAME, si["last_name"]) + if si.get("first_name") is not None: + write_string(fid, FIFF.FIFF_SUBJ_FIRST_NAME, si["first_name"]) + if si.get("middle_name") is not None: + write_string(fid, FIFF.FIFF_SUBJ_MIDDLE_NAME, si["middle_name"]) + if si.get("birthday") is not None: + write_julian(fid, FIFF.FIFF_SUBJ_BIRTH_DAY, si["birthday"]) + if si.get("sex") is not None: + write_int(fid, FIFF.FIFF_SUBJ_SEX, si["sex"]) + if si.get("hand") is not None: + write_int(fid, FIFF.FIFF_SUBJ_HAND, si["hand"]) + if si.get("weight") is not None: + write_float(fid, FIFF.FIFF_SUBJ_WEIGHT, si["weight"]) + if si.get("height") is not None: + write_float(fid, FIFF.FIFF_SUBJ_HEIGHT, si["height"]) end_block(fid, FIFF.FIFFB_SUBJECT) del si - if info.get('device_info') is not None: + if info.get("device_info") is not None: start_block(fid, FIFF.FIFFB_DEVICE) - di = info['device_info'] - write_string(fid, FIFF.FIFF_DEVICE_TYPE, di['type']) - for key in ('model', 'serial', 'site'): + di = info["device_info"] + write_string(fid, FIFF.FIFF_DEVICE_TYPE, di["type"]) + for key in ("model", "serial", "site"): if di.get(key) is not None: - write_string(fid, getattr(FIFF, 'FIFF_DEVICE_' + key.upper()), - di[key]) + write_string(fid, getattr(FIFF, "FIFF_DEVICE_" + key.upper()), di[key]) end_block(fid, FIFF.FIFFB_DEVICE) del di - if info.get('helium_info') is not None: + if info.get("helium_info") is not None: start_block(fid, FIFF.FIFFB_HELIUM) - hi = info['helium_info'] - if hi.get('he_level_raw') is not None: - write_float(fid, FIFF.FIFF_HE_LEVEL_RAW, hi['he_level_raw']) - if hi.get('helium_level') is not None: - write_float(fid, FIFF.FIFF_HELIUM_LEVEL, hi['helium_level']) - if hi.get('orig_file_guid') is not None: - write_string(fid, FIFF.FIFF_ORIG_FILE_GUID, hi['orig_file_guid']) - write_int(fid, FIFF.FIFF_MEAS_DATE, hi['meas_date']) + hi = info["helium_info"] + if hi.get("he_level_raw") is not None: + write_float(fid, FIFF.FIFF_HE_LEVEL_RAW, hi["he_level_raw"]) + if hi.get("helium_level") is not None: + write_float(fid, FIFF.FIFF_HELIUM_LEVEL, hi["helium_level"]) + if hi.get("orig_file_guid") is not None: + write_string(fid, FIFF.FIFF_ORIG_FILE_GUID, hi["orig_file_guid"]) + write_int(fid, FIFF.FIFF_MEAS_DATE, hi["meas_date"]) end_block(fid, FIFF.FIFFB_HELIUM) del hi - if info.get('hpi_subsystem') is not None: - hs = info['hpi_subsystem'] + if info.get("hpi_subsystem") is not None: + hs = info["hpi_subsystem"] start_block(fid, FIFF.FIFFB_HPI_SUBSYSTEM) - if hs.get('ncoil') is not None: - write_int(fid, FIFF.FIFF_HPI_NCOIL, hs['ncoil']) - if hs.get('event_channel') is not None: - write_string(fid, FIFF.FIFF_EVENT_CHANNEL, hs['event_channel']) - if hs.get('hpi_coils') is not None: - for coil in hs['hpi_coils']: + if hs.get("ncoil") is not None: + write_int(fid, FIFF.FIFF_HPI_NCOIL, hs["ncoil"]) + if hs.get("event_channel") is not None: + write_string(fid, FIFF.FIFF_EVENT_CHANNEL, hs["event_channel"]) + if hs.get("hpi_coils") is not None: + for coil in hs["hpi_coils"]: start_block(fid, FIFF.FIFFB_HPI_COIL) - if coil.get('event_bits') is not None: - write_int(fid, FIFF.FIFF_EVENT_BITS, - coil['event_bits']) + if coil.get("event_bits") is not None: + write_int(fid, FIFF.FIFF_EVENT_BITS, coil["event_bits"]) end_block(fid, FIFF.FIFFB_HPI_COIL) end_block(fid, FIFF.FIFFB_HPI_SUBSYSTEM) del hs # CTF compensation info - comps = info['comps'] + comps = info["comps"] if ch_names_mapping: comps = deepcopy(comps) _rename_comps(comps, ch_names_mapping) write_ctf_comp(fid, comps) # KIT system ID - if info.get('kit_system_id') is not None: - write_int(fid, FIFF.FIFF_MNE_KIT_SYSTEM_ID, info['kit_system_id']) + if info.get("kit_system_id") is not None: + write_int(fid, FIFF.FIFF_MNE_KIT_SYSTEM_ID, info["kit_system_id"]) end_block(fid, FIFF.FIFFB_MEAS_INFO) @@ -2232,9 +2480,10 @@ def _merge_info_values(infos, key, verbose=None): Does special things for "projs", "bads", and "meas_date". """ values = [d[key] for d in infos] - msg = ("Don't know how to merge '%s'. Make sure values are " - "compatible, got types:\n %s" - % (key, [type(v) for v in values])) + msg = ( + "Don't know how to merge '%s'. Make sure values are " + "compatible, got types:\n %s" % (key, [type(v) for v in values]) + ) def _flatten(lists): return [item for sublist in lists for item in sublist] @@ -2249,9 +2498,9 @@ def _where_isinstance(values, kind): # list if _check_isinstance(values, list, all): lists = (d[key] for d in infos) - if key == 'projs': + if key == "projs": return _uniquify_projs(_flatten(lists)) - elif key == 'bads': + elif key == "bads": return sorted(set(_flatten(lists))) else: return _flatten(lists) @@ -2264,7 +2513,7 @@ def _where_isinstance(values, kind): return _flatten(lists) # dict elif _check_isinstance(values, dict, all): - is_qual = all(object_diff(values[0], v) == '' for v in values[1:]) + is_qual = all(object_diff(values[0], v) == "" for v in values[1:]) if is_qual: return values[0] else: @@ -2276,14 +2525,16 @@ def _where_isinstance(values, kind): elif len(idx) > 1: raise RuntimeError(msg) # ndarray - elif _check_isinstance(values, np.ndarray, all) or \ - _check_isinstance(values, tuple, all): + elif _check_isinstance(values, np.ndarray, all) or _check_isinstance( + values, tuple, all + ): is_qual = all(np.array_equal(values[0], x) for x in values[1:]) if is_qual: return values[0] - elif key == 'meas_date': - logger.info('Found multiple entries for %s. ' - 'Setting value to `None`' % key) + elif key == "meas_date": + logger.info( + "Found multiple entries for %s. " "Setting value to `None`" % key + ) return None else: raise RuntimeError(msg) @@ -2299,12 +2550,10 @@ def _where_isinstance(values, kind): if len(unique_values) == 1: return list(values)[0] elif isinstance(list(unique_values)[0], BytesIO): - logger.info('Found multiple StringIO instances. ' - 'Setting value to `None`') + logger.info("Found multiple StringIO instances. " "Setting value to `None`") return None elif isinstance(list(unique_values)[0], str): - logger.info('Found multiple filenames. ' - 'Setting value to `None`') + logger.info("Found multiple filenames. " "Setting value to `None`") return None else: raise RuntimeError(msg) @@ -2345,77 +2594,100 @@ def _merge_info(infos, force_update_to_first=False, verbose=None): _force_update_info(infos[0], infos[1:]) info = Info() info._unlocked = True - info['chs'] = [] + info["chs"] = [] for this_info in infos: - info['chs'].extend(this_info['chs']) + info["chs"].extend(this_info["chs"]) info._update_redundant() - duplicates = {ch for ch in info['ch_names'] - if info['ch_names'].count(ch) > 1} + duplicates = {ch for ch in info["ch_names"] if info["ch_names"].count(ch) > 1} if len(duplicates) > 0: - msg = ("The following channels are present in more than one input " - "measurement info objects: %s" % list(duplicates)) + msg = ( + "The following channels are present in more than one input " + "measurement info objects: %s" % list(duplicates) + ) raise ValueError(msg) - transforms = ['ctf_head_t', 'dev_head_t', 'dev_ctf_t'] + transforms = ["ctf_head_t", "dev_head_t", "dev_ctf_t"] for trans_name in transforms: trans = [i[trans_name] for i in infos if i[trans_name]] if len(trans) == 0: info[trans_name] = None elif len(trans) == 1: info[trans_name] = trans[0] - elif all(np.all(trans[0]['trans'] == x['trans']) and - trans[0]['from'] == x['from'] and - trans[0]['to'] == x['to'] - for x in trans[1:]): + elif all( + np.all(trans[0]["trans"] == x["trans"]) + and trans[0]["from"] == x["from"] + and trans[0]["to"] == x["to"] + for x in trans[1:] + ): info[trans_name] = trans[0] else: - msg = ("Measurement infos provide mutually inconsistent %s" % - trans_name) + msg = "Measurement infos provide mutually inconsistent %s" % trans_name raise ValueError(msg) # KIT system-IDs - kit_sys_ids = [i['kit_system_id'] for i in infos if i['kit_system_id']] + kit_sys_ids = [i["kit_system_id"] for i in infos if i["kit_system_id"]] if len(kit_sys_ids) == 0: - info['kit_system_id'] = None + info["kit_system_id"] = None elif len(set(kit_sys_ids)) == 1: - info['kit_system_id'] = kit_sys_ids[0] + info["kit_system_id"] = kit_sys_ids[0] else: raise ValueError("Trying to merge channels from different KIT systems") # hpi infos and digitization data: - fields = ['hpi_results', 'hpi_meas', 'dig'] + fields = ["hpi_results", "hpi_meas", "dig"] for k in fields: values = [i[k] for i in infos if i[k]] if len(values) == 0: info[k] = [] elif len(values) == 1: info[k] = values[0] - elif all(object_diff(values[0], v) == '' for v in values[1:]): + elif all(object_diff(values[0], v) == "" for v in values[1:]): info[k] = values[0] else: - msg = ("Measurement infos are inconsistent for %s" % k) + msg = "Measurement infos are inconsistent for %s" % k raise ValueError(msg) # other fields - other_fields = ['acq_pars', 'acq_stim', 'bads', - 'comps', 'custom_ref_applied', 'description', - 'experimenter', 'file_id', 'highpass', 'utc_offset', - 'hpi_subsystem', 'events', 'device_info', 'helium_info', - 'line_freq', 'lowpass', 'meas_id', - 'proj_id', 'proj_name', 'projs', 'sfreq', 'gantry_angle', - 'subject_info', 'sfreq', 'xplotter_layout', 'proc_history'] + other_fields = [ + "acq_pars", + "acq_stim", + "bads", + "comps", + "custom_ref_applied", + "description", + "experimenter", + "file_id", + "highpass", + "utc_offset", + "hpi_subsystem", + "events", + "device_info", + "helium_info", + "line_freq", + "lowpass", + "meas_id", + "proj_id", + "proj_name", + "projs", + "sfreq", + "gantry_angle", + "subject_info", + "sfreq", + "xplotter_layout", + "proc_history", + ] for k in other_fields: info[k] = _merge_info_values(infos, k) - info['meas_date'] = infos[0]['meas_date'] + info["meas_date"] = infos[0]["meas_date"] info._unlocked = False return info @verbose -def create_info(ch_names, sfreq, ch_types='misc', verbose=None): +def create_info(ch_names, sfreq, ch_types="misc", verbose=None): """Create a basic Info instance suitable for use with create_raw. Parameters @@ -2463,41 +2735,50 @@ def create_info(ch_names, sfreq, ch_types='misc', verbose=None): pass else: ch_names = list(np.arange(ch_names).astype(str)) - _validate_type(ch_names, (list, tuple), "ch_names", - ("list, tuple, or int")) + _validate_type(ch_names, (list, tuple), "ch_names", ("list, tuple, or int")) sfreq = float(sfreq) if sfreq <= 0: - raise ValueError('sfreq must be positive') + raise ValueError("sfreq must be positive") nchan = len(ch_names) if isinstance(ch_types, str): ch_types = [ch_types] * nchan ch_types = np.atleast_1d(np.array(ch_types, np.str_)) if ch_types.ndim != 1 or len(ch_types) != nchan: - raise ValueError('ch_types and ch_names must be the same length ' - '(%s != %s) for ch_types=%s' - % (len(ch_types), nchan, ch_types)) + raise ValueError( + "ch_types and ch_names must be the same length " + "(%s != %s) for ch_types=%s" % (len(ch_types), nchan, ch_types) + ) info = _empty_info(sfreq) ch_types_dict = get_channel_type_constants(include_defaults=True) for ci, (ch_name, ch_type) in enumerate(zip(ch_names, ch_types)): - _validate_type(ch_name, 'str', "each entry in ch_names") - _validate_type(ch_type, 'str', "each entry in ch_types") + _validate_type(ch_name, "str", "each entry in ch_names") + _validate_type(ch_type, "str", "each entry in ch_types") if ch_type not in ch_types_dict: - raise KeyError(f'kind must be one of {list(ch_types_dict)}, ' - f'not {ch_type}') + raise KeyError( + f"kind must be one of {list(ch_types_dict)}, " f"not {ch_type}" + ) this_ch_dict = ch_types_dict[ch_type] - kind = this_ch_dict['kind'] + kind = this_ch_dict["kind"] # handle chpi, where kind is a *list* of FIFF constants: kind = kind[0] if isinstance(kind, (list, tuple)) else kind # mirror what tag.py does here coord_frame = _ch_coord_dict.get(kind, FIFF.FIFFV_COORD_UNKNOWN) - coil_type = this_ch_dict.get('coil_type', FIFF.FIFFV_COIL_NONE) - unit = this_ch_dict.get('unit', FIFF.FIFF_UNIT_NONE) - chan_info = dict(loc=np.full(12, np.nan), - unit_mul=FIFF.FIFF_UNITM_NONE, range=1., cal=1., - kind=kind, coil_type=coil_type, unit=unit, - coord_frame=coord_frame, ch_name=str(ch_name), - scanno=ci + 1, logno=ci + 1) - info['chs'].append(chan_info) + coil_type = this_ch_dict.get("coil_type", FIFF.FIFFV_COIL_NONE) + unit = this_ch_dict.get("unit", FIFF.FIFF_UNIT_NONE) + chan_info = dict( + loc=np.full(12, np.nan), + unit_mul=FIFF.FIFF_UNITM_NONE, + range=1.0, + cal=1.0, + kind=kind, + coil_type=coil_type, + unit=unit, + coord_frame=coord_frame, + ch_name=str(ch_name), + scanno=ci + 1, + logno=ci + 1, + ) + info["chs"].append(chan_info) info._update_redundant() info._check_consistency() @@ -2506,38 +2787,93 @@ def create_info(ch_names, sfreq, ch_types='misc', verbose=None): RAW_INFO_FIELDS = ( - 'acq_pars', 'acq_stim', 'bads', 'ch_names', 'chs', - 'comps', 'ctf_head_t', 'custom_ref_applied', 'description', 'dev_ctf_t', - 'dev_head_t', 'dig', 'experimenter', 'events', 'utc_offset', 'device_info', - 'file_id', 'highpass', 'hpi_meas', 'hpi_results', 'helium_info', - 'hpi_subsystem', 'kit_system_id', 'line_freq', 'lowpass', 'meas_date', - 'meas_id', 'nchan', 'proj_id', 'proj_name', 'projs', 'sfreq', - 'subject_info', 'xplotter_layout', 'proc_history', 'gantry_angle', + "acq_pars", + "acq_stim", + "bads", + "ch_names", + "chs", + "comps", + "ctf_head_t", + "custom_ref_applied", + "description", + "dev_ctf_t", + "dev_head_t", + "dig", + "experimenter", + "events", + "utc_offset", + "device_info", + "file_id", + "highpass", + "hpi_meas", + "hpi_results", + "helium_info", + "hpi_subsystem", + "kit_system_id", + "line_freq", + "lowpass", + "meas_date", + "meas_id", + "nchan", + "proj_id", + "proj_name", + "projs", + "sfreq", + "subject_info", + "xplotter_layout", + "proc_history", + "gantry_angle", ) def _empty_info(sfreq): """Create an empty info dictionary.""" _none_keys = ( - 'acq_pars', 'acq_stim', 'ctf_head_t', 'description', - 'dev_ctf_t', 'dig', 'experimenter', 'utc_offset', 'device_info', - 'file_id', 'highpass', 'hpi_subsystem', 'kit_system_id', 'helium_info', - 'line_freq', 'lowpass', 'meas_date', 'meas_id', 'proj_id', 'proj_name', - 'subject_info', 'xplotter_layout', 'gantry_angle', + "acq_pars", + "acq_stim", + "ctf_head_t", + "description", + "dev_ctf_t", + "dig", + "experimenter", + "utc_offset", + "device_info", + "file_id", + "highpass", + "hpi_subsystem", + "kit_system_id", + "helium_info", + "line_freq", + "lowpass", + "meas_date", + "meas_id", + "proj_id", + "proj_name", + "subject_info", + "xplotter_layout", + "gantry_angle", + ) + _list_keys = ( + "bads", + "chs", + "comps", + "events", + "hpi_meas", + "hpi_results", + "projs", + "proc_history", ) - _list_keys = ('bads', 'chs', 'comps', 'events', 'hpi_meas', 'hpi_results', - 'projs', 'proc_history') info = Info() info._unlocked = True for k in _none_keys: info[k] = None for k in _list_keys: info[k] = list() - info['custom_ref_applied'] = FIFF.FIFFV_MNE_CUSTOM_REF_OFF - info['highpass'] = 0. - info['sfreq'] = float(sfreq) - info['lowpass'] = info['sfreq'] / 2. - info['dev_head_t'] = Transform('meg', 'head') + info["custom_ref_applied"] = FIFF.FIFFV_MNE_CUSTOM_REF_OFF + info["highpass"] = 0.0 + info["sfreq"] = float(sfreq) + info["lowpass"] = info["sfreq"] / 2.0 + info["dev_head_t"] = Transform("meg", "head") info._update_redundant() info._check_consistency() return info @@ -2558,13 +2894,12 @@ def _force_update_info(info_base, info_target): The Info object(s) you wish to overwrite using info_base. These objects will be modified in-place. """ - exclude_keys = ['chs', 'ch_names', 'nchan'] + exclude_keys = ["chs", "ch_names", "nchan"] info_target = np.atleast_1d(info_target).ravel() all_infos = np.hstack([info_base, info_target]) for ii in all_infos: if not isinstance(ii, Info): - raise ValueError('Inputs must be of type Info. ' - 'Found type %s' % type(ii)) + raise ValueError("Inputs must be of type Info. " "Found type %s" % type(ii)) for key, val in info_base.items(): if key in exclude_keys: continue @@ -2607,138 +2942,140 @@ def anonymize_info(info, daysback=None, keep_his=False, verbose=None): ----- %(anonymize_info_notes)s """ - _validate_type(info, 'info', "self") + _validate_type(info, "info", "self") - default_anon_dos = datetime.datetime(2000, 1, 1, 0, 0, 0, - tzinfo=datetime.timezone.utc) + default_anon_dos = datetime.datetime( + 2000, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc + ) default_str = "mne_anonymize" default_subject_id = 0 default_sex = 0 - default_desc = ("Anonymized using a time shift" - " to preserve age at acquisition") + default_desc = "Anonymized using a time shift" " to preserve age at acquisition" - none_meas_date = info['meas_date'] is None + none_meas_date = info["meas_date"] is None if none_meas_date: if daysback is not None: - warn('Input info has "meas_date" set to None. ' - 'Removing all information from time/date structures, ' - '*NOT* performing any time shifts!') + warn( + 'Input info has "meas_date" set to None. ' + "Removing all information from time/date structures, " + "*NOT* performing any time shifts!" + ) else: # compute timeshift delta if daysback is None: - delta_t = info['meas_date'] - default_anon_dos + delta_t = info["meas_date"] - default_anon_dos else: delta_t = datetime.timedelta(days=daysback) with info._unlock(): - info['meas_date'] = info['meas_date'] - delta_t + info["meas_date"] = info["meas_date"] - delta_t # file_id and meas_id - for key in ('file_id', 'meas_id'): + for key in ("file_id", "meas_id"): value = info.get(key) if value is not None: - assert 'msecs' not in value - if (none_meas_date or - ((value['secs'], value['usecs']) == DATE_NONE)): + assert "msecs" not in value + if none_meas_date or ((value["secs"], value["usecs"]) == DATE_NONE): # Don't try to shift backwards in time when no measurement # date is available or when file_id is already a place holder tmp = DATE_NONE else: - tmp = _add_timedelta_to_stamp( - (value['secs'], value['usecs']), -delta_t) - value['secs'] = tmp[0] - value['usecs'] = tmp[1] + tmp = _add_timedelta_to_stamp((value["secs"], value["usecs"]), -delta_t) + value["secs"] = tmp[0] + value["usecs"] = tmp[1] # The following copy is needed for a test CTF dataset # otherwise value['machid'][:] = 0 would suffice - _tmp = value['machid'].copy() + _tmp = value["machid"].copy() _tmp[:] = 0 - value['machid'] = _tmp + value["machid"] = _tmp # subject info - subject_info = info.get('subject_info') + subject_info = info.get("subject_info") if subject_info is not None: - if subject_info.get('id') is not None: - subject_info['id'] = default_subject_id + if subject_info.get("id") is not None: + subject_info["id"] = default_subject_id if keep_his: - logger.info('Not fully anonymizing info - keeping ' - 'his_id, sex, and hand info') + logger.info( + "Not fully anonymizing info - keeping " "his_id, sex, and hand info" + ) else: - if subject_info.get('his_id') is not None: - subject_info['his_id'] = str(default_subject_id) - if subject_info.get('sex') is not None: - subject_info['sex'] = default_sex - if subject_info.get('hand') is not None: - del subject_info['hand'] # there's no "unknown" setting - - for key in ('last_name', 'first_name', 'middle_name'): + if subject_info.get("his_id") is not None: + subject_info["his_id"] = str(default_subject_id) + if subject_info.get("sex") is not None: + subject_info["sex"] = default_sex + if subject_info.get("hand") is not None: + del subject_info["hand"] # there's no "unknown" setting + + for key in ("last_name", "first_name", "middle_name"): if subject_info.get(key) is not None: subject_info[key] = default_str # anonymize the subject birthday if none_meas_date: - subject_info.pop('birthday', None) - elif subject_info.get('birthday') is not None: - dob = datetime.datetime(subject_info['birthday'][0], - subject_info['birthday'][1], - subject_info['birthday'][2]) + subject_info.pop("birthday", None) + elif subject_info.get("birthday") is not None: + dob = datetime.datetime( + subject_info["birthday"][0], + subject_info["birthday"][1], + subject_info["birthday"][2], + ) dob -= delta_t - subject_info['birthday'] = dob.year, dob.month, dob.day + subject_info["birthday"] = dob.year, dob.month, dob.day - for key in ('weight', 'height'): + for key in ("weight", "height"): if subject_info.get(key) is not None: subject_info[key] = 0 - info['experimenter'] = default_str - info['description'] = default_desc + info["experimenter"] = default_str + info["description"] = default_desc with info._unlock(): - if info['proj_id'] is not None: - info['proj_id'] = np.zeros_like(info['proj_id']) - if info['proj_name'] is not None: - info['proj_name'] = default_str - if info['utc_offset'] is not None: - info['utc_offset'] = None - - proc_hist = info.get('proc_history') + if info["proj_id"] is not None: + info["proj_id"] = np.zeros_like(info["proj_id"]) + if info["proj_name"] is not None: + info["proj_name"] = default_str + if info["utc_offset"] is not None: + info["utc_offset"] = None + + proc_hist = info.get("proc_history") if proc_hist is not None: for record in proc_hist: - record['block_id']['machid'][:] = 0 - record['experimenter'] = default_str + record["block_id"]["machid"][:] = 0 + record["experimenter"] = default_str if none_meas_date: - record['block_id']['secs'] = DATE_NONE[0] - record['block_id']['usecs'] = DATE_NONE[1] - record['date'] = DATE_NONE + record["block_id"]["secs"] = DATE_NONE[0] + record["block_id"]["usecs"] = DATE_NONE[1] + record["date"] = DATE_NONE else: - this_t0 = (record['block_id']['secs'], - record['block_id']['usecs']) - this_t1 = _add_timedelta_to_stamp( - this_t0, -delta_t) - record['block_id']['secs'] = this_t1[0] - record['block_id']['usecs'] = this_t1[1] - record['date'] = _add_timedelta_to_stamp( - record['date'], -delta_t) - - hi = info.get('helium_info') + this_t0 = (record["block_id"]["secs"], record["block_id"]["usecs"]) + this_t1 = _add_timedelta_to_stamp(this_t0, -delta_t) + record["block_id"]["secs"] = this_t1[0] + record["block_id"]["usecs"] = this_t1[1] + record["date"] = _add_timedelta_to_stamp(record["date"], -delta_t) + + hi = info.get("helium_info") if hi is not None: - if hi.get('orig_file_guid') is not None: - hi['orig_file_guid'] = default_str - if none_meas_date and hi.get('meas_date') is not None: - hi['meas_date'] = DATE_NONE - elif hi.get('meas_date') is not None: - hi['meas_date'] = _add_timedelta_to_stamp( - hi['meas_date'], -delta_t) - - di = info.get('device_info') + if hi.get("orig_file_guid") is not None: + hi["orig_file_guid"] = default_str + if none_meas_date and hi.get("meas_date") is not None: + hi["meas_date"] = DATE_NONE + elif hi.get("meas_date") is not None: + hi["meas_date"] = _add_timedelta_to_stamp(hi["meas_date"], -delta_t) + + di = info.get("device_info") if di is not None: - for k in ('serial', 'site'): + for k in ("serial", "site"): if di.get(k) is not None: di[k] = default_str - err_mesg = ('anonymize_info generated an inconsistent info object. ' - 'Underlying Error:\n') + err_mesg = ( + "anonymize_info generated an inconsistent info object. " "Underlying Error:\n" + ) info._check_consistency(prepend_error=err_mesg) - err_mesg = ('anonymize_info generated an inconsistent info object. ' - 'daysback parameter was too large. ' - 'Underlying Error:\n') + err_mesg = ( + "anonymize_info generated an inconsistent info object. " + "daysback parameter was too large. " + "Underlying Error:\n" + ) _check_dates(info, prepend_error=err_mesg) return info @@ -2770,16 +3107,16 @@ def _bad_chans_comp(info, ch_names): Returns [] if no channels are missing. """ - if 'comps' not in info: + if "comps" not in info: # should this be thought of as a bug? return False, [] # only include compensation channels that would affect selected channels ch_names_s = set(ch_names) comp_names = [] - for comp in info['comps']: - if len(ch_names_s.intersection(comp['data']['row_names'])) > 0: - comp_names.extend(comp['data']['col_names']) + for comp in info["comps"]: + if len(ch_names_s.intersection(comp["data"]["row_names"])) > 0: + comp_names.extend(comp["data"]["col_names"]) comp_names = sorted(set(comp_names)) missing_ch_names = sorted(set(comp_names).difference(ch_names)) @@ -2790,8 +3127,7 @@ def _bad_chans_comp(info, ch_names): return False, missing_ch_names -_DIG_CAST = dict( - kind=int, ident=int, r=lambda x: x, coord_frame=int) +_DIG_CAST = dict(kind=int, ident=int, r=lambda x: x, coord_frame=int) # key -> const, cast, write _CH_INFO_MAP = OrderedDict( scanno=(FIFF.FIFF_CH_SCAN_NO, _int_item, write_int), @@ -2809,27 +3145,27 @@ def _bad_chans_comp(info, ch_names): # key -> cast _CH_CAST = OrderedDict((key, val[1]) for key, val in _CH_INFO_MAP.items()) # const -> key, cast -_CH_READ_MAP = OrderedDict((val[0], (key, val[1])) - for key, val in _CH_INFO_MAP.items()) +_CH_READ_MAP = OrderedDict((val[0], (key, val[1])) for key, val in _CH_INFO_MAP.items()) @contextlib.contextmanager def _writing_info_hdf5(info): # Make info writing faster by packing chs and dig into numpy arrays - orig_dig = info.get('dig', None) - orig_chs = info['chs'] + orig_dig = info.get("dig", None) + orig_chs = info["chs"] with info._unlock(): try: if orig_dig is not None and len(orig_dig) > 0: - info['dig'] = _dict_pack(info['dig'], _DIG_CAST) - info['chs'] = _dict_pack(info['chs'], _CH_CAST) - info['chs']['ch_name'] = np.char.encode( - info['chs']['ch_name'], encoding='utf8') + info["dig"] = _dict_pack(info["dig"], _DIG_CAST) + info["chs"] = _dict_pack(info["chs"], _CH_CAST) + info["chs"]["ch_name"] = np.char.encode( + info["chs"]["ch_name"], encoding="utf8" + ) yield finally: if orig_dig is not None: - info['dig'] = orig_dig - info['chs'] = orig_chs + info["dig"] = orig_dig + info["chs"] = orig_chs def _dict_pack(obj, casts): @@ -2840,14 +3176,13 @@ def _dict_pack(obj, casts): def _dict_unpack(obj, casts): # unpack a dict of array into a list of dict n = len(obj[list(casts)[0]]) - return [{key: cast(obj[key][ii]) for key, cast in casts.items()} - for ii in range(n)] + return [{key: cast(obj[key][ii]) for key, cast in casts.items()} for ii in range(n)] def _make_ch_names_mapping(chs): - orig_ch_names = [c['ch_name'] for c in chs] + orig_ch_names = [c["ch_name"] for c in chs] ch_names = orig_ch_names.copy() - _unique_channel_names(ch_names, max_length=15, verbose='error') + _unique_channel_names(ch_names, max_length=15, verbose="error") ch_names_mapping = dict() if orig_ch_names != ch_names: ch_names_mapping.update(zip(orig_ch_names, ch_names)) @@ -2859,27 +3194,28 @@ def _write_ch_infos(fid, chs, reset_range, ch_names_mapping): for k, c in enumerate(chs): # Scan numbers may have been messed up c = c.copy() - c['ch_name'] = ch_names_mapping.get(c['ch_name'], c['ch_name']) - assert len(c['ch_name']) <= 15 - c['scanno'] = k + 1 + c["ch_name"] = ch_names_mapping.get(c["ch_name"], c["ch_name"]) + assert len(c["ch_name"]) <= 15 + c["scanno"] = k + 1 # for float/double, the "range" param is unnecessary if reset_range: - c['range'] = 1.0 + c["range"] = 1.0 write_ch_info(fid, c) # only write new-style channel information if necessary if len(ch_names_mapping): logger.info( - ' Writing channel names to FIF truncated to 15 characters ' - 'with remapping') + " Writing channel names to FIF truncated to 15 characters " + "with remapping" + ) for ch in chs: start_block(fid, FIFF.FIFFB_CH_INFO) assert set(ch) == set(_CH_INFO_MAP) - for (key, (const, _, write)) in _CH_INFO_MAP.items(): + for key, (const, _, write) in _CH_INFO_MAP.items(): write(fid, const, ch[key]) end_block(fid, FIFF.FIFFB_CH_INFO) -def _ensure_infos_match(info1, info2, name, *, on_mismatch='raise'): +def _ensure_infos_match(info1, info2, name, *, on_mismatch="raise"): """Check if infos match. Parameters @@ -2893,42 +3229,46 @@ def _ensure_infos_match(info1, info2, name, *, on_mismatch='raise'): What to do in case of a mismatch of ``dev_head_t`` between ``info1`` and ``info2``. """ - _check_on_missing(on_missing=on_mismatch, name='on_mismatch') + _check_on_missing(on_missing=on_mismatch, name="on_mismatch") info1._check_consistency() info2._check_consistency() - if info1['nchan'] != info2['nchan']: - raise ValueError(f'{name}.info[\'nchan\'] must match') - if set(info1['bads']) != set(info2['bads']): - raise ValueError(f'{name}.info[\'bads\'] must match') - if info1['sfreq'] != info2['sfreq']: - raise ValueError(f'{name}.info[\'sfreq\'] must match') - if set(info1['ch_names']) != set(info2['ch_names']): - raise ValueError(f'{name}.info[\'ch_names\'] must match') - if info1['ch_names'] != info2['ch_names']: - msg = (f'{name}.info[\'ch_names\']: Channel order must match. Use ' - '"mne.match_channel_orders()" to sort channels.') + if info1["nchan"] != info2["nchan"]: + raise ValueError(f"{name}.info['nchan'] must match") + if set(info1["bads"]) != set(info2["bads"]): + raise ValueError(f"{name}.info['bads'] must match") + if info1["sfreq"] != info2["sfreq"]: + raise ValueError(f"{name}.info['sfreq'] must match") + if set(info1["ch_names"]) != set(info2["ch_names"]): + raise ValueError(f"{name}.info['ch_names'] must match") + if info1["ch_names"] != info2["ch_names"]: + msg = ( + f"{name}.info['ch_names']: Channel order must match. Use " + '"mne.match_channel_orders()" to sort channels.' + ) raise ValueError(msg) - if len(info2['projs']) != len(info1['projs']): - raise ValueError(f'SSP projectors in {name} must be the same') - if any(not _proj_equal(p1, p2) for p1, p2 in - zip(info2['projs'], info1['projs'])): - raise ValueError(f'SSP projectors in {name} must be the same') - if (info1['dev_head_t'] is None) != (info2['dev_head_t'] is None) or \ - (info1['dev_head_t'] is not None and not - np.allclose(info1['dev_head_t']['trans'], - info2['dev_head_t']['trans'], rtol=1e-6)): - msg = (f"{name}.info['dev_head_t'] differs. The " - f"instances probably come from different runs, and " - f"are therefore associated with different head " - f"positions. Manually change info['dev_head_t'] to " - f"avoid this message but beware that this means the " - f"MEG sensors will not be properly spatially aligned. " - f"See mne.preprocessing.maxwell_filter to realign the " - f"runs to a common head position.") - _on_missing(on_missing=on_mismatch, msg=msg, - name='on_mismatch') + if len(info2["projs"]) != len(info1["projs"]): + raise ValueError(f"SSP projectors in {name} must be the same") + if any(not _proj_equal(p1, p2) for p1, p2 in zip(info2["projs"], info1["projs"])): + raise ValueError(f"SSP projectors in {name} must be the same") + if (info1["dev_head_t"] is None) != (info2["dev_head_t"] is None) or ( + info1["dev_head_t"] is not None + and not np.allclose( + info1["dev_head_t"]["trans"], info2["dev_head_t"]["trans"], rtol=1e-6 + ) + ): + msg = ( + f"{name}.info['dev_head_t'] differs. The " + f"instances probably come from different runs, and " + f"are therefore associated with different head " + f"positions. Manually change info['dev_head_t'] to " + f"avoid this message but beware that this means the " + f"MEG sensors will not be properly spatially aligned. " + f"See mne.preprocessing.maxwell_filter to realign the " + f"runs to a common head position." + ) + _on_missing(on_missing=on_mismatch, msg=msg, name="on_mismatch") def _get_fnirs_ch_pos(info): @@ -2940,6 +3280,7 @@ def _get_fnirs_ch_pos(info): returns the location of each source and detector. """ from ..preprocessing.nirs import _fnirs_optode_names, _optode_position + srcs, dets = _fnirs_optode_names(info) ch_pos = {} for optode in [*srcs, *dets]: diff --git a/mne/io/nedf/nedf.py b/mne/io/nedf/nedf.py index 78ae0106b4e..bd463dd87a0 100644 --- a/mne/io/nedf/nedf.py +++ b/mne/io/nedf/nedf.py @@ -29,7 +29,7 @@ def _getsubnodetext(node, name): """ subnode = node.findtext(name) if not subnode: - raise RuntimeError('NEDF header ' + name + ' not found') + raise RuntimeError("NEDF header " + name + " not found") return subnode @@ -64,55 +64,56 @@ def _parse_nedf_header(header): dt = [] # dtype for the binary data block datadt = [] # dtype for a single EEG sample - headerend = header.find(b'\0') + headerend = header.find(b"\0") if headerend == -1: - raise RuntimeError('End of header null not found') + raise RuntimeError("End of header null not found") headerxml = ElementTree.fromstring(header[:headerend]) - nedfversion = headerxml.findtext('NEDFversion', '') - if nedfversion not in ['1.3', '1.4']: - warn('NEDFversion unsupported, use with caution') + nedfversion = headerxml.findtext("NEDFversion", "") + if nedfversion not in ["1.3", "1.4"]: + warn("NEDFversion unsupported, use with caution") - if headerxml.findtext('stepDetails/DeviceClass', '') == 'STARSTIM': - warn('Found Starstim, this hasn\'t been tested extensively!') + if headerxml.findtext("stepDetails/DeviceClass", "") == "STARSTIM": + warn("Found Starstim, this hasn't been tested extensively!") - if headerxml.findtext('AdditionalChannelStatus', 'OFF') != 'OFF': - raise RuntimeError('Unknown additional channel, aborting.') + if headerxml.findtext("AdditionalChannelStatus", "OFF") != "OFF": + raise RuntimeError("Unknown additional channel, aborting.") - n_acc = int(headerxml.findtext('NumberOfChannelsOfAccelerometer', 0)) + n_acc = int(headerxml.findtext("NumberOfChannelsOfAccelerometer", 0)) if n_acc: # expect one sample of u16 accelerometer data per block - dt.append(('acc', '>u2', (n_acc,))) + dt.append(("acc", ">u2", (n_acc,))) - eegset = headerxml.find('EEGSettings') + eegset = headerxml.find("EEGSettings") if eegset is None: - raise RuntimeError('No EEG channels found') - nchantotal = int(_getsubnodetext(eegset, 'TotalNumberOfChannels')) - info['nchan'] = nchantotal + raise RuntimeError("No EEG channels found") + nchantotal = int(_getsubnodetext(eegset, "TotalNumberOfChannels")) + info["nchan"] = nchantotal - info['sfreq'] = int(_getsubnodetext(eegset, 'EEGSamplingRate')) - info['ch_names'] = [e.text for e in eegset.find('EEGMontage')] - if nchantotal != len(info['ch_names']): + info["sfreq"] = int(_getsubnodetext(eegset, "EEGSamplingRate")) + info["ch_names"] = [e.text for e in eegset.find("EEGMontage")] + if nchantotal != len(info["ch_names"]): raise RuntimeError( f"TotalNumberOfChannels ({nchantotal}) != " - f"channel count ({len(info['ch_names'])})") + f"channel count ({len(info['ch_names'])})" + ) # expect nchantotal uint24s - datadt.append(('eeg', 'B', (nchantotal, 3))) + datadt.append(("eeg", "B", (nchantotal, 3))) - if headerxml.find('STIMSettings') is not None: + if headerxml.find("STIMSettings") is not None: # 2* -> two stim samples per eeg sample - datadt.append(('stim', 'B', (2, nchantotal, 3))) - warn('stim channels are currently ignored') + datadt.append(("stim", "B", (2, nchantotal, 3))) + warn("stim channels are currently ignored") # Trigger data: 4 bytes in newer versions, 1 byte in older versions - trigger_type = '>i4' if headerxml.findtext('NEDFversion') else 'B' - datadt.append(('trig', trigger_type)) + trigger_type = ">i4" if headerxml.findtext("NEDFversion") else "B" + datadt.append(("trig", trigger_type)) # 5 data samples per block - dt.append(('data', np.dtype(datadt), (5,))) + dt.append(("data", np.dtype(datadt), (5,))) - date = headerxml.findtext('StepDetails/StartDate_firstEEGTimestamp', 0) - info['meas_date'] = datetime.fromtimestamp(int(date) / 1000, timezone.utc) + date = headerxml.findtext("StepDetails/StartDate_firstEEGTimestamp", 0) + info["meas_date"] = datetime.fromtimestamp(int(date) / 1000, timezone.utc) - n_samples = int(_getsubnodetext(eegset, 'NumberOfRecordsOfEEG')) + n_samples = int(_getsubnodetext(eegset, "NumberOfRecordsOfEEG")) n_full, n_last = divmod(n_samples, 5) dt_last = deepcopy(dt) assert dt_last[-1][-1] == (5,) @@ -131,29 +132,34 @@ class RawNedf(BaseRaw): def __init__(self, filename, preload=False, verbose=None): filename = str(_check_fname(filename, "read", True, "filename")) - with open(filename, mode='rb') as fid: + with open(filename, mode="rb") as fid: header = fid.read(_HDRLEN) header, dt, dt_last, n_samp, n_full = _parse_nedf_header(header) - ch_names = header['ch_names'] + ['STI 014'] - ch_types = ['eeg'] * len(ch_names) - ch_types[-1] = 'stim' - info = create_info(ch_names, header['sfreq'], ch_types) + ch_names = header["ch_names"] + ["STI 014"] + ch_types = ["eeg"] * len(ch_names) + ch_types[-1] = "stim" + info = create_info(ch_names, header["sfreq"], ch_types) # scaling factor ADC-values -> volts # taken from the NEDF EEGLAB plugin # (https://www.neuroelectrics.com/resources/software/): - for ch in info['chs'][:-1]: - ch['cal'] = 2.4 / (6.0 * 8388607) + for ch in info["chs"][:-1]: + ch["cal"] = 2.4 / (6.0 * 8388607) with info._unlock(): - info['meas_date'] = header['meas_date'] + info["meas_date"] = header["meas_date"] raw_extra = dict(dt=dt, dt_last=dt_last, n_full=n_full) super().__init__( - info, preload=preload, filenames=[filename], verbose=verbose, - raw_extras=[raw_extra], last_samps=[n_samp - 1]) + info, + preload=preload, + filenames=[filename], + verbose=verbose, + raw_extras=[raw_extra], + last_samps=[n_samp - 1], + ) def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): - dt = self._raw_extras[fi]['dt'] - dt_last = self._raw_extras[fi]['dt_last'] - n_full = self._raw_extras[fi]['n_full'] + dt = self._raw_extras[fi]["dt"] + dt_last = self._raw_extras[fi]["dt_last"] + n_full = self._raw_extras[fi]["n_full"] n_eeg = dt[1].subdtype[0][0].shape[0] # data is stored in 5-sample chunks (except maybe the last one!) # so we have to do some gymnastics to pick the correct parts to @@ -165,28 +171,28 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): last = None n_chunks = (n_samples_full - 1) // 5 + 1 n_tot = n_chunks * 5 - with open(self._filenames[fi], 'rb') as fid: + with open(self._filenames[fi], "rb") as fid: fid.seek(offset, 0) chunks = np.fromfile(fid, dtype=dt, count=n_chunks) assert len(chunks) == n_chunks if n_samples != n_samples_full: last = np.fromfile(fid, dtype=dt_last, count=1) eeg = _convert_eeg(chunks, n_eeg, n_tot) - trig = chunks['data']['trig'].reshape(1, n_tot) + trig = chunks["data"]["trig"].reshape(1, n_tot) if last is not None: - n_last = dt_last['data'].shape[0] - eeg = np.concatenate( - (eeg, _convert_eeg(last, n_eeg, n_last)), axis=-1) + n_last = dt_last["data"].shape[0] + eeg = np.concatenate((eeg, _convert_eeg(last, n_eeg, n_last)), axis=-1) trig = np.concatenate( - (trig, last['data']['trig'].reshape(1, n_last)), axis=-1) + (trig, last["data"]["trig"].reshape(1, n_last)), axis=-1 + ) one_ = np.concatenate((eeg, trig)) - one = one_[:, start_sl:n_samples + start_sl] + one = one_[:, start_sl : n_samples + start_sl] _mult_cal_one(data, one, idx, cals, mult) def _convert_eeg(chunks, n_eeg, n_tot): # convert uint8-triplet -> int32 - eeg = chunks['data']['eeg'] @ np.array([1 << 16, 1 << 8, 1]) + eeg = chunks["data"]["eeg"] @ np.array([1 << 16, 1 << 8, 1]) # convert sign if necessary eeg[eeg > (1 << 23)] -= 1 << 24 eeg = eeg.reshape((n_tot, n_eeg)).T diff --git a/mne/io/nedf/tests/test_nedf.py b/mne/io/nedf/tests/test_nedf.py index 404dd7af342..36794417e0d 100644 --- a/mne/io/nedf/tests/test_nedf.py +++ b/mne/io/nedf/tests/test_nedf.py @@ -29,28 +29,27 @@ \x00""" -@pytest.mark.parametrize('nacc', (0, 3)) +@pytest.mark.parametrize("nacc", (0, 3)) def test_nedf_header_parser(nacc): """Test NEDF header parsing and dtype extraction.""" - with pytest.warns(RuntimeWarning, match='stim channels.*ignored'): - info, dt, dt_last, n_samples, n_full = _parse_nedf_header( - stimhdr % nacc) + with pytest.warns(RuntimeWarning, match="stim channels.*ignored"): + info, dt, dt_last, n_samples, n_full = _parse_nedf_header(stimhdr % nacc) assert n_samples == 11 assert n_full == 2 nchan = 4 - assert info['nchan'] == nchan + assert info["nchan"] == nchan assert dt.itemsize == 200 + nacc * 2 if nacc: - assert dt.names[0] == 'acc' - assert dt['acc'].shape == (nacc,) + assert dt.names[0] == "acc" + assert dt["acc"].shape == (nacc,) - assert dt['data'].shape == (5,) # blocks of 5 EEG samples each - assert dt_last['data'].shape == (1,) # plus one last extra one + assert dt["data"].shape == (5,) # blocks of 5 EEG samples each + assert dt_last["data"].shape == (1,) # plus one last extra one - eegsampledt = dt['data'].subdtype[0] - assert eegsampledt.names == ('eeg', 'stim', 'trig') - assert eegsampledt['eeg'].shape == (nchan, 3) - assert eegsampledt['stim'].shape == (2, nchan, 3) + eegsampledt = dt["data"].subdtype[0] + assert eegsampledt.names == ("eeg", "stim", "trig") + assert eegsampledt["eeg"].shape == (nchan, 3) + assert eegsampledt["stim"].shape == (2, nchan, 3) def test_invalid_headers(): @@ -62,38 +61,35 @@ def test_invalid_headers(): ABCD \x00""" - nchan = b'4' - sr = b'500' + nchan = b"4" + sr = b"500" hdr = { - 'null': - b'No null terminator', - 'Unknown additional': - (b'1.3' + - b'???\x00'), # noqa: E501 - 'No EEG channels found': - b'1.3\x00', - 'TotalNumberOfChannels not found': - tpl % b'No nchan.', - '!= channel count': - tpl % (sr + b'52'), - 'EEGSamplingRate not found': - tpl % nchan, - 'NumberOfRecordsOfEEG not found': - tpl % (sr + nchan), + "null": b"No null terminator", + "Unknown additional": ( + b"1.3" + + b"???\x00" + ), # noqa: E501 + "No EEG channels found": b"1.3\x00", + "TotalNumberOfChannels not found": tpl % b"No nchan.", + "!= channel count": tpl + % (sr + b"52"), + "EEGSamplingRate not found": tpl % nchan, + "NumberOfRecordsOfEEG not found": tpl % (sr + nchan), } for match, invalid_hdr in hdr.items(): with pytest.raises(RuntimeError, match=match): _parse_nedf_header(invalid_hdr) sus_hdrs = { - 'unsupported': b'25\x00', - 'tested': ( - b'1.3' + - b'STARSTIM\x00'), + "unsupported": b"25\x00", + "tested": ( + b"1.3" + + b"STARSTIM\x00" + ), } for match, sus_hdr in sus_hdrs.items(): with pytest.warns(RuntimeWarning, match=match): - with pytest.raises(RuntimeError, match='No EEG channels found'): + with pytest.raises(RuntimeError, match="No EEG channels found"): _parse_nedf_header(sus_hdr) @@ -107,22 +103,22 @@ def test_nedf_data(): events = find_events(raw, shortest_event=1) assert len(events) == 4 assert_array_equal(events[:, 2], [1, 1, 1, 1]) - onsets = events[:, 0] / raw.info['sfreq'] - assert raw.info['sfreq'] == 500 + onsets = events[:, 0] / raw.info["sfreq"] + assert raw.info["sfreq"] == 500 - data_end = raw.get_data('Fp1', nsamples - 100, nsamples).mean() - assert_allclose(data_end, .0176, atol=.01) - assert_allclose(raw.get_data('Fpz', 0, 100).mean(), .0185, atol=.01) + data_end = raw.get_data("Fp1", nsamples - 100, nsamples).mean() + assert_allclose(data_end, 0.0176, atol=0.01) + assert_allclose(raw.get_data("Fpz", 0, 100).mean(), 0.0185, atol=0.01) assert_allclose(onsets, [22.384, 38.238, 49.496, 63.15]) - assert raw.info['meas_date'].year == 2019 - assert raw.ch_names[2] == 'AF7' - - for ch in raw.info['chs'][:-1]: - assert ch['kind'] == FIFF.FIFFV_EEG_CH - assert ch['unit'] == FIFF.FIFF_UNIT_V - assert raw.info['chs'][-1]['kind'] == FIFF.FIFFV_STIM_CH - assert raw.info['chs'][-1]['unit'] == FIFF.FIFF_UNIT_V + assert raw.info["meas_date"].year == 2019 + assert raw.ch_names[2] == "AF7" + + for ch in raw.info["chs"][:-1]: + assert ch["kind"] == FIFF.FIFFV_EEG_CH + assert ch["unit"] == FIFF.FIFF_UNIT_V + assert raw.info["chs"][-1]["kind"] == FIFF.FIFFV_STIM_CH + assert raw.info["chs"][-1]["unit"] == FIFF.FIFF_UNIT_V # full tests _test_raw_reader(read_raw_nedf, filename=eegfile) diff --git a/mne/io/nicolet/nicolet.py b/mne/io/nicolet/nicolet.py index 6681dcff523..ced03a89279 100644 --- a/mne/io/nicolet/nicolet.py +++ b/mne/io/nicolet/nicolet.py @@ -15,8 +15,9 @@ @fill_doc -def read_raw_nicolet(input_fname, ch_type, eog=(), - ecg=(), emg=(), misc=(), preload=False, verbose=None): +def read_raw_nicolet( + input_fname, ch_type, eog=(), ecg=(), emg=(), misc=(), preload=False, verbose=None +): """Read Nicolet data as raw object. ..note:: This reader takes data files with the extension ``.data`` as an @@ -58,68 +59,80 @@ def read_raw_nicolet(input_fname, ch_type, eog=(), -------- mne.io.Raw : Documentation of attributes and methods. """ - return RawNicolet(input_fname, ch_type, eog=eog, ecg=ecg, - emg=emg, misc=misc, preload=preload, verbose=verbose) + return RawNicolet( + input_fname, + ch_type, + eog=eog, + ecg=ecg, + emg=emg, + misc=misc, + preload=preload, + verbose=verbose, + ) def _get_nicolet_info(fname, ch_type, eog, ecg, emg, misc): """Extract info from Nicolet header files.""" fname, extension = path.splitext(fname) - if extension != '.data': - raise ValueError( - f'File name should end with .data not "{extension}".' - ) + if extension != ".data": + raise ValueError(f'File name should end with .data not "{extension}".') - header = fname + '.head' + header = fname + ".head" - logger.info('Reading header...') + logger.info("Reading header...") header_info = dict() - with open(header, 'r') as fid: + with open(header, "r") as fid: for line in fid: - var, value = line.split('=') - if var == 'elec_names': - value = value[1:-2].split(',') # strip brackets - elif var == 'conversion_factor': + var, value = line.split("=") + if var == "elec_names": + value = value[1:-2].split(",") # strip brackets + elif var == "conversion_factor": value = float(value) - elif var in ['num_channels', 'rec_id', 'adm_id', 'pat_id', - 'num_samples']: + elif var in ["num_channels", "rec_id", "adm_id", "pat_id", "num_samples"]: value = int(value) - elif var != 'start_ts': + elif var != "start_ts": value = float(value) header_info[var] = value - ch_names = header_info['elec_names'] - if eog == 'auto': - eog = _find_channels(ch_names, 'EOG') - if ecg == 'auto': - ecg = _find_channels(ch_names, 'ECG') - if emg == 'auto': - emg = _find_channels(ch_names, 'EMG') - - date, time = header_info['start_ts'].split() - date = date.split('-') - time = time.split(':') - sec, msec = time[2].split('.') - date = datetime.datetime(int(date[0]), int(date[1]), int(date[2]), - int(time[0]), int(time[1]), int(sec), int(msec)) - info = _empty_info(header_info['sample_freq']) - info['meas_date'] = (calendar.timegm(date.utctimetuple()), 0) - - if ch_type == 'eeg': + ch_names = header_info["elec_names"] + if eog == "auto": + eog = _find_channels(ch_names, "EOG") + if ecg == "auto": + ecg = _find_channels(ch_names, "ECG") + if emg == "auto": + emg = _find_channels(ch_names, "EMG") + + date, time = header_info["start_ts"].split() + date = date.split("-") + time = time.split(":") + sec, msec = time[2].split(".") + date = datetime.datetime( + int(date[0]), + int(date[1]), + int(date[2]), + int(time[0]), + int(time[1]), + int(sec), + int(msec), + ) + info = _empty_info(header_info["sample_freq"]) + info["meas_date"] = (calendar.timegm(date.utctimetuple()), 0) + + if ch_type == "eeg": ch_coil = FIFF.FIFFV_COIL_EEG ch_kind = FIFF.FIFFV_EEG_CH - elif ch_type == 'seeg': + elif ch_type == "seeg": ch_coil = FIFF.FIFFV_COIL_EEG ch_kind = FIFF.FIFFV_SEEG_CH else: - raise TypeError("Channel type not recognized. Available types are " - "'eeg' and 'seeg'.") - cals = np.repeat(header_info['conversion_factor'] * 1e-6, len(ch_names)) - info['chs'] = _create_chs(ch_names, cals, ch_coil, ch_kind, eog, ecg, emg, - misc) - info['highpass'] = 0. - info['lowpass'] = info['sfreq'] / 2.0 + raise TypeError( + "Channel type not recognized. Available types are " "'eeg' and 'seeg'." + ) + cals = np.repeat(header_info["conversion_factor"] * 1e-6, len(ch_names)) + info["chs"] = _create_chs(ch_names, cals, ch_coil, ch_kind, eog, ecg, emg, misc) + info["highpass"] = 0.0 + info["lowpass"] = info["sfreq"] / 2.0 info._unlocked = False info._update_redundant() return info, header_info @@ -158,19 +171,30 @@ class RawNicolet(BaseRaw): mne.io.Raw : Documentation of attributes and methods. """ - def __init__(self, input_fname, ch_type, eog=(), - ecg=(), emg=(), misc=(), preload=False, - verbose=None): # noqa: D102 + def __init__( + self, + input_fname, + ch_type, + eog=(), + ecg=(), + emg=(), + misc=(), + preload=False, + verbose=None, + ): # noqa: D102 input_fname = path.abspath(input_fname) - info, header_info = _get_nicolet_info(input_fname, ch_type, eog, ecg, - emg, misc) - last_samps = [header_info['num_samples'] - 1] + info, header_info = _get_nicolet_info(input_fname, ch_type, eog, ecg, emg, misc) + last_samps = [header_info["num_samples"] - 1] super(RawNicolet, self).__init__( - info, preload, filenames=[input_fname], raw_extras=[header_info], - last_samps=last_samps, orig_format='int', - verbose=verbose) + info, + preload, + filenames=[input_fname], + raw_extras=[header_info], + last_samps=last_samps, + orig_format="int", + verbose=verbose, + ) def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): """Read a chunk of raw data.""" - _read_segments_file( - self, data, idx, fi, start, stop, cals, mult, dtype='= len(_chan_labels): n = idx - len(_chan_labels) + 1 - _chan_labels.extend(['UNK'] * n) + _chan_labels.extend(["UNK"] * n) _chan_labels[idx] = name.strip() except UnicodeDecodeError: pass else: break else: - warn(f'Could not decode 21E file as one of {_encodings}; ' - f'Default channel names are chosen.') + warn( + f"Could not decode 21E file as one of {_encodings}; " + f"Default channel names are chosen." + ) return _chan_labels @@ -142,48 +167,48 @@ def _read_nihon_header(fname): fname = _ensure_path(fname) _chan_labels = _read_21e_file(fname) header = {} - logger.info(f'Reading header from {fname}') - with open(fname, 'r') as fid: - version = np.fromfile(fid, '|S16', 1).astype('U16')[0] + logger.info(f"Reading header from {fname}") + with open(fname, "r") as fid: + version = np.fromfile(fid, "|S16", 1).astype("U16")[0] if version not in _valid_headers: - raise ValueError( - 'Not a valid Nihon Kohden EEG file ({})'.format(version)) + raise ValueError("Not a valid Nihon Kohden EEG file ({})".format(version)) fid.seek(0x0081) - control_block = np.fromfile(fid, '|S16', 1).astype('U16')[0] + control_block = np.fromfile(fid, "|S16", 1).astype("U16")[0] if control_block not in _valid_headers: - raise ValueError('Not a valid Nihon Kohden EEG file ' - '(control block {})'.format(version)) + raise ValueError( + "Not a valid Nihon Kohden EEG file " + "(control block {})".format(version) + ) - fid.seek(0x17fe) + fid.seek(0x17FE) waveform_sign = np.fromfile(fid, np.uint8, 1)[0] if waveform_sign != 1: - raise ValueError('Not a valid Nihon Kohden EEG file ' - '(waveform block)') - header['version'] = version + raise ValueError("Not a valid Nihon Kohden EEG file " "(waveform block)") + header["version"] = version fid.seek(0x0091) n_ctlblocks = np.fromfile(fid, np.uint8, 1)[0] - header['n_ctlblocks'] = n_ctlblocks + header["n_ctlblocks"] = n_ctlblocks controlblocks = [] for i_ctl_block in range(n_ctlblocks): t_controlblock = {} fid.seek(0x0092 + i_ctl_block * 20) t_ctl_address = np.fromfile(fid, np.uint32, 1)[0] - t_controlblock['address'] = t_ctl_address + t_controlblock["address"] = t_ctl_address fid.seek(t_ctl_address + 17) n_datablocks = np.fromfile(fid, np.uint8, 1)[0] - t_controlblock['n_datablocks'] = n_datablocks - t_controlblock['datablocks'] = [] + t_controlblock["n_datablocks"] = n_datablocks + t_controlblock["datablocks"] = [] for i_data_block in range(n_datablocks): t_datablock = {} fid.seek(t_ctl_address + i_data_block * 20 + 18) t_data_address = np.fromfile(fid, np.uint32, 1)[0] - t_datablock['address'] = t_data_address + t_datablock["address"] = t_data_address fid.seek(t_data_address + 0x26) t_n_channels = np.fromfile(fid, np.uint8, 1)[0] - t_datablock['n_channels'] = t_n_channels + t_datablock["n_channels"] = t_n_channels t_channels = [] for i_ch in range(t_n_channels): @@ -191,76 +216,78 @@ def _read_nihon_header(fname): t_idx = np.fromfile(fid, np.uint8, 1)[0] t_channels.append(_chan_labels[t_idx]) - t_datablock['channels'] = t_channels + t_datablock["channels"] = t_channels fid.seek(t_data_address + 0x1C) t_record_duration = np.fromfile(fid, np.uint32, 1)[0] - t_datablock['duration'] = t_record_duration + t_datablock["duration"] = t_record_duration - fid.seek(t_data_address + 0x1a) + fid.seek(t_data_address + 0x1A) sfreq = np.fromfile(fid, np.uint16, 1)[0] & 0x3FFF - t_datablock['sfreq'] = sfreq + t_datablock["sfreq"] = sfreq - t_datablock['n_samples'] = int(t_record_duration * sfreq / 10) - t_controlblock['datablocks'].append(t_datablock) + t_datablock["n_samples"] = int(t_record_duration * sfreq / 10) + t_controlblock["datablocks"].append(t_datablock) controlblocks.append(t_controlblock) - header['controlblocks'] = controlblocks + header["controlblocks"] = controlblocks # Now check that every data block has the same channels and sfreq chans = [] sfreqs = [] nsamples = [] - for t_ctl in header['controlblocks']: - for t_dtb in t_ctl['datablocks']: - chans.append(t_dtb['channels']) - sfreqs.append(t_dtb['sfreq']) - nsamples.append(t_dtb['n_samples']) + for t_ctl in header["controlblocks"]: + for t_dtb in t_ctl["datablocks"]: + chans.append(t_dtb["channels"]) + sfreqs.append(t_dtb["sfreq"]) + nsamples.append(t_dtb["n_samples"]) for i_elem in range(1, len(chans)): if chans[0] != chans[i_elem]: - raise ValueError('Channel names in datablocks do not match') + raise ValueError("Channel names in datablocks do not match") if sfreqs[0] != sfreqs[i_elem]: - raise ValueError('Sample frequency in datablocks do not match') - header['ch_names'] = chans[0] - header['sfreq'] = sfreqs[0] - header['n_samples'] = np.sum(nsamples) + raise ValueError("Sample frequency in datablocks do not match") + header["ch_names"] = chans[0] + header["sfreq"] = sfreqs[0] + header["n_samples"] = np.sum(nsamples) # TODO: Support more than one controlblock and more than one datablock - if header['n_ctlblocks'] != 1: - raise NotImplementedError('I dont know how to read more than one ' - 'control block for this type of file :(') - if header['controlblocks'][0]['n_datablocks'] > 1: + if header["n_ctlblocks"] != 1: + raise NotImplementedError( + "I dont know how to read more than one " + "control block for this type of file :(" + ) + if header["controlblocks"][0]["n_datablocks"] > 1: # Multiple blocks, check that they all have the same kind of data - datablocks = header['controlblocks'][0]['datablocks'] + datablocks = header["controlblocks"][0]["datablocks"] block_0 = datablocks[0] for t_block in datablocks[1:]: - if block_0['n_channels'] != t_block['n_channels']: + if block_0["n_channels"] != t_block["n_channels"]: raise ValueError( - 'Cannot read NK file with different number of channels ' - 'in each datablock') - if block_0['channels'] != t_block['channels']: + "Cannot read NK file with different number of channels " + "in each datablock" + ) + if block_0["channels"] != t_block["channels"]: raise ValueError( - 'Cannot read NK file with different channels in each ' - 'datablock') - if block_0['sfreq'] != t_block['sfreq']: + "Cannot read NK file with different channels in each " "datablock" + ) + if block_0["sfreq"] != t_block["sfreq"]: raise ValueError( - 'Cannot read NK file with different sfreq in each ' - 'datablock') + "Cannot read NK file with different sfreq in each " "datablock" + ) return header def _read_nihon_annotations(fname): fname = _ensure_path(fname) - log_fname = fname.with_suffix('.LOG') + log_fname = fname.with_suffix(".LOG") if not log_fname.exists(): - warn('No LOG file exists. Annotations will not be read') + warn("No LOG file exists. Annotations will not be read") return dict(onset=[], duration=[], description=[]) - logger.info('Found LOG file, reading events.') - with open(log_fname, 'r') as fid: - version = np.fromfile(fid, '|S16', 1).astype('U16')[0] + logger.info("Found LOG file, reading events.") + with open(log_fname, "r") as fid: + version = np.fromfile(fid, "|S16", 1).astype("U16")[0] if version not in _valid_headers: - raise ValueError( - 'Not a valid Nihon Kohden LOG file ({})'.format(version)) + raise ValueError("Not a valid Nihon Kohden LOG file ({})".format(version)) fid.seek(0x91) n_logblocks = np.fromfile(fid, np.uint8, 1)[0] @@ -272,7 +299,7 @@ def _read_nihon_annotations(fname): fid.seek(t_blk_address + 0x12) n_logs = np.fromfile(fid, np.uint8, 1)[0] fid.seek(t_blk_address + 0x14) - t_logs = np.fromfile(fid, '|S45', n_logs) + t_logs = np.fromfile(fid, "|S45", n_logs) for t_log in t_logs: for enc in _encodings: try: @@ -282,30 +309,30 @@ def _read_nihon_annotations(fname): else: break else: - warn(f'Could not decode log as one of {_encodings}') + warn(f"Could not decode log as one of {_encodings}") continue - t_desc = t_log[:20].strip('\x00') - t_onset = datetime.strptime(t_log[20:26], '%H%M%S') - t_onset = (t_onset.hour * 3600 + t_onset.minute * 60 + - t_onset.second) + t_desc = t_log[:20].strip("\x00") + t_onset = datetime.strptime(t_log[20:26], "%H%M%S") + t_onset = t_onset.hour * 3600 + t_onset.minute * 60 + t_onset.second all_onsets.append(t_onset) all_descriptions.append(t_desc) annots = dict( onset=all_onsets, duration=[0] * len(all_onsets), - description=all_descriptions) + description=all_descriptions, + ) return annots def _map_ch_to_type(ch_name): - ch_type_pattern = OrderedDict([ - ('stim', ('Mark',)), ('misc', ('DC', 'NA', 'Z', '$')), - ('bio', ('X',))]) + ch_type_pattern = OrderedDict( + [("stim", ("Mark",)), ("misc", ("DC", "NA", "Z", "$")), ("bio", ("X",))] + ) for key, kinds in ch_type_pattern.items(): if any(kind in ch_name for kind in kinds): return key - return 'eeg' + return "eeg" def _map_ch_to_specs(ch_name): @@ -323,8 +350,14 @@ def _map_ch_to_specs(ch_name): cal = t_range / 65535 offset = phys_min - (dig_min * cal) - out = dict(unit=unit_mult, phys_min=phys_min, phys_max=phys_max, - dig_min=dig_min, cal=cal, offset=offset) + out = dict( + unit=unit_mult, + phys_min=phys_min, + phys_max=phys_max, + dig_min=dig_min, + cal=cal, + offset=offset, + ) return out @@ -347,83 +380,87 @@ class RawNihon(BaseRaw): @verbose def __init__(self, fname, preload=False, verbose=None): - fname = _check_fname(fname, 'read', True, 'fname') + fname = _check_fname(fname, "read", True, "fname") data_name = fname.name - logger.info('Loading %s' % data_name) + logger.info("Loading %s" % data_name) header = _read_nihon_header(fname) metadata = _read_nihon_metadata(fname) # n_chan = len(header['ch_names']) + 1 - sfreq = header['sfreq'] + sfreq = header["sfreq"] # data are multiplexed int16 - ch_names = header['ch_names'] + ch_names = header["ch_names"] ch_types = [_map_ch_to_type(x) for x in ch_names] info = create_info(ch_names, sfreq, ch_types) - n_samples = header['n_samples'] + n_samples = header["n_samples"] - if 'meas_date' in metadata: + if "meas_date" in metadata: with info._unlock(): - info['meas_date'] = metadata['meas_date'] - chs = {x: _map_ch_to_specs(x) for x in info['ch_names']} - - cal = np.array( - [chs[x]['cal'] for x in info['ch_names']], float)[:, np.newaxis] - offsets = np.array( - [chs[x]['offset'] for x in info['ch_names']], float)[:, np.newaxis] - gains = np.array( - [chs[x]['unit'] for x in info['ch_names']], float)[:, np.newaxis] - - raw_extras = dict( - cal=cal, offsets=offsets, gains=gains, header=header) + info["meas_date"] = metadata["meas_date"] + chs = {x: _map_ch_to_specs(x) for x in info["ch_names"]} + + cal = np.array([chs[x]["cal"] for x in info["ch_names"]], float)[:, np.newaxis] + offsets = np.array([chs[x]["offset"] for x in info["ch_names"]], float)[ + :, np.newaxis + ] + gains = np.array([chs[x]["unit"] for x in info["ch_names"]], float)[ + :, np.newaxis + ] + + raw_extras = dict(cal=cal, offsets=offsets, gains=gains, header=header) self._header = header - for i_ch, ch_name in enumerate(info['ch_names']): - t_range = (chs[ch_name]['phys_max'] - chs[ch_name]['phys_min']) - info['chs'][i_ch]['range'] = t_range - info['chs'][i_ch]['cal'] = 1 / t_range + for i_ch, ch_name in enumerate(info["ch_names"]): + t_range = chs[ch_name]["phys_max"] - chs[ch_name]["phys_min"] + info["chs"][i_ch]["range"] = t_range + info["chs"][i_ch]["cal"] = 1 / t_range super(RawNihon, self).__init__( - info, preload=preload, last_samps=(n_samples - 1,), - filenames=[fname.as_posix()], orig_format='short', - raw_extras=[raw_extras]) + info, + preload=preload, + last_samps=(n_samples - 1,), + filenames=[fname.as_posix()], + orig_format="short", + raw_extras=[raw_extras], + ) # Get annotations from LOG file annots = _read_nihon_annotations(fname) # Annotate acquisition skips - controlblock = self._header['controlblocks'][0] + controlblock = self._header["controlblocks"][0] cur_sample = 0 - if controlblock['n_datablocks'] > 1: - for i_block in range(controlblock['n_datablocks'] - 1): - t_block = controlblock['datablocks'][i_block] - cur_sample = cur_sample + t_block['n_samples'] - cur_tpoint = (cur_sample - 0.5) / t_block['sfreq'] + if controlblock["n_datablocks"] > 1: + for i_block in range(controlblock["n_datablocks"] - 1): + t_block = controlblock["datablocks"][i_block] + cur_sample = cur_sample + t_block["n_samples"] + cur_tpoint = (cur_sample - 0.5) / t_block["sfreq"] # Add annotations as in append raw - annots['onset'].append(cur_tpoint) - annots['duration'].append(0.0) - annots['description'].append('BAD boundary') - annots['onset'].append(cur_tpoint) - annots['duration'].append(0.0) - annots['description'].append('EDGE boundary') - - annotations = Annotations(**annots, orig_time=info['meas_date']) + annots["onset"].append(cur_tpoint) + annots["duration"].append(0.0) + annots["description"].append("BAD boundary") + annots["onset"].append(cur_tpoint) + annots["duration"].append(0.0) + annots["description"].append("EDGE boundary") + + annotations = Annotations(**annots, orig_time=info["meas_date"]) self.set_annotations(annotations) def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): """Read a chunk of raw data.""" # For now we assume one control block - header = self._raw_extras[fi]['header'] + header = self._raw_extras[fi]["header"] # Get the original cal, offsets and gains - cal = self._raw_extras[fi]['cal'] - offsets = self._raw_extras[fi]['offsets'] - gains = self._raw_extras[fi]['gains'] + cal = self._raw_extras[fi]["cal"] + offsets = self._raw_extras[fi]["offsets"] + gains = self._raw_extras[fi]["gains"] # get the right datablock - datablocks = header['controlblocks'][0]['datablocks'] - ends = np.cumsum([t['n_samples'] for t in datablocks]) + datablocks = header["controlblocks"][0]["datablocks"] + ends = np.cumsum([t["n_samples"] for t in datablocks]) start_block = np.where(start < ends)[0][0] stop_block = np.where(stop <= ends)[0][0] @@ -439,13 +476,18 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): new_stop = stop else: # Otherwise, stop on the last sample of the block - new_stop = t_block['n_samples'] + new_start + new_stop = t_block["n_samples"] + new_start samples_to_read = new_stop - new_start sample_stop = sample_start + samples_to_read self._read_segment_file( - data[:, sample_start:sample_stop], idx, fi, - new_start, new_stop, cals, mult + data[:, sample_start:sample_stop], + idx, + fi, + new_start, + new_stop, + cals, + mult, ) # Update variables for next loop @@ -454,9 +496,8 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): else: datablock = datablocks[start_block] - n_channels = datablock['n_channels'] + 1 - datastart = (datablock['address'] + 0x27 + - (datablock['n_channels'] * 10)) + n_channels = datablock["n_channels"] + 1 + datastart = datablock["address"] + 0x27 + (datablock["n_channels"] * 10) # Compute start offset based on the beginning of the block rel_start = start @@ -464,12 +505,12 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): rel_start = start - ends[start_block - 1] start_offset = datastart + rel_start * n_channels * 2 - with open(self._filenames[fi], 'rb') as fid: + with open(self._filenames[fi], "rb") as fid: to_read = (stop - start) * n_channels fid.seek(start_offset) - block_data = np.fromfile(fid, ' \ - len(set(nihon._read_21e_file(fname))) - msg = 'Channel names are not unique, found duplicates for: ' \ - '{\'FP1\'}. Applying running numbers for duplicates.' + assert len(nihon._read_21e_file(fname)) > len(set(nihon._read_21e_file(fname))) + msg = ( + "Channel names are not unique, found duplicates for: " + "{'FP1'}. Applying running numbers for duplicates." + ) with pytest.warns(RuntimeWarning, match=msg): read_raw_nihon(fname) diff --git a/mne/io/nirx/_localized_abbr.py b/mne/io/nirx/_localized_abbr.py index 4e42f7ddcff..c12133ef994 100644 --- a/mne/io/nirx/_localized_abbr.py +++ b/mne/io/nirx/_localized_abbr.py @@ -41,20 +41,104 @@ """ _localized_abbr = { - 'en_US.utf8': { - "month": {'jan': 'jan', 'feb': 'feb', 'mar': 'mar', 'apr': 'apr', 'may': 'may', 'jun': 'jun', 'jul': 'jul', 'aug': 'aug', 'sep': 'sep', 'oct': 'oct', 'nov': 'nov', 'dec': 'dec', }, # noqa - "weekday": {'sat': 'sat', 'sun': 'sun', 'mon': 'mon', 'tue': 'tue', 'wed': 'wed', 'thu': 'thu', 'fri': 'fri', }, # noqa + "en_US.utf8": { + "month": { + "jan": "jan", + "feb": "feb", + "mar": "mar", + "apr": "apr", + "may": "may", + "jun": "jun", + "jul": "jul", + "aug": "aug", + "sep": "sep", + "oct": "oct", + "nov": "nov", + "dec": "dec", + }, # noqa + "weekday": { + "sat": "sat", + "sun": "sun", + "mon": "mon", + "tue": "tue", + "wed": "wed", + "thu": "thu", + "fri": "fri", + }, # noqa }, - 'de_DE': { - "month": {'jan': 'jan', 'feb': 'feb', 'mär': 'mar', 'apr': 'apr', 'mai': 'may', 'jun': 'jun', 'jul': 'jul', 'aug': 'aug', 'sep': 'sep', 'okt': 'oct', 'nov': 'nov', 'dez': 'dec', }, # noqa - "weekday": {'sa': 'sat', 'so': 'sun', 'mo': 'mon', 'di': 'tue', 'mi': 'wed', 'do': 'thu', 'fr': 'fri', }, # noqa + "de_DE": { + "month": { + "jan": "jan", + "feb": "feb", + "mär": "mar", + "apr": "apr", + "mai": "may", + "jun": "jun", + "jul": "jul", + "aug": "aug", + "sep": "sep", + "okt": "oct", + "nov": "nov", + "dez": "dec", + }, # noqa + "weekday": { + "sa": "sat", + "so": "sun", + "mo": "mon", + "di": "tue", + "mi": "wed", + "do": "thu", + "fr": "fri", + }, # noqa }, - 'fr_FR': { - "month": {'janv.': 'jan', 'févr.': 'feb', 'mars': 'mar', 'avril': 'apr', 'mai': 'may', 'juin': 'jun', 'juil.': 'jul', 'août': 'aug', 'sept.': 'sep', 'oct.': 'oct', 'nov.': 'nov', 'déc.': 'dec', }, # noqa - "weekday": {'sam.': 'sat', 'dim.': 'sun', 'lun.': 'mon', 'mar.': 'tue', 'mer.': 'wed', 'jeu.': 'thu', 'ven.': 'fri', }, # noqa + "fr_FR": { + "month": { + "janv.": "jan", + "févr.": "feb", + "mars": "mar", + "avril": "apr", + "mai": "may", + "juin": "jun", + "juil.": "jul", + "août": "aug", + "sept.": "sep", + "oct.": "oct", + "nov.": "nov", + "déc.": "dec", + }, # noqa + "weekday": { + "sam.": "sat", + "dim.": "sun", + "lun.": "mon", + "mar.": "tue", + "mer.": "wed", + "jeu.": "thu", + "ven.": "fri", + }, # noqa }, - 'it_IT': { - "month": {'gen': 'jan', 'feb': 'feb', 'mar': 'mar', 'apr': 'apr', 'mag': 'may', 'giu': 'jun', 'lug': 'jul', 'ago': 'aug', 'set': 'sep', 'ott': 'oct', 'nov': 'nov', 'dic': 'dec', }, # noqa - "weekday": {'sab': 'sat', 'dom': 'sun', 'lun': 'mon', 'mar': 'tue', 'mer': 'wed', 'gio': 'thu', 'ven': 'fri', }, # noqa + "it_IT": { + "month": { + "gen": "jan", + "feb": "feb", + "mar": "mar", + "apr": "apr", + "mag": "may", + "giu": "jun", + "lug": "jul", + "ago": "aug", + "set": "sep", + "ott": "oct", + "nov": "nov", + "dic": "dec", + }, # noqa + "weekday": { + "sab": "sat", + "dom": "sun", + "lun": "mon", + "mar": "tue", + "mer": "wed", + "gio": "thu", + "ven": "fri", + }, # noqa }, } diff --git a/mne/io/nirx/nirx.py b/mne/io/nirx/nirx.py index 3704379b28a..ef3013d3cac 100644 --- a/mne/io/nirx/nirx.py +++ b/mne/io/nirx/nirx.py @@ -19,12 +19,20 @@ from ...annotations import Annotations from ..._freesurfer import get_mni_fiducials from ...transforms import apply_trans, _get_trans -from ...utils import (logger, verbose, fill_doc, warn, _check_fname, - _validate_type, _check_option, _mask_to_onsets_offsets) +from ...utils import ( + logger, + verbose, + fill_doc, + warn, + _check_fname, + _validate_type, + _check_option, + _mask_to_onsets_offsets, +) @fill_doc -def read_raw_nirx(fname, saturated='annotate', preload=False, verbose=None): +def read_raw_nirx(fname, saturated="annotate", preload=False, verbose=None): """Reader for a NIRX fNIRS recording. Parameters @@ -53,7 +61,7 @@ def read_raw_nirx(fname, saturated='annotate', preload=False, verbose=None): def _open(fname): - return open(fname, 'r', encoding='latin-1') + return open(fname, "r", encoding="latin-1") @fill_doc @@ -80,17 +88,18 @@ class RawNIRX(BaseRaw): @verbose def __init__(self, fname, saturated, preload=False, verbose=None): from scipy.io import loadmat - logger.info('Loading %s' % fname) - _validate_type(fname, 'path-like', 'fname') - _validate_type(saturated, str, 'saturated') - _check_option('saturated', saturated, ('annotate', 'nan', 'ignore')) + + logger.info("Loading %s" % fname) + _validate_type(fname, "path-like", "fname") + _validate_type(saturated, str, "saturated") + _check_option("saturated", saturated, ("annotate", "nan", "ignore")) fname = str(fname) - if fname.endswith('.hdr'): + if fname.endswith(".hdr"): fname = op.dirname(op.abspath(fname)) fname = str(_check_fname(fname, "read", True, "fname", need_dir=True)) - json_config = glob.glob('%s/*%s' % (fname, "config.json")) + json_config = glob.glob("%s/*%s" % (fname, "config.json")) if len(json_config): is_aurora = True else: @@ -98,93 +107,118 @@ def __init__(self, fname, saturated, preload=False, verbose=None): if is_aurora: # NIRSport2 devices using Aurora software - keys = ('hdr', 'config.json', 'description.json', - 'wl1', 'wl2', 'probeInfo.mat', 'tri') + keys = ( + "hdr", + "config.json", + "description.json", + "wl1", + "wl2", + "probeInfo.mat", + "tri", + ) else: # NIRScout devices and NIRSport1 devices - keys = ('hdr', 'inf', 'set', 'tpl', 'wl1', 'wl2', - 'config.txt', 'probeInfo.mat') - n_dat = len(glob.glob('%s/*%s' % (fname, 'dat'))) + keys = ( + "hdr", + "inf", + "set", + "tpl", + "wl1", + "wl2", + "config.txt", + "probeInfo.mat", + ) + n_dat = len(glob.glob("%s/*%s" % (fname, "dat"))) if n_dat != 1: - warn("A single dat file was expected in the specified path, " - f"but got {n_dat}. This may indicate that the file " - "structure has been modified since the measurement " - "was saved.") + warn( + "A single dat file was expected in the specified path, " + f"but got {n_dat}. This may indicate that the file " + "structure has been modified since the measurement " + "was saved." + ) # Check if required files exist and store names for later use files = dict() nan_mask = dict() for key in keys: - files[key] = glob.glob('%s/*%s' % (fname, key)) + files[key] = glob.glob("%s/*%s" % (fname, key)) fidx = 0 if len(files[key]) != 1: - if key not in ('wl1', 'wl2'): - raise RuntimeError( - f'Need one {key} file, got {len(files[key])}') - noidx = np.where(['nosatflags_' in op.basename(x) - for x in files[key]])[0] + if key not in ("wl1", "wl2"): + raise RuntimeError(f"Need one {key} file, got {len(files[key])}") + noidx = np.where(["nosatflags_" in op.basename(x) for x in files[key]])[ + 0 + ] if len(noidx) != 1 or len(files[key]) != 2: raise RuntimeError( - f'Need one nosatflags and one standard {key} file, ' - f'got {len(files[key])}') + f"Need one nosatflags and one standard {key} file, " + f"got {len(files[key])}" + ) # Here two files have been found, one that is called # no sat flags. The nosatflag file has no NaNs in it. noidx = noidx[0] - if saturated == 'ignore': + if saturated == "ignore": # Ignore NaN and return values fidx = noidx - elif saturated == 'nan': + elif saturated == "nan": # Return NaN fidx = 0 if noidx == 1 else 1 else: - assert saturated == 'annotate' # guaranteed above + assert saturated == "annotate" # guaranteed above fidx = noidx nan_mask[key] = files[key][0 if noidx == 1 else 1] files[key] = files[key][fidx] # Read number of rows/samples of wavelength data - with _open(files['wl1']) as fid: - last_sample = fid.read().count('\n') - 1 + with _open(files["wl1"]) as fid: + last_sample = fid.read().count("\n") - 1 # Read header file # The header file isn't compliant with the configparser. So all the # text between comments must be removed before passing to parser - with _open(files['hdr']) as f: + with _open(files["hdr"]) as f: hdr_str_all = f.read() - hdr_str = re.sub('#.*?#', '', hdr_str_all, flags=re.DOTALL) + hdr_str = re.sub("#.*?#", "", hdr_str_all, flags=re.DOTALL) if is_aurora: - hdr_str = re.sub('(\\[DataStructure].*)', '', - hdr_str, flags=re.DOTALL) + hdr_str = re.sub("(\\[DataStructure].*)", "", hdr_str, flags=re.DOTALL) hdr = RawConfigParser() hdr.read_string(hdr_str) # Check that the file format version is supported if is_aurora: # We may need to ease this requirement back - if hdr['GeneralInfo']['Version'] not in ['2021.4.0-34-ge9fdbbc8', - '2021.9.0-5-g3eb32851', - '2021.9.0-6-g14ef4a71']: - warn("MNE has not been tested with Aurora version " - f"{hdr['GeneralInfo']['Version']}") + if hdr["GeneralInfo"]["Version"] not in [ + "2021.4.0-34-ge9fdbbc8", + "2021.9.0-5-g3eb32851", + "2021.9.0-6-g14ef4a71", + ]: + warn( + "MNE has not been tested with Aurora version " + f"{hdr['GeneralInfo']['Version']}" + ) else: - if hdr['GeneralInfo']['NIRStar'] not in ['"15.0"', '"15.2"', - '"15.3"']: - raise RuntimeError('MNE does not support this NIRStar version' - ' (%s)' % (hdr['GeneralInfo']['NIRStar'],)) - if "NIRScout" not in hdr['GeneralInfo']['Device'] \ - and "NIRSport" not in hdr['GeneralInfo']['Device']: - warn("Only import of data from NIRScout devices have been " - "thoroughly tested. You are using a %s device. " % - hdr['GeneralInfo']['Device']) + if hdr["GeneralInfo"]["NIRStar"] not in ['"15.0"', '"15.2"', '"15.3"']: + raise RuntimeError( + "MNE does not support this NIRStar version" + " (%s)" % (hdr["GeneralInfo"]["NIRStar"],) + ) + if ( + "NIRScout" not in hdr["GeneralInfo"]["Device"] + and "NIRSport" not in hdr["GeneralInfo"]["Device"] + ): + warn( + "Only import of data from NIRScout devices have been " + "thoroughly tested. You are using a %s device. " + % hdr["GeneralInfo"]["Device"] + ) # Parse required header fields # Extract measurement date and time if is_aurora: - datetime_str = hdr['GeneralInfo']['Date'] + datetime_str = hdr["GeneralInfo"]["Date"] else: - datetime_str = hdr['GeneralInfo']['Date'] + \ - hdr['GeneralInfo']['Time'] + datetime_str = hdr["GeneralInfo"]["Date"] + hdr["GeneralInfo"]["Time"] meas_date = None # Several formats have been observed so we try each in turn @@ -193,19 +227,20 @@ def __init__(self, fname, saturated, preload=False, verbose=None): # So far we are lucky in that all the formats below, if they # include %a (weekday abbr), always come first. Thus we can use # a .split(), replace, and rejoin. - loc_datetime_str = datetime_str.split(' ') - for key, val in translations['weekday'].items(): + loc_datetime_str = datetime_str.split(" ") + for key, val in translations["weekday"].items(): loc_datetime_str[0] = loc_datetime_str[0].replace(key, val) for ii in range(1, len(loc_datetime_str)): - for key, val in translations['month'].items(): - loc_datetime_str[ii] = \ - loc_datetime_str[ii].replace(key, val) - loc_datetime_str = ' '.join(loc_datetime_str) - logger.debug(f'Trying {loc} datetime: {loc_datetime_str}') - for dt_code in ['"%a, %b %d, %Y""%H:%M:%S.%f"', - '"%a %d %b %Y""%H:%M:%S.%f"', - '"%a, %d %b %Y""%H:%M:%S.%f"', - '%Y-%m-%d %H:%M:%S.%f']: + for key, val in translations["month"].items(): + loc_datetime_str[ii] = loc_datetime_str[ii].replace(key, val) + loc_datetime_str = " ".join(loc_datetime_str) + logger.debug(f"Trying {loc} datetime: {loc_datetime_str}") + for dt_code in [ + '"%a, %b %d, %Y""%H:%M:%S.%f"', + '"%a %d %b %Y""%H:%M:%S.%f"', + '"%a, %d %b %Y""%H:%M:%S.%f"', + "%Y-%m-%d %H:%M:%S.%f", + ]: try: meas_date = dt.datetime.strptime(loc_datetime_str, dt_code) except ValueError: @@ -213,61 +248,71 @@ def __init__(self, fname, saturated, preload=False, verbose=None): else: meas_date = meas_date.replace(tzinfo=dt.timezone.utc) do_break = True - logger.debug( - f'Measurement date language {loc} detected: {dt_code}') + logger.debug(f"Measurement date language {loc} detected: {dt_code}") break if do_break: break if meas_date is None: - warn("Extraction of measurement date from NIRX file failed. " - "This can be caused by files saved in certain locales " - f"(currently only {list(_localized_abbr)} supported). " - "Please report this as a github issue. " - "The date is being set to January 1st, 2000, " - f"instead of {repr(datetime_str)}.") - meas_date = dt.datetime(2000, 1, 1, 0, 0, 0, - tzinfo=dt.timezone.utc) + warn( + "Extraction of measurement date from NIRX file failed. " + "This can be caused by files saved in certain locales " + f"(currently only {list(_localized_abbr)} supported). " + "Please report this as a github issue. " + "The date is being set to January 1st, 2000, " + f"instead of {repr(datetime_str)}." + ) + meas_date = dt.datetime(2000, 1, 1, 0, 0, 0, tzinfo=dt.timezone.utc) # Extract frequencies of light used by machine if is_aurora: fnirs_wavelengths = [760, 850] else: - fnirs_wavelengths = [int(s) for s in - re.findall(r'(\d+)', - hdr['ImagingParameters'][ - 'Wavelengths'])] + fnirs_wavelengths = [ + int(s) + for s in re.findall(r"(\d+)", hdr["ImagingParameters"]["Wavelengths"]) + ] # Extract source-detectors if is_aurora: - sources = re.findall(r'(\d+)-\d+', hdr_str_all.split("\n")[-2]) - detectors = re.findall(r'\d+-(\d+)', hdr_str_all.split("\n")[-2]) + sources = re.findall(r"(\d+)-\d+", hdr_str_all.split("\n")[-2]) + detectors = re.findall(r"\d+-(\d+)", hdr_str_all.split("\n")[-2]) sources = [int(s) + 1 for s in sources] detectors = [int(d) + 1 for d in detectors] else: - sources = np.asarray([int(s) for s in - re.findall(r'(\d+)-\d+:\d+', - hdr['DataStructure'] - ['S-D-Key'])], int) - detectors = np.asarray([int(s) for s in - re.findall(r'\d+-(\d+):\d+', - hdr['DataStructure'] - ['S-D-Key'])], int) + sources = np.asarray( + [ + int(s) + for s in re.findall( + r"(\d+)-\d+:\d+", hdr["DataStructure"]["S-D-Key"] + ) + ], + int, + ) + detectors = np.asarray( + [ + int(s) + for s in re.findall( + r"\d+-(\d+):\d+", hdr["DataStructure"]["S-D-Key"] + ) + ], + int, + ) # Extract sampling rate if is_aurora: - samplingrate = float(hdr['GeneralInfo']['Sampling rate']) + samplingrate = float(hdr["GeneralInfo"]["Sampling rate"]) else: - samplingrate = float(hdr['ImagingParameters']['SamplingRate']) + samplingrate = float(hdr["ImagingParameters"]["SamplingRate"]) # Read participant information file if is_aurora: - with open(files['description.json']) as f: + with open(files["description.json"]) as f: inf = json.load(f) else: inf = ConfigParser(allow_no_value=True) - inf.read(files['inf']) - inf = inf._sections['Subject Demographics'] + inf.read(files["inf"]) + inf = inf._sections["Subject Demographics"] # Store subject information from inf file in mne format # Note: NIRX also records "Study Type", "Experiment History", @@ -279,29 +324,28 @@ def __init__(self, fname, saturated, preload=False, verbose=None): if is_aurora: names = inf["subject"].split() else: - names = inf['name'].replace('"', "").split() - subject_info['his_id'] = "_".join(names) + names = inf["name"].replace('"', "").split() + subject_info["his_id"] = "_".join(names) if len(names) > 0: - subject_info['first_name'] = \ - names[0].replace("\"", "") + subject_info["first_name"] = names[0].replace('"', "") if len(names) > 1: - subject_info['last_name'] = \ - names[-1].replace("\"", "") + subject_info["last_name"] = names[-1].replace('"', "") if len(names) > 2: - subject_info['middle_name'] = \ - names[-2].replace("\"", "") - subject_info['sex'] = inf['gender'].replace("\"", "") + subject_info["middle_name"] = names[-2].replace('"', "") + subject_info["sex"] = inf["gender"].replace('"', "") # Recode values - if subject_info['sex'] in {'M', 'Male', '1'}: - subject_info['sex'] = FIFF.FIFFV_SUBJ_SEX_MALE - elif subject_info['sex'] in {'F', 'Female', '2'}: - subject_info['sex'] = FIFF.FIFFV_SUBJ_SEX_FEMALE + if subject_info["sex"] in {"M", "Male", "1"}: + subject_info["sex"] = FIFF.FIFFV_SUBJ_SEX_MALE + elif subject_info["sex"] in {"F", "Female", "2"}: + subject_info["sex"] = FIFF.FIFFV_SUBJ_SEX_FEMALE else: - subject_info['sex'] = FIFF.FIFFV_SUBJ_SEX_UNKNOWN - if inf['age'] != '': - subject_info['birthday'] = (meas_date.year - int(inf['age']), - meas_date.month, - meas_date.day) + subject_info["sex"] = FIFF.FIFFV_SUBJ_SEX_UNKNOWN + if inf["age"] != "": + subject_info["birthday"] = ( + meas_date.year - int(inf["age"]), + meas_date.month, + meas_date.day, + ) # Read information about probe/montage/optodes # A word on terminology used here: @@ -310,30 +354,33 @@ def __init__(self, fname, saturated, preload=False, verbose=None): # Sources and detectors are both called optodes # Each source - detector pair produces a channel # Channels are defined as the midpoint between source and detector - mat_data = loadmat(files['probeInfo.mat']) - probes = mat_data['probeInfo']['probes'][0, 0] - requested_channels = probes['index_c'][0, 0] - src_locs = probes['coords_s3'][0, 0] / 100. - det_locs = probes['coords_d3'][0, 0] / 100. - ch_locs = probes['coords_c3'][0, 0] / 100. + mat_data = loadmat(files["probeInfo.mat"]) + probes = mat_data["probeInfo"]["probes"][0, 0] + requested_channels = probes["index_c"][0, 0] + src_locs = probes["coords_s3"][0, 0] / 100.0 + det_locs = probes["coords_d3"][0, 0] / 100.0 + ch_locs = probes["coords_c3"][0, 0] / 100.0 # These are all in MNI coordinates, so let's transform them to # the Neuromag head coordinate frame src_locs, det_locs, ch_locs, mri_head_t = _convert_fnirs_to_head( - 'fsaverage', 'mri', 'head', src_locs, det_locs, ch_locs) + "fsaverage", "mri", "head", src_locs, det_locs, ch_locs + ) # Set up digitization - dig = get_mni_fiducials('fsaverage', verbose=False) + dig = get_mni_fiducials("fsaverage", verbose=False) for fid in dig: - fid['r'] = apply_trans(mri_head_t, fid['r']) - fid['coord_frame'] = FIFF.FIFFV_COORD_HEAD + fid["r"] = apply_trans(mri_head_t, fid["r"]) + fid["coord_frame"] = FIFF.FIFFV_COORD_HEAD for ii, ch_loc in enumerate(ch_locs, 1): - dig.append(dict( - kind=FIFF.FIFFV_POINT_EEG, # misnomer but probably okay - r=ch_loc, - ident=ii, - coord_frame=FIFF.FIFFV_COORD_HEAD, - )) + dig.append( + dict( + kind=FIFF.FIFFV_POINT_EEG, # misnomer but probably okay + r=ch_loc, + ident=ii, + coord_frame=FIFF.FIFFV_COORD_HEAD, + ) + ) dig = _format_dig_points(dig) del mri_head_t @@ -343,25 +390,25 @@ def __init__(self, fname, saturated, preload=False, verbose=None): # subset requested in the probe file req_ind = np.array([], int) for req_idx in range(requested_channels.shape[0]): - sd_idx = np.where((sources == requested_channels[req_idx][0]) & - (detectors == requested_channels[req_idx][1])) + sd_idx = np.where( + (sources == requested_channels[req_idx][0]) + & (detectors == requested_channels[req_idx][1]) + ) req_ind = np.concatenate((req_ind, sd_idx[0])) req_ind = req_ind.astype(int) snames = [f"S{sources[idx]}" for idx in req_ind] dnames = [f"_D{detectors[idx]}" for idx in req_ind] sdnames = [m + str(n) for m, n in zip(snames, dnames)] - sd1 = [s + ' ' + str(fnirs_wavelengths[0]) for s in sdnames] - sd2 = [s + ' ' + str(fnirs_wavelengths[1]) for s in sdnames] + sd1 = [s + " " + str(fnirs_wavelengths[0]) for s in sdnames] + sd2 = [s + " " + str(fnirs_wavelengths[1]) for s in sdnames] chnames = [val for pair in zip(sd1, sd2) for val in pair] # Create mne structure - info = create_info(chnames, - samplingrate, - ch_types='fnirs_cw_amplitude') + info = create_info(chnames, samplingrate, ch_types="fnirs_cw_amplitude") with info._unlock(): info.update(subject_info=subject_info, dig=dig) - info['meas_date'] = meas_date + info["meas_date"] = meas_date # Store channel, source, and detector locations # The channel location is stored in the first 3 entries of loc. @@ -378,11 +425,11 @@ def __init__(self, fname, saturated, preload=False, verbose=None): midpoint = (src_locs[src, :] + det_locs[det, :]) / 2 for ii in range(2): ch_idx3 = ch_idx2 * 2 + ii - info['chs'][ch_idx3]['loc'][3:6] = src_locs[src, :] - info['chs'][ch_idx3]['loc'][6:9] = det_locs[det, :] - info['chs'][ch_idx3]['loc'][:3] = midpoint - info['chs'][ch_idx3]['loc'][9] = fnirs_wavelengths[ii] - info['chs'][ch_idx3]['coord_frame'] = FIFF.FIFFV_COORD_HEAD + info["chs"][ch_idx3]["loc"][3:6] = src_locs[src, :] + info["chs"][ch_idx3]["loc"][6:9] = det_locs[det, :] + info["chs"][ch_idx3]["loc"][:3] = midpoint + info["chs"][ch_idx3]["loc"][9] = fnirs_wavelengths[ii] + info["chs"][ch_idx3]["coord_frame"] = FIFF.FIFFV_COORD_HEAD # Extract the start/stop numbers for samples in the CSV. In theory the # sample bounds should just be 10 * the number of channels, but some @@ -390,10 +437,10 @@ def __init__(self, fname, saturated, preload=False, verbose=None): # instead make a single pass over the entire file at the beginning so # that we know how to seek and read later. bounds = dict() - for key in ('wl1', 'wl2'): + for key in ("wl1", "wl2"): offset = 0 bounds[key] = [offset] - with open(files[key], 'rb') as fid: + with open(files[key], "rb") as fid: for line in fid: offset += len(line) bounds[key].append(offset) @@ -401,51 +448,60 @@ def __init__(self, fname, saturated, preload=False, verbose=None): # Extras required for reading data raw_extras = { - 'sd_index': req_ind, - 'files': files, - 'bounds': bounds, - 'nan_mask': nan_mask, + "sd_index": req_ind, + "files": files, + "bounds": bounds, + "nan_mask": nan_mask, } # Get our saturated mask annot_mask = None - for ki, key in enumerate(('wl1', 'wl2')): + for ki, key in enumerate(("wl1", "wl2")): if nan_mask.get(key, None) is None: continue - mask = np.isnan(_read_csv_rows_cols( - nan_mask[key], 0, last_sample + 1, req_ind, {0: 0, 1: None}).T) - if saturated == 'nan': + mask = np.isnan( + _read_csv_rows_cols( + nan_mask[key], 0, last_sample + 1, req_ind, {0: 0, 1: None} + ).T + ) + if saturated == "nan": nan_mask[key] = mask else: - assert saturated == 'annotate' + assert saturated == "annotate" if annot_mask is None: annot_mask = np.zeros( - (len(info['ch_names']) // 2, last_sample + 1), bool) + (len(info["ch_names"]) // 2, last_sample + 1), bool + ) annot_mask |= mask nan_mask[key] = None # shouldn't need again super(RawNIRX, self).__init__( - info, preload, filenames=[fname], last_samps=[last_sample], - raw_extras=[raw_extras], verbose=verbose) + info, + preload, + filenames=[fname], + last_samps=[last_sample], + raw_extras=[raw_extras], + verbose=verbose, + ) # make onset/duration/description onset, duration, description, ch_names = list(), list(), list(), list() if annot_mask is not None: for ci, mask in enumerate(annot_mask): on, dur = _mask_to_onsets_offsets(mask) - on = on / info['sfreq'] - dur = dur / info['sfreq'] + on = on / info["sfreq"] + dur = dur / info["sfreq"] dur -= on onset.extend(on) duration.extend(dur) - description.extend(['BAD_SATURATED'] * len(on)) - ch_names.extend([self.ch_names[2 * ci:2 * ci + 2]] * len(on)) + description.extend(["BAD_SATURATED"] * len(on)) + ch_names.extend([self.ch_names[2 * ci : 2 * ci + 2]] * len(on)) # Read triggers from event file if not is_aurora: - files['tri'] = files['hdr'][:-3] + 'evt' - if op.isfile(files['tri']): - with _open(files['tri']) as fid: - t = [re.findall(r'(\d+)', line) for line in fid] + files["tri"] = files["hdr"][:-3] + "evt" + if op.isfile(files["tri"]): + with _open(files["tri"]) as fid: + t = [re.findall(r"(\d+)", line) for line in fid] if is_aurora: tf_idx, desc_idx = _determine_tri_idxs(t[0]) for t_ in t: @@ -453,11 +509,11 @@ def __init__(self, fname, saturated, preload=False, verbose=None): trigger_frame = float(t_[tf_idx]) desc = float(t_[desc_idx]) else: - binary_value = ''.join(t_[1:])[::-1] + binary_value = "".join(t_[1:])[::-1] desc = float(int(binary_value, 2)) trigger_frame = float(t_[0]) onset.append(trigger_frame / samplingrate) - duration.append(1.) # No duration info stored in files + duration.append(1.0) # No duration info stored in files description.append(desc) ch_names.append(list()) annot = Annotations(onset, duration, description, ch_names=ch_names) @@ -469,15 +525,18 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): The NIRX machine records raw data as two different wavelengths. The returned data interleaves the wavelengths. """ - sd_index = self._raw_extras[fi]['sd_index'] + sd_index = self._raw_extras[fi]["sd_index"] wls = list() - for key in ('wl1', 'wl2'): + for key in ("wl1", "wl2"): d = _read_csv_rows_cols( - self._raw_extras[fi]['files'][key], - start, stop, sd_index, - self._raw_extras[fi]['bounds'][key]).T - nan_mask = self._raw_extras[fi]['nan_mask'].get(key, None) + self._raw_extras[fi]["files"][key], + start, + stop, + sd_index, + self._raw_extras[fi]["bounds"][key], + ).T + nan_mask = self._raw_extras[fi]["nan_mask"].get(key, None) if nan_mask is not None: d[nan_mask[:, start:stop]] = np.nan wls.append(d) @@ -492,14 +551,13 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): return data -def _read_csv_rows_cols(fname, start, stop, cols, bounds, - sep=' ', replace=None): - with open(fname, 'rb') as fid: +def _read_csv_rows_cols(fname, start, stop, cols, bounds, sep=" ", replace=None): + with open(fname, "rb") as fid: fid.seek(bounds[start]) args = list() if bounds[1] is not None: args.append(bounds[stop] - bounds[start]) - data = fid.read(*args).decode('latin-1') + data = fid.read(*args).decode("latin-1") if replace is not None: data = replace(data) x = np.fromstring(data, float, sep=sep) diff --git a/mne/io/nirx/tests/test_nirx.py b/mne/io/nirx/tests/test_nirx.py index de7a78b0cca..02b4865867f 100644 --- a/mne/io/nirx/tests/test_nirx.py +++ b/mne/io/nirx/tests/test_nirx.py @@ -16,8 +16,11 @@ from mne.io.tests.test_raw import _test_raw_reader from mne.preprocessing import annotate_nan from mne.transforms import apply_trans, _get_trans -from mne.preprocessing.nirs import source_detector_distances,\ - short_channels, _reorder_nirx +from mne.preprocessing.nirs import ( + source_detector_distances, + short_channels, + _reorder_nirx, +) from mne.io.constants import FIFF testing_path = data_path(download=False) @@ -26,9 +29,7 @@ fname_nirx_15_2_short = ( testing_path / "NIRx" / "nirscout" / "nirx_15_2_recording_w_short" ) -fname_nirx_15_3_short = ( - testing_path / "NIRx" / "nirscout" / "nirx_15_3_recording" -) +fname_nirx_15_3_short = testing_path / "NIRx" / "nirscout" / "nirx_15_3_recording" # This file has no saturated sections @@ -51,9 +52,7 @@ ) # NIRSport2 device using Aurora software -nirsport2 = ( - testing_path / "NIRx" / "nirsport_v2" / "aurora_recording _w_short_and_acc" -) +nirsport2 = testing_path / "NIRx" / "nirsport_v2" / "aurora_recording _w_short_and_acc" nirsport2_2021_9 = testing_path / "NIRx" / "nirsport_v2" / "aurora_2021_9" nirsport2_2021_9_6 = testing_path / "NIRx" / "nirsport_v2" / "aurora_2021_9_6" @@ -77,7 +76,7 @@ def test_nirsport_v2_matches_snirf(nirx_snirf): @requires_testing_data -@pytest.mark.filterwarnings('ignore:.*Extraction of measurement.*:') +@pytest.mark.filterwarnings("ignore:.*Extraction of measurement.*:") def test_nirsport_v2(): """Test NIRSport2 file.""" raw = read_raw_nirx(nirsport2, preload=True) @@ -87,90 +86,87 @@ def test_nirsport_v2(): # nirsite https://github.com/mne-tools/mne-testing-data/pull/86 # figure 3 allowed_distance_error = 0.005 - assert_allclose(source_detector_distances(raw.copy(). - pick("S1_D1 760").info), - [0.0304], atol=allowed_distance_error) - assert_allclose(source_detector_distances(raw.copy(). - pick("S2_D2 760").info), - [0.0400], atol=allowed_distance_error) + assert_allclose( + source_detector_distances(raw.copy().pick("S1_D1 760").info), + [0.0304], + atol=allowed_distance_error, + ) + assert_allclose( + source_detector_distances(raw.copy().pick("S2_D2 760").info), + [0.0400], + atol=allowed_distance_error, + ) # Test location of detectors # The locations of detectors can be seen in the first # figure on this page... # https://github.com/mne-tools/mne-testing-data/pull/86 allowed_dist_error = 0.0002 - locs = [ch['loc'][6:9] for ch in raw.info['chs']] - head_mri_t, _ = _get_trans('fsaverage', 'head', 'mri') + locs = [ch["loc"][6:9] for ch in raw.info["chs"]] + head_mri_t, _ = _get_trans("fsaverage", "head", "mri") mni_locs = apply_trans(head_mri_t, locs) - assert raw.info['ch_names'][0][3:5] == 'D1' - assert_allclose( - mni_locs[0], [-0.0841, -0.0464, -0.0129], atol=allowed_dist_error) + assert raw.info["ch_names"][0][3:5] == "D1" + assert_allclose(mni_locs[0], [-0.0841, -0.0464, -0.0129], atol=allowed_dist_error) - assert raw.info['ch_names'][2][3:5] == 'D6' - assert_allclose( - mni_locs[2], [-0.0841, -0.0138, 0.0248], atol=allowed_dist_error) + assert raw.info["ch_names"][2][3:5] == "D6" + assert_allclose(mni_locs[2], [-0.0841, -0.0138, 0.0248], atol=allowed_dist_error) - assert raw.info['ch_names'][34][3:5] == 'D5' - assert_allclose( - mni_locs[34], [0.0845, -0.0451, -0.0123], atol=allowed_dist_error) + assert raw.info["ch_names"][34][3:5] == "D5" + assert_allclose(mni_locs[34], [0.0845, -0.0451, -0.0123], atol=allowed_dist_error) # Test location of sensors # The locations of sensors can be seen in the second # figure on this page... # https://github.com/mne-tools/mne-testing-data/pull/86 - locs = [ch['loc'][3:6] for ch in raw.info['chs']] - head_mri_t, _ = _get_trans('fsaverage', 'head', 'mri') + locs = [ch["loc"][3:6] for ch in raw.info["chs"]] + head_mri_t, _ = _get_trans("fsaverage", "head", "mri") mni_locs = apply_trans(head_mri_t, locs) - assert raw.info['ch_names'][0][:2] == 'S1' - assert_allclose( - mni_locs[0], [-0.0848, -0.0162, -0.0163], atol=allowed_dist_error) + assert raw.info["ch_names"][0][:2] == "S1" + assert_allclose(mni_locs[0], [-0.0848, -0.0162, -0.0163], atol=allowed_dist_error) - assert raw.info['ch_names'][9][:2] == 'S2' - assert_allclose( - mni_locs[9], [-0.0, -0.1195, 0.0142], atol=allowed_dist_error) + assert raw.info["ch_names"][9][:2] == "S2" + assert_allclose(mni_locs[9], [-0.0, -0.1195, 0.0142], atol=allowed_dist_error) - assert raw.info['ch_names'][39][:2] == 'S8' - assert_allclose( - mni_locs[34], [0.0828, -0.046, 0.0285], atol=allowed_dist_error) + assert raw.info["ch_names"][39][:2] == "S8" + assert_allclose(mni_locs[34], [0.0828, -0.046, 0.0285], atol=allowed_dist_error) assert len(raw.annotations) == 3 - assert raw.annotations.description[0] == '1.0' - assert raw.annotations.description[2] == '6.0' + assert raw.annotations.description[0] == "1.0" + assert raw.annotations.description[2] == "6.0" # Lose tolerance as I am eyeballing the time differences on screen - assert_allclose( - np.diff(raw.annotations.onset), [2.3, 3.1], atol=0.1) + assert_allclose(np.diff(raw.annotations.onset), [2.3, 3.1], atol=0.1) mon = raw.get_montage() assert len(mon.dig) == 27 @requires_testing_data -@pytest.mark.filterwarnings('ignore:.*Extraction of measurement.*:') +@pytest.mark.filterwarnings("ignore:.*Extraction of measurement.*:") def test_nirsport_v1_wo_sat(): """Test NIRSport1 file with no saturation.""" raw = read_raw_nirx(nirsport1_wo_sat, preload=True) # Test data import assert raw._data.shape == (26, 164) - assert raw.info['sfreq'] == 10.416667 + assert raw.info["sfreq"] == 10.416667 # By default real data is returned assert np.sum(np.isnan(raw.get_data())) == 0 - raw = read_raw_nirx(nirsport1_wo_sat, preload=True, saturated='nan') + raw = read_raw_nirx(nirsport1_wo_sat, preload=True, saturated="nan") data = raw.get_data() assert data.shape == (26, 164) assert np.sum(np.isnan(data)) == 0 - raw = read_raw_nirx(nirsport1_wo_sat, saturated='annotate') + raw = read_raw_nirx(nirsport1_wo_sat, saturated="annotate") data = raw.get_data() assert data.shape == (26, 164) assert np.sum(np.isnan(data)) == 0 -@pytest.mark.filterwarnings('ignore:.*Extraction of measurement.*:') +@pytest.mark.filterwarnings("ignore:.*Extraction of measurement.*:") @requires_testing_data def test_nirsport_v1_w_sat(): """Test NIRSport1 file with NaNs but not in channel of interest.""" @@ -179,24 +175,24 @@ def test_nirsport_v1_w_sat(): # Test data import data = raw.get_data() assert data.shape == (26, 176) - assert raw.info['sfreq'] == 10.416667 + assert raw.info["sfreq"] == 10.416667 assert np.sum(np.isnan(data)) == 0 - raw = read_raw_nirx(nirsport1_w_sat, saturated='nan') + raw = read_raw_nirx(nirsport1_w_sat, saturated="nan") data = raw.get_data() assert data.shape == (26, 176) assert np.sum(np.isnan(data)) == 0 - raw = read_raw_nirx(nirsport1_w_sat, saturated='annotate') + raw = read_raw_nirx(nirsport1_w_sat, saturated="annotate") data = raw.get_data() assert data.shape == (26, 176) assert np.sum(np.isnan(data)) == 0 -@pytest.mark.filterwarnings('ignore:.*Extraction of measurement.*:') +@pytest.mark.filterwarnings("ignore:.*Extraction of measurement.*:") @requires_testing_data -@pytest.mark.parametrize('preload', (True, False)) -@pytest.mark.parametrize('meas_date', (None, "orig")) +@pytest.mark.parametrize("preload", (True, False)) +@pytest.mark.parametrize("meas_date", (None, "orig")) def test_nirsport_v1_w_bad_sat(preload, meas_date): """Test NIRSport1 file with NaNs.""" fname = nirsport1_w_fullsat @@ -205,12 +201,12 @@ def test_nirsport_v1_w_bad_sat(preload, meas_date): assert not np.isnan(data).any() assert len(raw.annotations) == 5 # annotated version and ignore should have same data but different annot - raw_ignore = read_raw_nirx(fname, saturated='ignore', preload=preload) + raw_ignore = read_raw_nirx(fname, saturated="ignore", preload=preload) assert_allclose(raw_ignore.get_data(), data) assert len(raw_ignore.annotations) == 2 - assert not any('NAN' in d for d in raw_ignore.annotations.description) + assert not any("NAN" in d for d in raw_ignore.annotations.description) # nan version should not have same data, but we can give it the same annot - raw_nan = read_raw_nirx(fname, saturated='nan', preload=preload) + raw_nan = read_raw_nirx(fname, saturated="nan", preload=preload) data_nan = raw_nan.get_data() assert np.isnan(data_nan).any() assert not np.allclose(raw_nan.get_data(), data) @@ -222,8 +218,8 @@ def test_nirsport_v1_w_bad_sat(preload, meas_date): nan_annots = annotate_nan(raw_nan) assert nan_annots.orig_time == raw_nan.info["meas_date"] raw_nan_annot.set_annotations(nan_annots) - use_mask = np.where(raw.annotations.description == 'BAD_SATURATED') - for key in ('onset', 'duration'): + use_mask = np.where(raw.annotations.description == "BAD_SATURATED") + for key in ("onset", "duration"): a = getattr(raw_nan_annot.annotations, key)[::2] # one ch in each b = getattr(raw.annotations, key)[use_mask] # two chs in each assert_allclose(a, b) @@ -237,13 +233,13 @@ def test_nirx_hdr_load(): # Test data import assert raw._data.shape == (26, 145) - assert raw.info['sfreq'] == 12.5 + assert raw.info["sfreq"] == 12.5 @requires_testing_data def test_nirx_missing_warn(): """Test reading NIRX files when missing data.""" - with pytest.raises(FileNotFoundError, match='does not exist'): + with pytest.raises(FileNotFoundError, match="does not exist"): read_raw_nirx(fname_nirx_15_2_short / "1", preload=True) @@ -251,21 +247,25 @@ def test_nirx_missing_warn(): def test_nirx_missing_evt(tmp_path): """Test reading NIRX files when missing data.""" shutil.copytree(fname_nirx_15_2_short, str(tmp_path) + "/data/") - os.rename(tmp_path / "data" / "NIRS-2019-08-23_001.evt", - tmp_path / "data" / "NIRS-2019-08-23_001.xxx") + os.rename( + tmp_path / "data" / "NIRS-2019-08-23_001.evt", + tmp_path / "data" / "NIRS-2019-08-23_001.xxx", + ) fname = tmp_path / "data" / "NIRS-2019-08-23_001.hdr" raw = read_raw_nirx(fname, preload=True) - assert raw.annotations.onset.shape == (0, ) + assert raw.annotations.onset.shape == (0,) @requires_testing_data def test_nirx_dat_warn(tmp_path): """Test reading NIRX files when missing data.""" shutil.copytree(fname_nirx_15_2_short, str(tmp_path) + "/data/") - os.rename(tmp_path / "data" / "NIRS-2019-08-23_001.dat", - tmp_path / "data" / "NIRS-2019-08-23_001.tmp") + os.rename( + tmp_path / "data" / "NIRS-2019-08-23_001.dat", + tmp_path / "data" / "NIRS-2019-08-23_001.tmp", + ) fname = tmp_path / "data" / "NIRS-2019-08-23_001.hdr" - with pytest.warns(RuntimeWarning, match='A single dat'): + with pytest.warns(RuntimeWarning, match="A single dat"): read_raw_nirx(fname, preload=True) @@ -276,36 +276,48 @@ def test_nirx_15_2_short(): # Test data import assert raw._data.shape == (26, 145) - assert raw.info['sfreq'] == 12.5 - assert raw.info['meas_date'] == dt.datetime(2019, 8, 23, 7, 37, 4, 540000, - tzinfo=dt.timezone.utc) + assert raw.info["sfreq"] == 12.5 + assert raw.info["meas_date"] == dt.datetime( + 2019, 8, 23, 7, 37, 4, 540000, tzinfo=dt.timezone.utc + ) # Test channel naming - assert raw.info['ch_names'][:4] == ["S1_D1 760", "S1_D1 850", - "S1_D9 760", "S1_D9 850"] - assert raw.info['ch_names'][24:26] == ["S5_D13 760", "S5_D13 850"] + assert raw.info["ch_names"][:4] == [ + "S1_D1 760", + "S1_D1 850", + "S1_D9 760", + "S1_D9 850", + ] + assert raw.info["ch_names"][24:26] == ["S5_D13 760", "S5_D13 850"] # Test frequency encoding - assert raw.info['chs'][0]['loc'][9] == 760 - assert raw.info['chs'][1]['loc'][9] == 850 + assert raw.info["chs"][0]["loc"][9] == 760 + assert raw.info["chs"][1]["loc"][9] == 850 # Test info import - assert raw.info['subject_info'] == dict(sex=1, first_name="MNE", - middle_name="Test", - last_name="Recording", - birthday=(2014, 8, 23), - his_id="MNE_Test_Recording") + assert raw.info["subject_info"] == dict( + sex=1, + first_name="MNE", + middle_name="Test", + last_name="Recording", + birthday=(2014, 8, 23), + his_id="MNE_Test_Recording", + ) # Test distance between optodes matches values from # nirsite https://github.com/mne-tools/mne-testing-data/pull/51 # step 4 figure 2 allowed_distance_error = 0.0002 - assert_allclose(source_detector_distances(raw.copy(). - pick("S1_D1 760").info), - [0.0304], atol=allowed_distance_error) - assert_allclose(source_detector_distances(raw.copy(). - pick("S2_D10 760").info), - [0.0086], atol=allowed_distance_error) + assert_allclose( + source_detector_distances(raw.copy().pick("S1_D1 760").info), + [0.0304], + atol=allowed_distance_error, + ) + assert_allclose( + source_detector_distances(raw.copy().pick("S2_D10 760").info), + [0.0086], + atol=allowed_distance_error, + ) # Test which channels are short # These are the ones marked as red at @@ -318,7 +330,7 @@ def test_nirx_15_2_short(): assert_array_equal(is_short[:3:2], [True, True]) # Test trigger events - assert_array_equal(raw.annotations.description, ['3.0', '2.0', '1.0']) + assert_array_equal(raw.annotations.description, ["3.0", "2.0", "1.0"]) # Test location of detectors # The locations of detectors can be seen in the first @@ -330,37 +342,30 @@ def test_nirx_15_2_short(): # 3d locations should be specified in meters, so that's what's tested below # Detector locations are stored in the third three loc values allowed_dist_error = 0.0002 - locs = [ch['loc'][6:9] for ch in raw.info['chs']] - head_mri_t, _ = _get_trans('fsaverage', 'head', 'mri') + locs = [ch["loc"][6:9] for ch in raw.info["chs"]] + head_mri_t, _ = _get_trans("fsaverage", "head", "mri") mni_locs = apply_trans(head_mri_t, locs) - assert raw.info['ch_names'][0][3:5] == 'D1' - assert_allclose( - mni_locs[0], [-0.0841, -0.0464, -0.0129], atol=allowed_dist_error) + assert raw.info["ch_names"][0][3:5] == "D1" + assert_allclose(mni_locs[0], [-0.0841, -0.0464, -0.0129], atol=allowed_dist_error) - assert raw.info['ch_names'][4][3:5] == 'D3' - assert_allclose( - mni_locs[4], [0.0846, -0.0142, -0.0156], atol=allowed_dist_error) + assert raw.info["ch_names"][4][3:5] == "D3" + assert_allclose(mni_locs[4], [0.0846, -0.0142, -0.0156], atol=allowed_dist_error) - assert raw.info['ch_names'][8][3:5] == 'D2' - assert_allclose( - mni_locs[8], [0.0207, -0.1062, 0.0484], atol=allowed_dist_error) + assert raw.info["ch_names"][8][3:5] == "D2" + assert_allclose(mni_locs[8], [0.0207, -0.1062, 0.0484], atol=allowed_dist_error) - assert raw.info['ch_names'][12][3:5] == 'D4' - assert_allclose( - mni_locs[12], [-0.0196, 0.0821, 0.0275], atol=allowed_dist_error) + assert raw.info["ch_names"][12][3:5] == "D4" + assert_allclose(mni_locs[12], [-0.0196, 0.0821, 0.0275], atol=allowed_dist_error) - assert raw.info['ch_names'][16][3:5] == 'D5' - assert_allclose( - mni_locs[16], [-0.0360, 0.0276, 0.0778], atol=allowed_dist_error) + assert raw.info["ch_names"][16][3:5] == "D5" + assert_allclose(mni_locs[16], [-0.0360, 0.0276, 0.0778], atol=allowed_dist_error) - assert raw.info['ch_names'][19][3:5] == 'D6' - assert_allclose( - mni_locs[19], [0.0352, 0.0283, 0.0780], atol=allowed_dist_error) + assert raw.info["ch_names"][19][3:5] == "D6" + assert_allclose(mni_locs[19], [0.0352, 0.0283, 0.0780], atol=allowed_dist_error) - assert raw.info['ch_names'][21][3:5] == 'D7' - assert_allclose( - mni_locs[21], [0.0388, -0.0477, 0.0932], atol=allowed_dist_error) + assert raw.info["ch_names"][21][3:5] == "D7" + assert_allclose(mni_locs[21], [0.0388, -0.0477, 0.0932], atol=allowed_dist_error) @requires_testing_data @@ -370,34 +375,42 @@ def test_nirx_15_3_short(): # Test data import assert raw._data.shape == (26, 220) - assert raw.info['sfreq'] == 12.5 + assert raw.info["sfreq"] == 12.5 # Test channel naming - assert raw.info['ch_names'][:4] == ["S1_D2 760", "S1_D2 850", - "S1_D9 760", "S1_D9 850"] - assert raw.info['ch_names'][24:26] == ["S5_D13 760", "S5_D13 850"] + assert raw.info["ch_names"][:4] == [ + "S1_D2 760", + "S1_D2 850", + "S1_D9 760", + "S1_D9 850", + ] + assert raw.info["ch_names"][24:26] == ["S5_D13 760", "S5_D13 850"] # Test frequency encoding - assert raw.info['chs'][0]['loc'][9] == 760 - assert raw.info['chs'][1]['loc'][9] == 850 + assert raw.info["chs"][0]["loc"][9] == 760 + assert raw.info["chs"][1]["loc"][9] == 850 # Test info import - assert raw.info['subject_info'] == dict(birthday=(2020, 8, 18), - sex=0, - first_name="testMontage\\0A" - "TestMontage", - his_id="testMontage\\0A" - "TestMontage") + assert raw.info["subject_info"] == dict( + birthday=(2020, 8, 18), + sex=0, + first_name="testMontage\\0A" "TestMontage", + his_id="testMontage\\0A" "TestMontage", + ) # Test distance between optodes matches values from # https://github.com/mne-tools/mne-testing-data/pull/72 allowed_distance_error = 0.001 - assert_allclose(source_detector_distances(raw.copy(). - pick("S1_D2 760").info), - [0.0304], atol=allowed_distance_error) - assert_allclose(source_detector_distances(raw.copy(). - pick("S5_D13 760").info), - [0.0076], atol=allowed_distance_error) + assert_allclose( + source_detector_distances(raw.copy().pick("S1_D2 760").info), + [0.0304], + atol=allowed_distance_error, + ) + assert_allclose( + source_detector_distances(raw.copy().pick("S5_D13 760").info), + [0.0076], + atol=allowed_distance_error, + ) # Test which channels are short # These are the ones marked as red at @@ -410,7 +423,7 @@ def test_nirx_15_3_short(): assert_array_equal(is_short[:3:2], [True, True]) # Test trigger events - assert_array_equal(raw.annotations.description, ['4.0', '2.0', '1.0']) + assert_array_equal(raw.annotations.description, ["4.0", "2.0", "1.0"]) # Test location of detectors # The locations of detectors can be seen in the first @@ -418,70 +431,62 @@ def test_nirx_15_3_short(): # https://github.com/mne-tools/mne-testing-data/pull/72 # And have been manually copied below allowed_dist_error = 0.0002 - locs = [ch['loc'][6:9] for ch in raw.info['chs']] - head_mri_t, _ = _get_trans('fsaverage', 'head', 'mri') + locs = [ch["loc"][6:9] for ch in raw.info["chs"]] + head_mri_t, _ = _get_trans("fsaverage", "head", "mri") mni_locs = apply_trans(head_mri_t, locs) - assert raw.info['ch_names'][0][3:5] == 'D2' - assert_allclose( - mni_locs[0], [-0.0841, -0.0464, -0.0129], atol=allowed_dist_error) + assert raw.info["ch_names"][0][3:5] == "D2" + assert_allclose(mni_locs[0], [-0.0841, -0.0464, -0.0129], atol=allowed_dist_error) - assert raw.info['ch_names'][4][3:5] == 'D1' - assert_allclose( - mni_locs[4], [0.0846, -0.0142, -0.0156], atol=allowed_dist_error) + assert raw.info["ch_names"][4][3:5] == "D1" + assert_allclose(mni_locs[4], [0.0846, -0.0142, -0.0156], atol=allowed_dist_error) - assert raw.info['ch_names'][8][3:5] == 'D3' - assert_allclose( - mni_locs[8], [0.0207, -0.1062, 0.0484], atol=allowed_dist_error) + assert raw.info["ch_names"][8][3:5] == "D3" + assert_allclose(mni_locs[8], [0.0207, -0.1062, 0.0484], atol=allowed_dist_error) - assert raw.info['ch_names'][12][3:5] == 'D4' - assert_allclose( - mni_locs[12], [-0.0196, 0.0821, 0.0275], atol=allowed_dist_error) + assert raw.info["ch_names"][12][3:5] == "D4" + assert_allclose(mni_locs[12], [-0.0196, 0.0821, 0.0275], atol=allowed_dist_error) - assert raw.info['ch_names'][16][3:5] == 'D5' - assert_allclose( - mni_locs[16], [-0.0360, 0.0276, 0.0778], atol=allowed_dist_error) + assert raw.info["ch_names"][16][3:5] == "D5" + assert_allclose(mni_locs[16], [-0.0360, 0.0276, 0.0778], atol=allowed_dist_error) - assert raw.info['ch_names'][19][3:5] == 'D6' - assert_allclose( - mni_locs[19], [0.0388, -0.0477, 0.0932], atol=allowed_dist_error) + assert raw.info["ch_names"][19][3:5] == "D6" + assert_allclose(mni_locs[19], [0.0388, -0.0477, 0.0932], atol=allowed_dist_error) - assert raw.info['ch_names'][21][3:5] == 'D7' - assert_allclose( - mni_locs[21], [-0.0394, -0.0483, 0.0928], atol=allowed_dist_error) + assert raw.info["ch_names"][21][3:5] == "D7" + assert_allclose(mni_locs[21], [-0.0394, -0.0483, 0.0928], atol=allowed_dist_error) @requires_testing_data def test_locale_encoding(tmp_path): """Test NIRx encoding.""" - fname = tmp_path / 'latin' + fname = tmp_path / "latin" shutil.copytree(fname_nirx_15_2, fname) hdr_fname = fname / "NIRS-2019-10-02_003.hdr" hdr = list() - with open(hdr_fname, 'rb') as fid: + with open(hdr_fname, "rb") as fid: hdr.extend(line for line in fid) # French hdr[2] = b'Date="jeu. 13 f\xe9vr. 2020"\r\n' - with open(hdr_fname, 'wb') as fid: + with open(hdr_fname, "wb") as fid: for line in hdr: fid.write(line) - read_raw_nirx(fname, verbose='debug') + read_raw_nirx(fname, verbose="debug") # German hdr[2] = b'Date="mi 13 dez 2020"\r\n' - with open(hdr_fname, 'wb') as fid: + with open(hdr_fname, "wb") as fid: for line in hdr: fid.write(line) - read_raw_nirx(fname, verbose='debug') + read_raw_nirx(fname, verbose="debug") # Italian hdr[2] = b'Date="ven 24 gen 2020"\r\n' hdr[3] = b'Time="10:57:41.454"\r\n' - with open(hdr_fname, 'wb') as fid: + with open(hdr_fname, "wb") as fid: for line in hdr: fid.write(line) - raw = read_raw_nirx(fname, verbose='debug') - want_dt = dt.datetime( - 2020, 1, 24, 10, 57, 41, 454000, tzinfo=dt.timezone.utc) - assert raw.info['meas_date'] == want_dt + raw = read_raw_nirx(fname, verbose="debug") + want_dt = dt.datetime(2020, 1, 24, 10, 57, 41, 454000, tzinfo=dt.timezone.utc) + assert raw.info["meas_date"] == want_dt @requires_testing_data @@ -491,43 +496,49 @@ def test_nirx_15_2(): # Test data import assert raw._data.shape == (64, 67) - assert raw.info['sfreq'] == 3.90625 - assert raw.info['meas_date'] == dt.datetime(2019, 10, 2, 9, 8, 47, 511000, - tzinfo=dt.timezone.utc) + assert raw.info["sfreq"] == 3.90625 + assert raw.info["meas_date"] == dt.datetime( + 2019, 10, 2, 9, 8, 47, 511000, tzinfo=dt.timezone.utc + ) # Test channel naming - assert raw.info['ch_names'][:4] == ["S1_D1 760", "S1_D1 850", - "S1_D10 760", "S1_D10 850"] + assert raw.info["ch_names"][:4] == [ + "S1_D1 760", + "S1_D1 850", + "S1_D10 760", + "S1_D10 850", + ] # Test info import - assert raw.info['subject_info'] == dict(sex=1, first_name="TestRecording", - birthday=(1989, 10, 2), - his_id="TestRecording") + assert raw.info["subject_info"] == dict( + sex=1, + first_name="TestRecording", + birthday=(1989, 10, 2), + his_id="TestRecording", + ) # Test trigger events - assert_array_equal(raw.annotations.description, ['4.0', '6.0', '2.0']) + assert_array_equal(raw.annotations.description, ["4.0", "6.0", "2.0"]) print(raw.annotations.onset) # Test location of detectors allowed_dist_error = 0.0002 - locs = [ch['loc'][6:9] for ch in raw.info['chs']] - head_mri_t, _ = _get_trans('fsaverage', 'head', 'mri') + locs = [ch["loc"][6:9] for ch in raw.info["chs"]] + head_mri_t, _ = _get_trans("fsaverage", "head", "mri") mni_locs = apply_trans(head_mri_t, locs) - assert raw.info['ch_names'][0][3:5] == 'D1' - assert_allclose( - mni_locs[0], [-0.0292, 0.0852, -0.0142], atol=allowed_dist_error) + assert raw.info["ch_names"][0][3:5] == "D1" + assert_allclose(mni_locs[0], [-0.0292, 0.0852, -0.0142], atol=allowed_dist_error) - assert raw.info['ch_names'][15][3:5] == 'D4' - assert_allclose( - mni_locs[15], [-0.0739, -0.0756, -0.0075], atol=allowed_dist_error) + assert raw.info["ch_names"][15][3:5] == "D4" + assert_allclose(mni_locs[15], [-0.0739, -0.0756, -0.0075], atol=allowed_dist_error) # Old name aliases for backward compat - assert 'fnirs_cw_amplitude' in raw - with pytest.raises(ValueError, match='Invalid value'): - 'fnirs_raw' in raw - assert 'fnirs_od' not in raw - picks = pick_types(raw.info, fnirs='fnirs_cw_amplitude') + assert "fnirs_cw_amplitude" in raw + with pytest.raises(ValueError, match="Invalid value"): + "fnirs_raw" in raw + assert "fnirs_od" not in raw + picks = pick_types(raw.info, fnirs="fnirs_cw_amplitude") assert len(picks) > 0 @@ -547,78 +558,248 @@ def test_nirx_15_0(): # Test data import assert raw._data.shape == (20, 92) - assert raw.info['sfreq'] == 6.25 - assert raw.info['meas_date'] == dt.datetime(2019, 10, 27, 13, 53, 34, - 209000, - tzinfo=dt.timezone.utc) + assert raw.info["sfreq"] == 6.25 + assert raw.info["meas_date"] == dt.datetime( + 2019, 10, 27, 13, 53, 34, 209000, tzinfo=dt.timezone.utc + ) # Test channel naming - assert raw.info['ch_names'][:12] == ["S1_D1 760", "S1_D1 850", - "S2_D2 760", "S2_D2 850", - "S3_D3 760", "S3_D3 850", - "S4_D4 760", "S4_D4 850", - "S5_D5 760", "S5_D5 850", - "S6_D6 760", "S6_D6 850"] + assert raw.info["ch_names"][:12] == [ + "S1_D1 760", + "S1_D1 850", + "S2_D2 760", + "S2_D2 850", + "S3_D3 760", + "S3_D3 850", + "S4_D4 760", + "S4_D4 850", + "S5_D5 760", + "S5_D5 850", + "S6_D6 760", + "S6_D6 850", + ] # Test info import - assert raw.info['subject_info'] == {'birthday': (2004, 10, 27), - 'first_name': 'NIRX', - 'last_name': 'Test', - 'sex': FIFF.FIFFV_SUBJ_SEX_UNKNOWN, - 'his_id': "NIRX_Test"} + assert raw.info["subject_info"] == { + "birthday": (2004, 10, 27), + "first_name": "NIRX", + "last_name": "Test", + "sex": FIFF.FIFFV_SUBJ_SEX_UNKNOWN, + "his_id": "NIRX_Test", + } # Test trigger events - assert_array_equal(raw.annotations.description, ['1.0', '2.0', '2.0']) + assert_array_equal(raw.annotations.description, ["1.0", "2.0", "2.0"]) # Test location of detectors allowed_dist_error = 0.0002 - locs = [ch['loc'][6:9] for ch in raw.info['chs']] - head_mri_t, _ = _get_trans('fsaverage', 'head', 'mri') + locs = [ch["loc"][6:9] for ch in raw.info["chs"]] + head_mri_t, _ = _get_trans("fsaverage", "head", "mri") mni_locs = apply_trans(head_mri_t, locs) - assert raw.info['ch_names'][0][3:5] == 'D1' - assert_allclose( - mni_locs[0], [0.0287, -0.1143, -0.0332], atol=allowed_dist_error) + assert raw.info["ch_names"][0][3:5] == "D1" + assert_allclose(mni_locs[0], [0.0287, -0.1143, -0.0332], atol=allowed_dist_error) - assert raw.info['ch_names'][15][3:5] == 'D8' - assert_allclose( - mni_locs[15], [-0.0693, -0.0480, 0.0657], atol=allowed_dist_error) + assert raw.info["ch_names"][15][3:5] == "D8" + assert_allclose(mni_locs[15], [-0.0693, -0.0480, 0.0657], atol=allowed_dist_error) # Test distance between optodes matches values from allowed_distance_error = 0.0002 - assert_allclose(source_detector_distances(raw.copy(). - pick("S1_D1 760").info), - [0.0300], atol=allowed_distance_error) - assert_allclose(source_detector_distances(raw.copy(). - pick("S7_D7 760").info), - [0.0392], atol=allowed_distance_error) + assert_allclose( + source_detector_distances(raw.copy().pick("S1_D1 760").info), + [0.0300], + atol=allowed_distance_error, + ) + assert_allclose( + source_detector_distances(raw.copy().pick("S7_D7 760").info), + [0.0392], + atol=allowed_distance_error, + ) @requires_testing_data -@pytest.mark.parametrize('fname, boundary_decimal', ( - [fname_nirx_15_2_short, 1], - [fname_nirx_15_2, 0], - [fname_nirx_15_2, 0], - [nirsport2_2021_9, 0], -)) +@pytest.mark.parametrize( + "fname, boundary_decimal", + ( + [fname_nirx_15_2_short, 1], + [fname_nirx_15_2, 0], + [fname_nirx_15_2, 0], + [nirsport2_2021_9, 0], + ), +) def test_nirx_standard(fname, boundary_decimal): """Test standard operations.""" - _test_raw_reader(read_raw_nirx, fname=fname, - boundary_decimal=boundary_decimal) # low fs + _test_raw_reader( + read_raw_nirx, fname=fname, boundary_decimal=boundary_decimal + ) # low fs # Below are the native (on-disk) orders, which should be preserved @requires_testing_data -@pytest.mark.parametrize('fname, want_order', [ - (fname_nirx_15_0, ['S1_D1', 'S2_D2', 'S3_D3', 'S4_D4', 'S5_D5', 'S6_D6', 'S7_D7', 'S8_D8', 'S9_D9', 'S10_D10']), # noqa: E501 - (fname_nirx_15_2, ['S1_D1', 'S1_D10', 'S2_D1', 'S2_D2', 'S3_D2', 'S3_D3', 'S4_D3', 'S4_D4', 'S5_D4', 'S5_D5', 'S6_D5', 'S6_D6', 'S7_D6', 'S7_D7', 'S8_D7', 'S8_D8', 'S9_D8', 'S9_D9', 'S10_D9', 'S10_D10', 'S11_D11', 'S11_D12', 'S12_D12', 'S12_D13', 'S13_D13', 'S13_D14', 'S14_D14', 'S14_D15', 'S15_D15', 'S15_D16', 'S16_D11', 'S16_D16']), # noqa: E501 - (fname_nirx_15_2_short, ['S1_D1', 'S1_D9', 'S2_D3', 'S2_D10', 'S3_D2', 'S3_D11', 'S4_D4', 'S4_D12', 'S5_D5', 'S5_D6', 'S5_D7', 'S5_D8', 'S5_D13']), # noqa: E501 - (fname_nirx_15_3_short, ['S1_D2', 'S1_D9', 'S2_D1', 'S2_D10', 'S3_D3', 'S3_D11', 'S4_D4', 'S4_D12', 'S5_D5', 'S5_D6', 'S5_D7', 'S5_D8', 'S5_D13']), # noqa: E501 - (nirsport1_wo_sat, ['S1_D4', 'S1_D5', 'S1_D6', 'S2_D5', 'S2_D6', 'S3_D5', 'S4_D1', 'S4_D3', 'S4_D4', 'S5_D1', 'S5_D2', 'S6_D1', 'S6_D3']), # noqa: E501 - (nirsport2, ['S1_D1', 'S1_D6', 'S1_D9', 'S2_D2', 'S2_D10', 'S3_D5', 'S3_D7', 'S3_D11', 'S4_D8', 'S4_D12', 'S5_D3', 'S5_D13', 'S6_D4', 'S6_D14', 'S7_D1', 'S7_D6', 'S7_D15', 'S8_D5', 'S8_D7', 'S8_D16']), # noqa: E501 - (nirsport2_2021_9, ['S1_D1', 'S1_D3', 'S2_D1', 'S2_D2', 'S2_D4', 'S3_D2', 'S3_D5', 'S4_D1', 'S4_D3', 'S4_D4', 'S4_D6', 'S5_D2', 'S5_D4', 'S5_D5', 'S5_D7', 'S6_D3', 'S6_D6', 'S7_D4', 'S7_D6', 'S7_D7', 'S8_D5', 'S8_D7']), # noqa: E501 -]) +@pytest.mark.parametrize( + "fname, want_order", + [ + ( + fname_nirx_15_0, + [ + "S1_D1", + "S2_D2", + "S3_D3", + "S4_D4", + "S5_D5", + "S6_D6", + "S7_D7", + "S8_D8", + "S9_D9", + "S10_D10", + ], + ), # noqa: E501 + ( + fname_nirx_15_2, + [ + "S1_D1", + "S1_D10", + "S2_D1", + "S2_D2", + "S3_D2", + "S3_D3", + "S4_D3", + "S4_D4", + "S5_D4", + "S5_D5", + "S6_D5", + "S6_D6", + "S7_D6", + "S7_D7", + "S8_D7", + "S8_D8", + "S9_D8", + "S9_D9", + "S10_D9", + "S10_D10", + "S11_D11", + "S11_D12", + "S12_D12", + "S12_D13", + "S13_D13", + "S13_D14", + "S14_D14", + "S14_D15", + "S15_D15", + "S15_D16", + "S16_D11", + "S16_D16", + ], + ), # noqa: E501 + ( + fname_nirx_15_2_short, + [ + "S1_D1", + "S1_D9", + "S2_D3", + "S2_D10", + "S3_D2", + "S3_D11", + "S4_D4", + "S4_D12", + "S5_D5", + "S5_D6", + "S5_D7", + "S5_D8", + "S5_D13", + ], + ), # noqa: E501 + ( + fname_nirx_15_3_short, + [ + "S1_D2", + "S1_D9", + "S2_D1", + "S2_D10", + "S3_D3", + "S3_D11", + "S4_D4", + "S4_D12", + "S5_D5", + "S5_D6", + "S5_D7", + "S5_D8", + "S5_D13", + ], + ), # noqa: E501 + ( + nirsport1_wo_sat, + [ + "S1_D4", + "S1_D5", + "S1_D6", + "S2_D5", + "S2_D6", + "S3_D5", + "S4_D1", + "S4_D3", + "S4_D4", + "S5_D1", + "S5_D2", + "S6_D1", + "S6_D3", + ], + ), # noqa: E501 + ( + nirsport2, + [ + "S1_D1", + "S1_D6", + "S1_D9", + "S2_D2", + "S2_D10", + "S3_D5", + "S3_D7", + "S3_D11", + "S4_D8", + "S4_D12", + "S5_D3", + "S5_D13", + "S6_D4", + "S6_D14", + "S7_D1", + "S7_D6", + "S7_D15", + "S8_D5", + "S8_D7", + "S8_D16", + ], + ), # noqa: E501 + ( + nirsport2_2021_9, + [ + "S1_D1", + "S1_D3", + "S2_D1", + "S2_D2", + "S2_D4", + "S3_D2", + "S3_D5", + "S4_D1", + "S4_D3", + "S4_D4", + "S4_D6", + "S5_D2", + "S5_D4", + "S5_D5", + "S5_D7", + "S6_D3", + "S6_D6", + "S7_D4", + "S7_D6", + "S7_D7", + "S8_D5", + "S8_D7", + ], + ), # noqa: E501 + ], +) def test_channel_order(fname, want_order): """Test that logical channel order is preserved.""" raw = read_raw_nirx(fname) diff --git a/mne/io/open.py b/mne/io/open.py index e3b83c31fb1..c3d2cd2a294 100644 --- a/mne/io/open.py +++ b/mne/io/open.py @@ -44,11 +44,11 @@ def _fiff_get_fid(fname): fid.seek(0) else: fname = str(fname) - if op.splitext(fname)[1].lower() == '.gz': - logger.debug('Using gzip') + if op.splitext(fname)[1].lower() == ".gz": + logger.debug("Using gzip") fid = GzipFile(fname, "rb") # Open in binary mode else: - logger.debug('Using normal I/O') + logger.debug("Using normal I/O") fid = open(fname, "rb") # Open in binary mode return fid @@ -59,7 +59,7 @@ def _get_next_fname(fid, fname, tree): next_fname = None for nodes in nodes_list: next_fname = None - for ent in nodes['directory']: + for ent in nodes["directory"]: if ent.kind == FIFF.FIFF_REF_ROLE: tag = read_tag(fid, ent.pos) role = int(tag.data.item()) @@ -76,21 +76,22 @@ def _get_next_fname(fid, fname, tree): continue next_num = read_tag(fid, ent.pos).data.item() path, base = op.split(fname) - idx = base.find('.') - idx2 = base.rfind('-') - num_str = base[idx2 + 1:idx] + idx = base.find(".") + idx2 = base.rfind("-") + num_str = base[idx2 + 1 : idx] if not num_str.isdigit(): idx2 = -1 if idx2 < 0 and next_num == 1: # this is the first file, which may not be numbered next_fname = op.join( - path, '%s-%d.%s' % (base[:idx], next_num, - base[idx + 1:])) + path, "%s-%d.%s" % (base[:idx], next_num, base[idx + 1 :]) + ) continue - next_fname = op.join(path, '%s-%d.%s' - % (base[:idx2], next_num, base[idx + 1:])) + next_fname = op.join( + path, "%s-%d.%s" % (base[:idx2], next_num, base[idx + 1 :]) + ) if next_fname is not None: break return next_fname @@ -139,31 +140,33 @@ def _fiff_open(fname, fid, preload): tag = read_tag_info(fid) # Check that this looks like a fif file - prefix = f'file {repr(fname)} does not' + prefix = f"file {repr(fname)} does not" if tag.kind != FIFF.FIFF_FILE_ID: - raise ValueError(f'{prefix} start with a file id tag') + raise ValueError(f"{prefix} start with a file id tag") if tag.type != FIFF.FIFFT_ID_STRUCT: - raise ValueError(f'{prefix} start with a file id tag') + raise ValueError(f"{prefix} start with a file id tag") if tag.size != 20: - raise ValueError(f'{prefix} start with a file id tag') + raise ValueError(f"{prefix} start with a file id tag") tag = read_tag(fid) if tag.kind != FIFF.FIFF_DIR_POINTER: - raise ValueError(f'{prefix} have a directory pointer') + raise ValueError(f"{prefix} have a directory pointer") # Read or create the directory tree - logger.debug(' Creating tag directory for %s...' % fname) + logger.debug(" Creating tag directory for %s..." % fname) dirpos = int(tag.data.item()) read_slow = True if dirpos > 0: dir_tag = read_tag(fid, dirpos) if dir_tag is None: - warn(f'FIF tag directory missing at the end of the file, possibly ' - f'corrupted file: {fname}') + warn( + f"FIF tag directory missing at the end of the file, possibly " + f"corrupted file: {fname}" + ) else: directory = dir_tag.data read_slow = False @@ -181,7 +184,7 @@ def _fiff_open(fname, fid, preload): tree, _ = make_dir_tree(fid, directory) - logger.debug('[done]') + logger.debug("[done]") # Back to the beginning fid.seek(0) @@ -190,8 +193,15 @@ def _fiff_open(fname, fid, preload): @verbose -def show_fiff(fname, indent=' ', read_limit=np.inf, max_str=30, - output=str, tag=None, verbose=None): +def show_fiff( + fname, + indent=" ", + read_limit=np.inf, + max_str=30, + output=str, + tag=None, + verbose=None, +): """Show FIFF information. This function is similar to mne_show_fiff. @@ -221,53 +231,68 @@ def show_fiff(fname, indent=' ', read_limit=np.inf, max_str=30, The contents of the file. """ if output not in [list, str]: - raise ValueError('output must be list or str') + raise ValueError("output must be list or str") if isinstance(tag, str): # command mne show_fiff passes string tag = int(tag) f, tree, directory = fiff_open(fname) # This gets set to 0 (unknown) by fiff_open, but FIFFB_ROOT probably # makes more sense for display - tree['block'] = FIFF.FIFFB_ROOT + tree["block"] = FIFF.FIFFB_ROOT with f as fid: - out = _show_tree(fid, tree, indent=indent, level=0, - read_limit=read_limit, max_str=max_str, tag_id=tag) + out = _show_tree( + fid, + tree, + indent=indent, + level=0, + read_limit=read_limit, + max_str=max_str, + tag_id=tag, + ) if output == str: - out = '\n'.join(out) + out = "\n".join(out) return out -def _find_type(value, fmts=['FIFF_'], exclude=['FIFF_UNIT']): +def _find_type(value, fmts=["FIFF_"], exclude=["FIFF_UNIT"]): """Find matching values.""" value = int(value) - vals = [k for k, v in FIFF.items() - if v == value and any(fmt in k for fmt in fmts) and - not any(exc in k for exc in exclude)] + vals = [ + k + for k, v in FIFF.items() + if v == value + and any(fmt in k for fmt in fmts) + and not any(exc in k for exc in exclude) + ] if len(vals) == 0: - vals = ['???'] + vals = ["???"] return vals def _show_tree(fid, tree, indent, level, read_limit, max_str, tag_id): """Show FIFF tree.""" from scipy import sparse + this_idt = indent * level next_idt = indent * (level + 1) # print block-level information - out = [this_idt + str(int(tree['block'])) + ' = ' + - '/'.join(_find_type(tree['block'], fmts=['FIFFB_']))] + out = [ + this_idt + + str(int(tree["block"])) + + " = " + + "/".join(_find_type(tree["block"], fmts=["FIFFB_"])) + ] tag_found = False if tag_id is None or out[0].strip().startswith(str(tag_id)): tag_found = True - if tree['directory'] is not None: - kinds = [ent.kind for ent in tree['directory']] + [-1] - types = [ent.type for ent in tree['directory']] - sizes = [ent.size for ent in tree['directory']] - poss = [ent.pos for ent in tree['directory']] + if tree["directory"] is not None: + kinds = [ent.kind for ent in tree["directory"]] + [-1] + types = [ent.type for ent in tree["directory"]] + sizes = [ent.size for ent in tree["directory"]] + poss = [ent.pos for ent in tree["directory"]] counter = 0 good = True - for k, kn, size, pos, type_ in zip(kinds[:-1], kinds[1:], sizes, poss, - types): + for k, kn, size, pos, type_ in zip(kinds[:-1], kinds[1:], sizes, poss, types): if not tag_found and k != tag_id: continue tag = Tag(k, size, 0, pos) @@ -282,43 +307,51 @@ def _show_tree(fid, tree, indent, level, read_limit, max_str, tag_id): counter += 1 else: # find the tag type - this_type = _find_type(k, fmts=['FIFF_']) + this_type = _find_type(k, fmts=["FIFF_"]) # prepend a count if necessary - prepend = 'x' + str(counter + 1) + ': ' if counter > 0 else '' - postpend = '' + prepend = "x" + str(counter + 1) + ": " if counter > 0 else "" + postpend = "" # print tag data nicely if tag.data is not None: - postpend = ' = ' + str(tag.data)[:max_str] + postpend = " = " + str(tag.data)[:max_str] if isinstance(tag.data, np.ndarray): if tag.data.size > 1: - postpend += ' ... array size=' + str(tag.data.size) + postpend += " ... array size=" + str(tag.data.size) elif isinstance(tag.data, dict): - postpend += ' ... dict len=' + str(len(tag.data)) + postpend += " ... dict len=" + str(len(tag.data)) elif isinstance(tag.data, str): - postpend += ' ... str len=' + str(len(tag.data)) + postpend += " ... str len=" + str(len(tag.data)) elif isinstance(tag.data, (list, tuple)): - postpend += ' ... list len=' + str(len(tag.data)) + postpend += " ... list len=" + str(len(tag.data)) elif sparse.issparse(tag.data): - postpend += (' ... sparse (%s) shape=%s' - % (tag.data.getformat(), tag.data.shape)) + postpend += " ... sparse (%s) shape=%s" % ( + tag.data.getformat(), + tag.data.shape, + ) else: - postpend += ' ... type=' + str(type(tag.data)) - postpend = '>' * 20 + 'BAD' if not good else postpend - type_ = _call_dict_names.get(type_, '?%s?' % (type_,)) - out += [next_idt + prepend + str(k) + ' = ' + - '/'.join(this_type) + - ' (' + str(size) + 'b %s)' % type_ + - postpend] - out[-1] = out[-1].replace('\n', '¶') + postpend += " ... type=" + str(type(tag.data)) + postpend = ">" * 20 + "BAD" if not good else postpend + type_ = _call_dict_names.get(type_, "?%s?" % (type_,)) + out += [ + next_idt + + prepend + + str(k) + + " = " + + "/".join(this_type) + + " (" + + str(size) + + "b %s)" % type_ + + postpend + ] + out[-1] = out[-1].replace("\n", "¶") counter = 0 good = True if tag_id in kinds: tag_found = True if not tag_found: - out = [''] + out = [""] level = -1 # removes extra indent # deal with children - for branch in tree['children']: - out += _show_tree(fid, branch, indent, level + 1, read_limit, max_str, - tag_id) + for branch in tree["children"]: + out += _show_tree(fid, branch, indent, level + 1, read_limit, max_str, tag_id) return out diff --git a/mne/io/persyst/persyst.py b/mne/io/persyst/persyst.py index a9897d37e67..da2ab798591 100644 --- a/mne/io/persyst/persyst.py +++ b/mne/io/persyst/persyst.py @@ -67,16 +67,17 @@ class RawPersyst(BaseRaw): @verbose def __init__(self, fname, preload=False, verbose=None): fname = str(_check_fname(fname, "read", True, "fname")) - logger.info('Loading %s' % fname) + logger.info("Loading %s" % fname) # make sure filename is the Lay file - if not fname.endswith('.lay'): - fname = fname + '.lay' + if not fname.endswith(".lay"): + fname = fname + ".lay" # get the current directory and Lay filename curr_path, lay_fname = op.dirname(fname), op.basename(fname) if not op.exists(fname): - raise FileNotFoundError(f'The path you specified, ' - f'"{lay_fname}",does not exist.') + raise FileNotFoundError( + f"The path you specified, " f'"{lay_fname}",does not exist.' + ) # sections and subsections currently unused keys, data, sections = _read_lay_contents(fname) @@ -93,50 +94,51 @@ def __init__(self, fname, preload=False, verbose=None): # loop through each line in the lay file for key, val, section in zip(keys, data, sections): - if key == '': + if key == "": continue # Make sure key are lowercase for everything, but electrodes. # We also do not want to lower-case comments because those # are free-form text where casing may matter. - if key is not None and section not in ['channelmap', - 'comments']: + if key is not None and section not in ["channelmap", "comments"]: key = key.lower() # FileInfo - if section == 'fileinfo': + if section == "fileinfo": # extract the .dat file name - if key == 'file': + if key == "file": dat_fname = op.basename(val) dat_fpath = op.join(curr_path, op.basename(dat_fname)) # determine if .dat file exists where it should - error_msg = f'The data path you specified ' \ - f'does not exist for the lay path, ' \ - f'{lay_fname}. Make sure the dat file ' \ - f'is in the same directory as the lay ' \ - f'file, and the specified dat filename ' \ - f'matches.' + error_msg = ( + f"The data path you specified " + f"does not exist for the lay path, " + f"{lay_fname}. Make sure the dat file " + f"is in the same directory as the lay " + f"file, and the specified dat filename " + f"matches." + ) if not op.exists(dat_fpath): raise FileNotFoundError(error_msg) fileinfo_dict[key] = val # ChannelMap - elif section == 'channelmap': + elif section == "channelmap": # channel map has = for = channelmap_dict[key] = val # Patient (All optional) - elif section == 'patient': + elif section == "patient": patient_dict[key] = val # Comments (turned into mne.Annotations) - elif section == 'comments': + elif section == "comments": comments_dict[key] = comments_dict.get(key, list()) + [val] num_comments += 1 # get numerical metadata # datatype is either 7 for 32 bit, or 0 for 16 bit - datatype = fileinfo_dict.get('datatype') - cal = float(fileinfo_dict.get('calibration')) - n_chs = int(fileinfo_dict.get('waveformcount')) + datatype = fileinfo_dict.get("datatype") + cal = float(fileinfo_dict.get("calibration")) + n_chs = int(fileinfo_dict.get("waveformcount")) # Store subject information from lay file in mne format # Note: Persyst also records "Physician", "Technician", @@ -145,97 +147,103 @@ def __init__(self, fname, preload=False, verbose=None): subject_info = _get_subjectinfo(patient_dict) # set measurement date - testdate = patient_dict.get('testdate') + testdate = patient_dict.get("testdate") if testdate is not None: # TODO: Persyst may change its internal date schemas # without notice # These are the 3 "so far" possible datatime storage # formats in Persyst .lay - if '/' in testdate: - testdate = datetime.strptime(testdate, '%m/%d/%Y') - elif '-' in testdate: - testdate = datetime.strptime(testdate, '%d-%m-%Y') - elif '.' in testdate: - testdate = datetime.strptime(testdate, '%Y.%m.%d') + if "/" in testdate: + testdate = datetime.strptime(testdate, "%m/%d/%Y") + elif "-" in testdate: + testdate = datetime.strptime(testdate, "%d-%m-%Y") + elif "." in testdate: + testdate = datetime.strptime(testdate, "%Y.%m.%d") if not isinstance(testdate, datetime): - warn('Cannot read in the measurement date due ' - 'to incompatible format. Please set manually ' - 'for %s ' % lay_fname) + warn( + "Cannot read in the measurement date due " + "to incompatible format. Please set manually " + "for %s " % lay_fname + ) meas_date = None else: - testtime = datetime.strptime(patient_dict.get('testtime'), - '%H:%M:%S') + testtime = datetime.strptime(patient_dict.get("testtime"), "%H:%M:%S") meas_date = datetime( - year=testdate.year, month=testdate.month, - day=testdate.day, hour=testtime.hour, - minute=testtime.minute, second=testtime.second, - tzinfo=timezone.utc) + year=testdate.year, + month=testdate.month, + day=testdate.day, + hour=testtime.hour, + minute=testtime.minute, + second=testtime.second, + tzinfo=timezone.utc, + ) # Create mne structure ch_names = list(channelmap_dict.keys()) if n_chs != len(ch_names): - raise RuntimeError('Channels in lay file do not ' - 'match the number of channels ' - 'in the .dat file.') # noqa + raise RuntimeError( + "Channels in lay file do not " + "match the number of channels " + "in the .dat file." + ) # noqa # get rid of the "-Ref" in channel names - ch_names = [ch.upper().split('-REF')[0] for ch in ch_names] + ch_names = [ch.upper().split("-REF")[0] for ch in ch_names] # get the sampling rate and default channel types to EEG - sfreq = fileinfo_dict.get('samplingrate') - ch_types = 'eeg' + sfreq = fileinfo_dict.get("samplingrate") + ch_types = "eeg" info = create_info(ch_names, sfreq, ch_types=ch_types) info.update(subject_info=subject_info) with info._unlock(): for idx in range(n_chs): # calibration brings to uV then 1e-6 brings to V - info['chs'][idx]['cal'] = cal * 1.0e-6 - info['meas_date'] = meas_date + info["chs"][idx]["cal"] = cal * 1.0e-6 + info["meas_date"] = meas_date # determine number of samples in file # Note: We do not use the lay file to do this # because clips in time may be generated by Persyst that # DO NOT modify the "SampleTimes" section - with open(dat_fpath, 'rb') as f: + with open(dat_fpath, "rb") as f: # determine the precision if int(datatype) == 7: # 32 bit - dtype = np.dtype('i4') + dtype = np.dtype("i4") elif int(datatype) == 0: # 16 bit - dtype = np.dtype('i2') + dtype = np.dtype("i2") else: - raise RuntimeError(f'Unknown format: {datatype}') + raise RuntimeError(f"Unknown format: {datatype}") # allow offset to occur f.seek(0, os.SEEK_END) n_samples = f.tell() n_samples = n_samples // (dtype.itemsize * n_chs) - logger.debug(f'Loaded {n_samples} samples ' - f'for {n_chs} channels.') + logger.debug(f"Loaded {n_samples} samples " f"for {n_chs} channels.") - raw_extras = { - 'dtype': dtype, - 'n_chs': n_chs, - 'n_samples': n_samples - } + raw_extras = {"dtype": dtype, "n_chs": n_chs, "n_samples": n_samples} # create Raw object super(RawPersyst, self).__init__( - info, preload, filenames=[dat_fpath], + info, + preload, + filenames=[dat_fpath], last_samps=[n_samples - 1], - raw_extras=[raw_extras], verbose=verbose) + raw_extras=[raw_extras], + verbose=verbose, + ) # set annotations based on the comments read in onset = np.zeros(num_comments, float) duration = np.zeros(num_comments, float) - description = [''] * num_comments + description = [""] * num_comments # loop through comments dictionary, which may contain # multiple events for the same "text" annotation t_idx = 0 for _description, event_tuples in comments_dict.items(): - for (_onset, _duration) in event_tuples: + for _onset, _duration in event_tuples: # extract the onset, duration, description to # create an Annotations object onset[t_idx] = _onset @@ -252,8 +260,8 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): binary files. In addition, it stores the calibration to convert data to uV in the lay file. """ - dtype = self._raw_extras[fi]['dtype'] - n_chs = self._raw_extras[fi]['n_chs'] + dtype = self._raw_extras[fi]["dtype"] + n_chs = self._raw_extras[fi]["n_chs"] dat_fname = self._filenames[fi] # compute samples count based on start and stop @@ -264,17 +272,16 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): count = time_length_samps * n_chs # seek the dat file - with open(dat_fname, 'rb') as dat_file_ID: + with open(dat_fname, "rb") as dat_file_ID: # allow offset to occur dat_file_ID.seek(n_chs * dtype.itemsize * start, 1) # read in the actual record starting at possibly offset - record = np.fromfile(dat_file_ID, dtype=dtype, - count=count) + record = np.fromfile(dat_file_ID, dtype=dtype, count=count) # chs * rows # cast as float32; more than enough precision - record = np.reshape(record, (n_chs, -1), 'F').astype(np.float32) + record = np.reshape(record, (n_chs, -1), "F").astype(np.float32) # calibrate to convert to V and handle mult _mult_cal_one(data, record, idx, cals, mult) @@ -283,28 +290,28 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): def _get_subjectinfo(patient_dict): # attempt to parse out the birthdate, but if it doesn't # meet spec, then it will set to None - birthdate = patient_dict.get('birthdate') - if '/' in birthdate: + birthdate = patient_dict.get("birthdate") + if "/" in birthdate: try: - birthdate = datetime.strptime(birthdate, '%m/%d/%y') + birthdate = datetime.strptime(birthdate, "%m/%d/%y") except ValueError: birthdate = None - print('Unable to process birthdate of %s ' % birthdate) - elif '-' in birthdate: + print("Unable to process birthdate of %s " % birthdate) + elif "-" in birthdate: try: - birthdate = datetime.strptime(birthdate, '%d-%m-%y') + birthdate = datetime.strptime(birthdate, "%d-%m-%y") except ValueError: birthdate = None - print('Unable to process birthdate of %s ' % birthdate) + print("Unable to process birthdate of %s " % birthdate) subject_info = { - 'first_name': patient_dict.get('first'), - 'middle_name': patient_dict.get('middle'), - 'last_name': patient_dict.get('last'), - 'sex': patient_dict.get('sex'), - 'hand': patient_dict.get('hand'), - 'his_id': patient_dict.get('id'), - 'birthday': birthdate, + "first_name": patient_dict.get("first"), + "middle_name": patient_dict.get("middle"), + "last_name": patient_dict.get("last"), + "sex": patient_dict.get("sex"), + "hand": patient_dict.get("hand"), + "his_id": patient_dict.get("id"), + "birthday": birthdate, } # Recode sex values @@ -314,8 +321,7 @@ def _get_subjectinfo(patient_dict): f=FIFF.FIFFV_SUBJ_SEX_FEMALE, female=FIFF.FIFFV_SUBJ_SEX_FEMALE, ) - subject_info['sex'] = sex_dict.get(subject_info['sex'], - FIFF.FIFFV_SUBJ_SEX_UNKNOWN) + subject_info["sex"] = sex_dict.get(subject_info["sex"], FIFF.FIFFV_SUBJ_SEX_UNKNOWN) # Recode hand values hand_dict = dict( @@ -329,9 +335,9 @@ def _get_subjectinfo(patient_dict): ) # no handedness is set when unknown try: - subject_info['hand'] = hand_dict[subject_info['hand']] + subject_info["hand"] = hand_dict[subject_info["hand"]] except KeyError: - subject_info.pop('hand') + subject_info.pop("hand") return subject_info @@ -343,8 +349,8 @@ def _read_lay_contents(fname): keys, data = [], [] # initialize all section to empty str - section = '' - with open(fname, 'r') as fin: + section = "" + with open(fname, "r") as fin: for line in fin: # break a line into a status, key and value status, key, val = _process_lay_line(line, section) @@ -420,19 +426,18 @@ def _process_lay_line(line, section): 4. variable type (unused) 5. free-form text describing the annotation """ - key = '' # default; only return value possibly not set + key = "" # default; only return value possibly not set line = line.strip() # remove leading and trailing spaces end_idx = len(line) - 1 # get the last index of the line # empty sequence evaluates to false if not line: status = 0 - key = '' - value = '' + key = "" + value = "" return status, key, value # section found - elif (line[0] == '[') and (line[end_idx] == ']') \ - and (end_idx + 1 >= 3): + elif (line[0] == "[") and (line[end_idx] == "]") and (end_idx + 1 >= 3): status = 1 value = line[1:end_idx].lower() # key found @@ -440,25 +445,27 @@ def _process_lay_line(line, section): # handle Comments section differently from all other sections # TODO: utilize state and var_type in code. # Currently not used - if section == 'comments': + if section == "comments": # Persyst Comments output 5 variables "," separated - time_sec, duration, state, var_type, text = line.split(',', 4) + time_sec, duration, state, var_type, text = line.split(",", 4) status = 2 key = text value = (time_sec, duration) # all other sections else: - if '=' not in line: - raise RuntimeError('The line %s does not conform ' - 'to the standards. Please check the ' - '.lay file.' % line) # noqa - pos = line.index('=') + if "=" not in line: + raise RuntimeError( + "The line %s does not conform " + "to the standards. Please check the " + ".lay file." % line + ) # noqa + pos = line.index("=") status = 2 # the line now is composed of a # = key = line[0:pos] key.strip() - value = line[pos + 1:end_idx + 1] + value = line[pos + 1 : end_idx + 1] value.strip() return status, key, value diff --git a/mne/io/persyst/tests/test_persyst.py b/mne/io/persyst/tests/test_persyst.py index 4d11c728398..734a7b3011c 100644 --- a/mne/io/persyst/tests/test_persyst.py +++ b/mne/io/persyst/tests/test_persyst.py @@ -15,14 +15,10 @@ testing_path = data_path(download=False) fname_lay = ( - testing_path - / "Persyst" - / "sub-pt1_ses-02_task-monitor_acq-ecog_run-01_clip2.lay" + testing_path / "Persyst" / "sub-pt1_ses-02_task-monitor_acq-ecog_run-01_clip2.lay" ) fname_dat = ( - testing_path - / "Persyst" - / "sub-pt1_ses-02_task-monitor_acq-ecog_run-01_clip2.dat" + testing_path / "Persyst" / "sub-pt1_ses-02_task-monitor_acq-ecog_run-01_clip2.dat" ) @@ -32,7 +28,7 @@ def test_persyst_lay_load(): raw = read_raw_persyst(fname_lay, preload=False) # Test data import - assert raw.info['sfreq'] == 200 + assert raw.info["sfreq"] == 200 assert raw.preload is False # load raw data @@ -45,8 +41,7 @@ def test_persyst_lay_load(): assert len(raw.ch_names) == 83 # no "-Ref" in channel names - assert all(['-ref' not in ch.lower() - for ch in raw.ch_names]) + assert all(["-ref" not in ch.lower() for ch in raw.ch_names]) # test with preload True raw = read_raw_persyst(fname_lay, preload=True) @@ -90,39 +85,39 @@ def test_persyst_dates(tmp_path): # reformat the lay file to have testdate with # "/" character with open(fname_lay, "r") as fin: - with open(new_fname_lay, 'w') as fout: + with open(new_fname_lay, "w") as fout: # for each line in the input file for idx, line in enumerate(fin): - if line.startswith('TestDate'): - line = 'TestDate=01/23/2000\n' + if line.startswith("TestDate"): + line = "TestDate=01/23/2000\n" fout.write(line) # file should update correctly with datetime raw = read_raw_persyst(new_fname_lay) - assert raw.info['meas_date'].month == 1 - assert raw.info['meas_date'].day == 23 - assert raw.info['meas_date'].year == 2000 + assert raw.info["meas_date"].month == 1 + assert raw.info["meas_date"].day == 23 + assert raw.info["meas_date"].year == 2000 # reformat the lay file to have testdate with # "-" character os.remove(new_fname_lay) with open(fname_lay, "r") as fin: - with open(new_fname_lay, 'w') as fout: + with open(new_fname_lay, "w") as fout: # for each line in the input file for idx, line in enumerate(fin): - if line.startswith('TestDate'): - line = 'TestDate=24-01-2000\n' + if line.startswith("TestDate"): + line = "TestDate=24-01-2000\n" fout.write(line) # file should update correctly with datetime raw = read_raw_persyst(new_fname_lay) - assert raw.info['meas_date'].month == 1 - assert raw.info['meas_date'].day == 24 - assert raw.info['meas_date'].year == 2000 + assert raw.info["meas_date"].month == 1 + assert raw.info["meas_date"].day == 24 + assert raw.info["meas_date"].year == 2000 @requires_testing_data def test_persyst_wrong_file(tmp_path): """Test reading Persyst files when passed in wrong file path.""" - with pytest.raises(FileNotFoundError, match='The path you'): + with pytest.raises(FileNotFoundError, match="The path you"): read_raw_persyst(fname_dat, preload=True) new_fname_lay = tmp_path / fname_lay.name @@ -130,10 +125,11 @@ def test_persyst_wrong_file(tmp_path): shutil.copy(fname_lay, new_fname_lay) # without a .dat file, reader should break - desired_err_msg = \ - 'The data path you specified does ' \ - 'not exist for the lay path, ' \ - 'sub-pt1_ses-02_task-monitor_acq-ecog_run-01_clip2.lay' + desired_err_msg = ( + "The data path you specified does " + "not exist for the lay path, " + "sub-pt1_ses-02_task-monitor_acq-ecog_run-01_clip2.lay" + ) with pytest.raises(FileNotFoundError, match=desired_err_msg): read_raw_persyst(new_fname_lay, preload=True) @@ -154,10 +150,11 @@ def test_persyst_moved_file(tmp_path): # without a .dat file, reader should break # when the lay file was moved - desired_err_msg = \ - 'The data path you specified does ' \ - 'not exist for the lay path, ' \ - 'sub-pt1_ses-02_task-monitor_acq-ecog_run-01_clip2.lay' + desired_err_msg = ( + "The data path you specified does " + "not exist for the lay path, " + "sub-pt1_ses-02_task-monitor_acq-ecog_run-01_clip2.lay" + ) with pytest.raises(FileNotFoundError, match=desired_err_msg): read_raw_persyst(new_fname_lay, preload=True) @@ -166,13 +163,13 @@ def test_persyst_moved_file(tmp_path): # as reader requires lay and dat file to be in # same directory with open(fname_lay, "r") as fin: - with open(new_fname_lay, 'w') as fout: + with open(new_fname_lay, "w") as fout: # for each line in the input file for idx, line in enumerate(fin): - if line.startswith('File='): + if line.startswith("File="): # give it the full path to the old data - test_fpath = fname_dat.parent / line.split('=')[1] - line = f'File={test_fpath}\n' + test_fpath = fname_dat.parent / line.split("=")[1] + line = f"File={test_fpath}\n" fout.write(line) with pytest.raises(FileNotFoundError, match=desired_err_msg): read_raw_persyst(new_fname_lay, preload=True) @@ -203,11 +200,11 @@ def test_persyst_annotations(tmp_path): # get the annotations and make sure that repeated annotations # are in the dataset annotations = raw.annotations - assert np.count_nonzero(annotations.description == 'seizure') == 2 + assert np.count_nonzero(annotations.description == "seizure") == 2 # make sure annotation with a "," character is in there - assert 'seizure1,2' in annotations.description - assert 'CLip2' in annotations.description + assert "seizure1,2" in annotations.description + assert "CLip2" in annotations.description @requires_testing_data @@ -219,42 +216,40 @@ def test_persyst_errors(tmp_path): # reformat the lay file with open(fname_lay, "r") as fin: - with open(new_fname_lay, 'w') as fout: + with open(new_fname_lay, "w") as fout: # for each line in the input file for idx, line in enumerate(fin): if idx == 1: - line = line.replace('=', ',') + line = line.replace("=", ",") fout.write(line) # file should break - with pytest.raises(RuntimeError, match='The line'): + with pytest.raises(RuntimeError, match="The line"): read_raw_persyst(new_fname_lay) # reformat the lay file os.remove(new_fname_lay) with open(fname_lay, "r") as fin: - with open(new_fname_lay, 'w') as fout: + with open(new_fname_lay, "w") as fout: # for each line in the input file for idx, line in enumerate(fin): - if line.startswith('WaveformCount'): - line = 'WaveformCount=1\n' + if line.startswith("WaveformCount"): + line = "WaveformCount=1\n" fout.write(line) # file should break - with pytest.raises(RuntimeError, match='Channels in lay ' - 'file do not'): + with pytest.raises(RuntimeError, match="Channels in lay " "file do not"): read_raw_persyst(new_fname_lay) # reformat the lay file to have testdate # improperly specified os.remove(new_fname_lay) with open(fname_lay, "r") as fin: - with open(new_fname_lay, 'w') as fout: + with open(new_fname_lay, "w") as fout: # for each line in the input file for idx, line in enumerate(fin): - if line.startswith('TestDate'): - line = 'TestDate=Jan 23rd 2000\n' + if line.startswith("TestDate"): + line = "TestDate=Jan 23rd 2000\n" fout.write(line) # file should not read in meas date - with pytest.warns(RuntimeWarning, - match='Cannot read in the measurement date'): + with pytest.warns(RuntimeWarning, match="Cannot read in the measurement date"): raw = read_raw_persyst(new_fname_lay) - assert raw.info['meas_date'] is None + assert raw.info["meas_date"] is None diff --git a/mne/io/pick.py b/mne/io/pick.py index e5156974c2c..2cf461d551e 100644 --- a/mne/io/pick.py +++ b/mne/io/pick.py @@ -10,8 +10,16 @@ import numpy as np from .constants import FIFF -from ..utils import (logger, verbose, _validate_type, fill_doc, _ensure_int, - _check_option, warn, deprecated) +from ..utils import ( + logger, + verbose, + _validate_type, + fill_doc, + _ensure_int, + _check_option, + warn, + deprecated, +) def get_channel_type_constants(include_defaults=False): @@ -39,78 +47,99 @@ def get_channel_type_constants(include_defaults=False): (depending on the recording system), so no "coil_type" entry is given for "ref_meg" unless ``include_defaults`` is requested. """ - base = dict(grad=dict(kind=FIFF.FIFFV_MEG_CH, unit=FIFF.FIFF_UNIT_T_M), - mag=dict(kind=FIFF.FIFFV_MEG_CH, unit=FIFF.FIFF_UNIT_T), - ref_meg=dict(kind=FIFF.FIFFV_REF_MEG_CH), - eeg=dict(kind=FIFF.FIFFV_EEG_CH, - unit=FIFF.FIFF_UNIT_V, - coil_type=FIFF.FIFFV_COIL_EEG), - seeg=dict(kind=FIFF.FIFFV_SEEG_CH, - unit=FIFF.FIFF_UNIT_V, - coil_type=FIFF.FIFFV_COIL_EEG), - dbs=dict(kind=FIFF.FIFFV_DBS_CH, - unit=FIFF.FIFF_UNIT_V, - coil_type=FIFF.FIFFV_COIL_EEG), - ecog=dict(kind=FIFF.FIFFV_ECOG_CH, - unit=FIFF.FIFF_UNIT_V, - coil_type=FIFF.FIFFV_COIL_EEG), - eog=dict(kind=FIFF.FIFFV_EOG_CH, unit=FIFF.FIFF_UNIT_V), - emg=dict(kind=FIFF.FIFFV_EMG_CH, unit=FIFF.FIFF_UNIT_V), - ecg=dict(kind=FIFF.FIFFV_ECG_CH, unit=FIFF.FIFF_UNIT_V), - resp=dict(kind=FIFF.FIFFV_RESP_CH, unit=FIFF.FIFF_UNIT_V), - bio=dict(kind=FIFF.FIFFV_BIO_CH, unit=FIFF.FIFF_UNIT_V), - misc=dict(kind=FIFF.FIFFV_MISC_CH, unit=FIFF.FIFF_UNIT_V), - stim=dict(kind=FIFF.FIFFV_STIM_CH), - exci=dict(kind=FIFF.FIFFV_EXCI_CH), - syst=dict(kind=FIFF.FIFFV_SYST_CH), - ias=dict(kind=FIFF.FIFFV_IAS_CH), - gof=dict(kind=FIFF.FIFFV_GOODNESS_FIT), - dipole=dict(kind=FIFF.FIFFV_DIPOLE_WAVE), - chpi=dict(kind=[FIFF.FIFFV_QUAT_0, FIFF.FIFFV_QUAT_1, - FIFF.FIFFV_QUAT_2, FIFF.FIFFV_QUAT_3, - FIFF.FIFFV_QUAT_4, FIFF.FIFFV_QUAT_5, - FIFF.FIFFV_QUAT_6, FIFF.FIFFV_HPI_G, - FIFF.FIFFV_HPI_ERR, FIFF.FIFFV_HPI_MOV]), - fnirs_cw_amplitude=dict( - kind=FIFF.FIFFV_FNIRS_CH, - unit=FIFF.FIFF_UNIT_V, - coil_type=FIFF.FIFFV_COIL_FNIRS_CW_AMPLITUDE), - fnirs_fd_ac_amplitude=dict( - kind=FIFF.FIFFV_FNIRS_CH, - unit=FIFF.FIFF_UNIT_V, - coil_type=FIFF.FIFFV_COIL_FNIRS_FD_AC_AMPLITUDE), - fnirs_fd_phase=dict( - kind=FIFF.FIFFV_FNIRS_CH, - unit=FIFF.FIFF_UNIT_RAD, - coil_type=FIFF.FIFFV_COIL_FNIRS_FD_PHASE), - fnirs_od=dict(kind=FIFF.FIFFV_FNIRS_CH, - coil_type=FIFF.FIFFV_COIL_FNIRS_OD), - hbo=dict(kind=FIFF.FIFFV_FNIRS_CH, - unit=FIFF.FIFF_UNIT_MOL, - coil_type=FIFF.FIFFV_COIL_FNIRS_HBO), - hbr=dict(kind=FIFF.FIFFV_FNIRS_CH, - unit=FIFF.FIFF_UNIT_MOL, - coil_type=FIFF.FIFFV_COIL_FNIRS_HBR), - csd=dict(kind=FIFF.FIFFV_EEG_CH, - unit=FIFF.FIFF_UNIT_V_M2, - coil_type=FIFF.FIFFV_COIL_EEG_CSD), - temperature=dict(kind=FIFF.FIFFV_TEMPERATURE_CH, - unit=FIFF.FIFF_UNIT_CEL), - gsr=dict(kind=FIFF.FIFFV_GALVANIC_CH, - unit=FIFF.FIFF_UNIT_S), - eyegaze=dict(kind=FIFF.FIFFV_EYETRACK_CH, - coil_type=FIFF.FIFFV_COIL_EYETRACK_POS), - pupil=dict(kind=FIFF.FIFFV_EYETRACK_CH, - coil_type=FIFF.FIFFV_COIL_EYETRACK_PUPIL) - ) + base = dict( + grad=dict(kind=FIFF.FIFFV_MEG_CH, unit=FIFF.FIFF_UNIT_T_M), + mag=dict(kind=FIFF.FIFFV_MEG_CH, unit=FIFF.FIFF_UNIT_T), + ref_meg=dict(kind=FIFF.FIFFV_REF_MEG_CH), + eeg=dict( + kind=FIFF.FIFFV_EEG_CH, unit=FIFF.FIFF_UNIT_V, coil_type=FIFF.FIFFV_COIL_EEG + ), + seeg=dict( + kind=FIFF.FIFFV_SEEG_CH, + unit=FIFF.FIFF_UNIT_V, + coil_type=FIFF.FIFFV_COIL_EEG, + ), + dbs=dict( + kind=FIFF.FIFFV_DBS_CH, unit=FIFF.FIFF_UNIT_V, coil_type=FIFF.FIFFV_COIL_EEG + ), + ecog=dict( + kind=FIFF.FIFFV_ECOG_CH, + unit=FIFF.FIFF_UNIT_V, + coil_type=FIFF.FIFFV_COIL_EEG, + ), + eog=dict(kind=FIFF.FIFFV_EOG_CH, unit=FIFF.FIFF_UNIT_V), + emg=dict(kind=FIFF.FIFFV_EMG_CH, unit=FIFF.FIFF_UNIT_V), + ecg=dict(kind=FIFF.FIFFV_ECG_CH, unit=FIFF.FIFF_UNIT_V), + resp=dict(kind=FIFF.FIFFV_RESP_CH, unit=FIFF.FIFF_UNIT_V), + bio=dict(kind=FIFF.FIFFV_BIO_CH, unit=FIFF.FIFF_UNIT_V), + misc=dict(kind=FIFF.FIFFV_MISC_CH, unit=FIFF.FIFF_UNIT_V), + stim=dict(kind=FIFF.FIFFV_STIM_CH), + exci=dict(kind=FIFF.FIFFV_EXCI_CH), + syst=dict(kind=FIFF.FIFFV_SYST_CH), + ias=dict(kind=FIFF.FIFFV_IAS_CH), + gof=dict(kind=FIFF.FIFFV_GOODNESS_FIT), + dipole=dict(kind=FIFF.FIFFV_DIPOLE_WAVE), + chpi=dict( + kind=[ + FIFF.FIFFV_QUAT_0, + FIFF.FIFFV_QUAT_1, + FIFF.FIFFV_QUAT_2, + FIFF.FIFFV_QUAT_3, + FIFF.FIFFV_QUAT_4, + FIFF.FIFFV_QUAT_5, + FIFF.FIFFV_QUAT_6, + FIFF.FIFFV_HPI_G, + FIFF.FIFFV_HPI_ERR, + FIFF.FIFFV_HPI_MOV, + ] + ), + fnirs_cw_amplitude=dict( + kind=FIFF.FIFFV_FNIRS_CH, + unit=FIFF.FIFF_UNIT_V, + coil_type=FIFF.FIFFV_COIL_FNIRS_CW_AMPLITUDE, + ), + fnirs_fd_ac_amplitude=dict( + kind=FIFF.FIFFV_FNIRS_CH, + unit=FIFF.FIFF_UNIT_V, + coil_type=FIFF.FIFFV_COIL_FNIRS_FD_AC_AMPLITUDE, + ), + fnirs_fd_phase=dict( + kind=FIFF.FIFFV_FNIRS_CH, + unit=FIFF.FIFF_UNIT_RAD, + coil_type=FIFF.FIFFV_COIL_FNIRS_FD_PHASE, + ), + fnirs_od=dict(kind=FIFF.FIFFV_FNIRS_CH, coil_type=FIFF.FIFFV_COIL_FNIRS_OD), + hbo=dict( + kind=FIFF.FIFFV_FNIRS_CH, + unit=FIFF.FIFF_UNIT_MOL, + coil_type=FIFF.FIFFV_COIL_FNIRS_HBO, + ), + hbr=dict( + kind=FIFF.FIFFV_FNIRS_CH, + unit=FIFF.FIFF_UNIT_MOL, + coil_type=FIFF.FIFFV_COIL_FNIRS_HBR, + ), + csd=dict( + kind=FIFF.FIFFV_EEG_CH, + unit=FIFF.FIFF_UNIT_V_M2, + coil_type=FIFF.FIFFV_COIL_EEG_CSD, + ), + temperature=dict(kind=FIFF.FIFFV_TEMPERATURE_CH, unit=FIFF.FIFF_UNIT_CEL), + gsr=dict(kind=FIFF.FIFFV_GALVANIC_CH, unit=FIFF.FIFF_UNIT_S), + eyegaze=dict( + kind=FIFF.FIFFV_EYETRACK_CH, coil_type=FIFF.FIFFV_COIL_EYETRACK_POS + ), + pupil=dict( + kind=FIFF.FIFFV_EYETRACK_CH, coil_type=FIFF.FIFFV_COIL_EYETRACK_PUPIL + ), + ) if include_defaults: coil_none = dict(coil_type=FIFF.FIFFV_COIL_NONE) unit_none = dict(unit=FIFF.FIFF_UNIT_NONE) defaults = dict( grad=dict(coil_type=FIFF.FIFFV_COIL_VV_PLANAR_T1), mag=dict(coil_type=FIFF.FIFFV_COIL_VV_MAG_T3), - ref_meg=dict(coil_type=FIFF.FIFFV_COIL_VV_MAG_T3, - unit=FIFF.FIFF_UNIT_T), + ref_meg=dict(coil_type=FIFF.FIFFV_COIL_VV_MAG_T3, unit=FIFF.FIFF_UNIT_T), misc=dict(**coil_none, **unit_none), # NB: overwrites UNIT_V stim=dict(unit=FIFF.FIFF_UNIT_V, **coil_none), eog=coil_none, @@ -127,61 +156,69 @@ def get_channel_type_constants(include_defaults=False): _first_rule = { - FIFF.FIFFV_MEG_CH: 'meg', - FIFF.FIFFV_REF_MEG_CH: 'ref_meg', - FIFF.FIFFV_EEG_CH: 'eeg', - FIFF.FIFFV_STIM_CH: 'stim', - FIFF.FIFFV_EOG_CH: 'eog', - FIFF.FIFFV_EMG_CH: 'emg', - FIFF.FIFFV_ECG_CH: 'ecg', - FIFF.FIFFV_RESP_CH: 'resp', - FIFF.FIFFV_MISC_CH: 'misc', - FIFF.FIFFV_EXCI_CH: 'exci', - FIFF.FIFFV_IAS_CH: 'ias', - FIFF.FIFFV_SYST_CH: 'syst', - FIFF.FIFFV_SEEG_CH: 'seeg', - FIFF.FIFFV_DBS_CH: 'dbs', - FIFF.FIFFV_BIO_CH: 'bio', - FIFF.FIFFV_QUAT_0: 'chpi', - FIFF.FIFFV_QUAT_1: 'chpi', - FIFF.FIFFV_QUAT_2: 'chpi', - FIFF.FIFFV_QUAT_3: 'chpi', - FIFF.FIFFV_QUAT_4: 'chpi', - FIFF.FIFFV_QUAT_5: 'chpi', - FIFF.FIFFV_QUAT_6: 'chpi', - FIFF.FIFFV_HPI_G: 'chpi', - FIFF.FIFFV_HPI_ERR: 'chpi', - FIFF.FIFFV_HPI_MOV: 'chpi', - FIFF.FIFFV_DIPOLE_WAVE: 'dipole', - FIFF.FIFFV_GOODNESS_FIT: 'gof', - FIFF.FIFFV_ECOG_CH: 'ecog', - FIFF.FIFFV_FNIRS_CH: 'fnirs', - FIFF.FIFFV_TEMPERATURE_CH: 'temperature', - FIFF.FIFFV_GALVANIC_CH: 'gsr', - FIFF.FIFFV_EYETRACK_CH: 'eyetrack', + FIFF.FIFFV_MEG_CH: "meg", + FIFF.FIFFV_REF_MEG_CH: "ref_meg", + FIFF.FIFFV_EEG_CH: "eeg", + FIFF.FIFFV_STIM_CH: "stim", + FIFF.FIFFV_EOG_CH: "eog", + FIFF.FIFFV_EMG_CH: "emg", + FIFF.FIFFV_ECG_CH: "ecg", + FIFF.FIFFV_RESP_CH: "resp", + FIFF.FIFFV_MISC_CH: "misc", + FIFF.FIFFV_EXCI_CH: "exci", + FIFF.FIFFV_IAS_CH: "ias", + FIFF.FIFFV_SYST_CH: "syst", + FIFF.FIFFV_SEEG_CH: "seeg", + FIFF.FIFFV_DBS_CH: "dbs", + FIFF.FIFFV_BIO_CH: "bio", + FIFF.FIFFV_QUAT_0: "chpi", + FIFF.FIFFV_QUAT_1: "chpi", + FIFF.FIFFV_QUAT_2: "chpi", + FIFF.FIFFV_QUAT_3: "chpi", + FIFF.FIFFV_QUAT_4: "chpi", + FIFF.FIFFV_QUAT_5: "chpi", + FIFF.FIFFV_QUAT_6: "chpi", + FIFF.FIFFV_HPI_G: "chpi", + FIFF.FIFFV_HPI_ERR: "chpi", + FIFF.FIFFV_HPI_MOV: "chpi", + FIFF.FIFFV_DIPOLE_WAVE: "dipole", + FIFF.FIFFV_GOODNESS_FIT: "gof", + FIFF.FIFFV_ECOG_CH: "ecog", + FIFF.FIFFV_FNIRS_CH: "fnirs", + FIFF.FIFFV_TEMPERATURE_CH: "temperature", + FIFF.FIFFV_GALVANIC_CH: "gsr", + FIFF.FIFFV_EYETRACK_CH: "eyetrack", } # How to reduce our categories in channel_type (originally) _second_rules = { - 'meg': ('unit', {FIFF.FIFF_UNIT_T_M: 'grad', - FIFF.FIFF_UNIT_T: 'mag'}), - 'fnirs': ('coil_type', {FIFF.FIFFV_COIL_FNIRS_HBO: 'hbo', - FIFF.FIFFV_COIL_FNIRS_HBR: 'hbr', - FIFF.FIFFV_COIL_FNIRS_CW_AMPLITUDE: - 'fnirs_cw_amplitude', - FIFF.FIFFV_COIL_FNIRS_FD_AC_AMPLITUDE: - 'fnirs_fd_ac_amplitude', - FIFF.FIFFV_COIL_FNIRS_FD_PHASE: - 'fnirs_fd_phase', - FIFF.FIFFV_COIL_FNIRS_OD: 'fnirs_od', - }), - 'eeg': ('coil_type', {FIFF.FIFFV_COIL_EEG: 'eeg', - FIFF.FIFFV_COIL_EEG_BIPOLAR: 'eeg', - FIFF.FIFFV_COIL_NONE: 'eeg', # MNE-C backward compat - FIFF.FIFFV_COIL_EEG_CSD: 'csd', - }), - 'eyetrack': ('coil_type', {FIFF.FIFFV_COIL_EYETRACK_POS: 'eyegaze', - FIFF.FIFFV_COIL_EYETRACK_PUPIL: 'pupil' - }) + "meg": ("unit", {FIFF.FIFF_UNIT_T_M: "grad", FIFF.FIFF_UNIT_T: "mag"}), + "fnirs": ( + "coil_type", + { + FIFF.FIFFV_COIL_FNIRS_HBO: "hbo", + FIFF.FIFFV_COIL_FNIRS_HBR: "hbr", + FIFF.FIFFV_COIL_FNIRS_CW_AMPLITUDE: "fnirs_cw_amplitude", + FIFF.FIFFV_COIL_FNIRS_FD_AC_AMPLITUDE: "fnirs_fd_ac_amplitude", + FIFF.FIFFV_COIL_FNIRS_FD_PHASE: "fnirs_fd_phase", + FIFF.FIFFV_COIL_FNIRS_OD: "fnirs_od", + }, + ), + "eeg": ( + "coil_type", + { + FIFF.FIFFV_COIL_EEG: "eeg", + FIFF.FIFFV_COIL_EEG_BIPOLAR: "eeg", + FIFF.FIFFV_COIL_NONE: "eeg", # MNE-C backward compat + FIFF.FIFFV_COIL_EEG_CSD: "csd", + }, + ), + "eyetrack": ( + "coil_type", + { + FIFF.FIFFV_COIL_EYETRACK_POS: "eyegaze", + FIFF.FIFFV_COIL_EYETRACK_PUPIL: "pupil", + }, + ), } @@ -208,12 +245,13 @@ def channel_type(info, idx): # This is faster than the original _channel_type_old now in test_pick.py # because it uses (at most!) two dict lookups plus one conditional # to get the channel type string. - ch = info['chs'][idx] + ch = info["chs"][idx] try: - first_kind = _first_rule[ch['kind']] + first_kind = _first_rule[ch["kind"]] except KeyError: - raise ValueError('Unknown channel type (%s) for channel "%s"' - % (ch['kind'], ch["ch_name"])) + raise ValueError( + 'Unknown channel type (%s) for channel "%s"' % (ch["kind"], ch["ch_name"]) + ) if first_kind in _second_rules: key, second_rule = _second_rules[first_kind] first_kind = second_rule[ch[key]] @@ -252,8 +290,8 @@ def pick_channels(ch_names, include, exclude=[], ordered=None, *, verbose=None): pick_channels_regexp, pick_types """ if len(np.unique(ch_names)) != len(ch_names): - raise RuntimeError('ch_names is not a unique list, picking is unsafe') - _validate_type(ordered, (bool, None), 'ordered') + raise RuntimeError("ch_names is not a unique list, picking is unsafe") + _validate_type(ordered, (bool, None), "ordered") _check_excludes_includes(include) _check_excludes_includes(exclude) if not isinstance(include, list): @@ -270,27 +308,31 @@ def pick_channels(ch_names, include, exclude=[], ordered=None, *, verbose=None): else: missing.append(name) dep_msg = ( - 'The default for pick_channels will change from ordered=False to ' - 'ordered=True in 1.5' + "The default for pick_channels will change from ordered=False to " + "ordered=True in 1.5" ) if len(missing): if ordered is None: warn( - f'{dep_msg} and this will result in an error because the ' - f'following channel names are missing:\n{missing}\n' - 'Either fix your included names or explicitly pass ' - 'ordered=False.', FutureWarning) + f"{dep_msg} and this will result in an error because the " + f"following channel names are missing:\n{missing}\n" + "Either fix your included names or explicitly pass " + "ordered=False.", + FutureWarning, + ) elif ordered: - raise ValueError('Missing channels from ch_names required by ' - 'include:\n%s' % (missing,)) + raise ValueError( + "Missing channels from ch_names required by " + "include:\n%s" % (missing,) + ) if not ordered: out_sel = np.unique(sel) if ordered is None and not np.array_equal(out_sel, sel): warn( - f'{dep_msg} and this will result in a change of behavior ' - 'because the resulting channel order will not match. Either ' - 'use a channel order that matches your instance or ' - 'pass ordered=False.', + f"{dep_msg} and this will result in a change of behavior " + "because the resulting channel order will not match. Either " + "use a channel order that matches your instance or " + "pass ordered=False.", FutureWarning, ) sel = out_sel @@ -335,14 +377,14 @@ def _triage_meg_pick(ch, meg): """Triage an MEG pick type.""" if meg is True: return True - elif ch['unit'] == FIFF.FIFF_UNIT_T_M: - if meg == 'grad': + elif ch["unit"] == FIFF.FIFF_UNIT_T_M: + if meg == "grad": return True - elif meg == 'planar1' and ch['ch_name'].endswith('2'): + elif meg == "planar1" and ch["ch_name"].endswith("2"): return True - elif meg == 'planar2' and ch['ch_name'].endswith('3'): + elif meg == "planar2" and ch["ch_name"].endswith("3"): return True - elif (meg == 'mag' and ch['unit'] == FIFF.FIFF_UNIT_T): + elif meg == "mag" and ch["unit"] == FIFF.FIFF_UNIT_T: return True return False @@ -351,20 +393,25 @@ def _triage_fnirs_pick(ch, fnirs, warned): """Triage an fNIRS pick type.""" if fnirs is True: return True - elif ch['coil_type'] == FIFF.FIFFV_COIL_FNIRS_HBO and 'hbo' in fnirs: + elif ch["coil_type"] == FIFF.FIFFV_COIL_FNIRS_HBO and "hbo" in fnirs: return True - elif ch['coil_type'] == FIFF.FIFFV_COIL_FNIRS_HBR and 'hbr' in fnirs: + elif ch["coil_type"] == FIFF.FIFFV_COIL_FNIRS_HBR and "hbr" in fnirs: return True - elif ch['coil_type'] == FIFF.FIFFV_COIL_FNIRS_CW_AMPLITUDE and \ - 'fnirs_cw_amplitude' in fnirs: + elif ( + ch["coil_type"] == FIFF.FIFFV_COIL_FNIRS_CW_AMPLITUDE + and "fnirs_cw_amplitude" in fnirs + ): return True - elif ch['coil_type'] == FIFF.FIFFV_COIL_FNIRS_FD_AC_AMPLITUDE and \ - 'fnirs_fd_ac_amplitude' in fnirs: + elif ( + ch["coil_type"] == FIFF.FIFFV_COIL_FNIRS_FD_AC_AMPLITUDE + and "fnirs_fd_ac_amplitude" in fnirs + ): return True - elif ch['coil_type'] == FIFF.FIFFV_COIL_FNIRS_FD_PHASE and \ - 'fnirs_fd_phase' in fnirs: + elif ( + ch["coil_type"] == FIFF.FIFFV_COIL_FNIRS_FD_PHASE and "fnirs_fd_phase" in fnirs + ): return True - elif ch['coil_type'] == FIFF.FIFFV_COIL_FNIRS_OD and 'fnirs_od' in fnirs: + elif ch["coil_type"] == FIFF.FIFFV_COIL_FNIRS_OD and "fnirs_od" in fnirs: return True return False @@ -375,11 +422,9 @@ def _triage_eyetrack_pick(ch, eyetrack): return False elif eyetrack is True: return True - elif ch['coil_type'] == FIFF.FIFFV_COIL_EYETRACK_PUPIL and \ - 'pupil' in eyetrack: + elif ch["coil_type"] == FIFF.FIFFV_COIL_EYETRACK_PUPIL and "pupil" in eyetrack: return True - elif ch['coil_type'] == FIFF.FIFFV_COIL_EYETRACK_POS and \ - 'eyegaze' in eyetrack: + elif ch["coil_type"] == FIFF.FIFFV_COIL_EYETRACK_POS and "eyegaze" in eyetrack: return True return False @@ -387,11 +432,12 @@ def _triage_eyetrack_pick(ch, eyetrack): def _check_meg_type(meg, allow_auto=False): """Ensure a valid meg type.""" if isinstance(meg, str): - allowed_types = ['grad', 'mag', 'planar1', 'planar2'] - allowed_types += ['auto'] if allow_auto else [] + allowed_types = ["grad", "mag", "planar1", "planar2"] + allowed_types += ["auto"] if allow_auto else [] if meg not in allowed_types: - raise ValueError('meg value must be one of %s or bool, not %s' - % (allowed_types, meg)) + raise ValueError( + "meg value must be one of %s or bool, not %s" % (allowed_types, meg) + ) def _check_info_exclude(info, exclude): @@ -399,22 +445,49 @@ def _check_info_exclude(info, exclude): info._check_consistency() if exclude is None: raise ValueError('exclude must be a list of strings or "bads"') - elif exclude == 'bads': - exclude = info.get('bads', []) + elif exclude == "bads": + exclude = info.get("bads", []) elif not isinstance(exclude, (list, tuple)): - raise ValueError('exclude must either be "bads" or a list of strings.' - ' If only one channel is to be excluded, use ' - '[ch_name] instead of passing ch_name.') + raise ValueError( + 'exclude must either be "bads" or a list of strings.' + " If only one channel is to be excluded, use " + "[ch_name] instead of passing ch_name." + ) return exclude @fill_doc -def pick_types(info, meg=False, eeg=False, stim=False, eog=False, ecg=False, - emg=False, ref_meg='auto', *, misc=False, resp=False, - chpi=False, exci=False, ias=False, syst=False, seeg=False, - dipole=False, gof=False, bio=False, ecog=False, fnirs=False, - csd=False, dbs=False, temperature=False, gsr=False, - eyetrack=False, include=(), exclude='bads', selection=None): +def pick_types( + info, + meg=False, + eeg=False, + stim=False, + eog=False, + ecg=False, + emg=False, + ref_meg="auto", + *, + misc=False, + resp=False, + chpi=False, + exci=False, + ias=False, + syst=False, + seeg=False, + dipole=False, + gof=False, + bio=False, + ecog=False, + fnirs=False, + csd=False, + dbs=False, + temperature=False, + gsr=False, + eyetrack=False, + include=(), + exclude="bads", + selection=None, +): """Pick channels by type and names. Parameters @@ -429,36 +502,79 @@ def pick_types(info, meg=False, eeg=False, stim=False, eog=False, ecg=False, """ # NOTE: Changes to this function's signature should also be changed in # PickChannelsMixin - _validate_type(meg, (bool, str), 'meg') + _validate_type(meg, (bool, str), "meg") exclude = _check_info_exclude(info, exclude) - nchan = info['nchan'] + nchan = info["nchan"] pick = np.zeros(nchan, dtype=bool) _check_meg_type(ref_meg, allow_auto=True) _check_meg_type(meg) - if isinstance(ref_meg, str) and ref_meg == 'auto': - ref_meg = ('comps' in info and info['comps'] is not None and - len(info['comps']) > 0 and meg is not False) + if isinstance(ref_meg, str) and ref_meg == "auto": + ref_meg = ( + "comps" in info + and info["comps"] is not None + and len(info["comps"]) > 0 + and meg is not False + ) - for param in (eeg, stim, eog, ecg, emg, misc, resp, chpi, exci, - ias, syst, seeg, dipole, gof, bio, ecog, csd, dbs, - temperature, gsr): + for param in ( + eeg, + stim, + eog, + ecg, + emg, + misc, + resp, + chpi, + exci, + ias, + syst, + seeg, + dipole, + gof, + bio, + ecog, + csd, + dbs, + temperature, + gsr, + ): if not isinstance(param, bool): - w = ('Parameters for all channel types (with the exception of ' - '"meg", "ref_meg", "fnirs", and "eyetrack") must be of type ' - 'bool, not {}.') + w = ( + "Parameters for all channel types (with the exception of " + '"meg", "ref_meg", "fnirs", and "eyetrack") must be of type ' + "bool, not {}." + ) raise ValueError(w.format(type(param))) - param_dict = dict(eeg=eeg, stim=stim, eog=eog, ecg=ecg, emg=emg, - misc=misc, resp=resp, chpi=chpi, exci=exci, - ias=ias, syst=syst, seeg=seeg, dbs=dbs, dipole=dipole, - gof=gof, bio=bio, ecog=ecog, csd=csd, - temperature=temperature, gsr=gsr, eyetrack=eyetrack) + param_dict = dict( + eeg=eeg, + stim=stim, + eog=eog, + ecg=ecg, + emg=emg, + misc=misc, + resp=resp, + chpi=chpi, + exci=exci, + ias=ias, + syst=syst, + seeg=seeg, + dbs=dbs, + dipole=dipole, + gof=gof, + bio=bio, + ecog=ecog, + csd=csd, + temperature=temperature, + gsr=gsr, + eyetrack=eyetrack, + ) # avoid triage if possible if isinstance(meg, bool): - for key in ('grad', 'mag'): + for key in ("grad", "mag"): param_dict[key] = meg if isinstance(fnirs, bool): for key in _FNIRS_CH_TYPES_SPLIT: @@ -469,35 +585,39 @@ def pick_types(info, meg=False, eeg=False, stim=False, eog=False, ecg=False, try: pick[k] = param_dict[ch_type] except KeyError: # not so simple - assert ch_type in ('grad', 'mag', 'ref_meg') + \ - _FNIRS_CH_TYPES_SPLIT + _EYETRACK_CH_TYPES_SPLIT - if ch_type in ('grad', 'mag'): - pick[k] = _triage_meg_pick(info['chs'][k], meg) - elif ch_type == 'ref_meg': - pick[k] = _triage_meg_pick(info['chs'][k], ref_meg) - elif ch_type in ('eyegaze', 'pupil'): - pick[k] = _triage_eyetrack_pick(info['chs'][k], eyetrack) + assert ( + ch_type + in ("grad", "mag", "ref_meg") + + _FNIRS_CH_TYPES_SPLIT + + _EYETRACK_CH_TYPES_SPLIT + ) + if ch_type in ("grad", "mag"): + pick[k] = _triage_meg_pick(info["chs"][k], meg) + elif ch_type == "ref_meg": + pick[k] = _triage_meg_pick(info["chs"][k], ref_meg) + elif ch_type in ("eyegaze", "pupil"): + pick[k] = _triage_eyetrack_pick(info["chs"][k], eyetrack) else: # ch_type in ('hbo', 'hbr') - pick[k] = _triage_fnirs_pick(info['chs'][k], fnirs, warned) + pick[k] = _triage_fnirs_pick(info["chs"][k], fnirs, warned) # restrict channels to selection if provided if selection is not None: # the selection only restricts these types of channels - sel_kind = [FIFF.FIFFV_MEG_CH, FIFF.FIFFV_REF_MEG_CH, - FIFF.FIFFV_EEG_CH] + sel_kind = [FIFF.FIFFV_MEG_CH, FIFF.FIFFV_REF_MEG_CH, FIFF.FIFFV_EEG_CH] for k in np.where(pick)[0]: - if (info['chs'][k]['kind'] in sel_kind and - info['ch_names'][k] not in selection): + if ( + info["chs"][k]["kind"] in sel_kind + and info["ch_names"][k] not in selection + ): pick[k] = False - myinclude = [info['ch_names'][k] for k in range(nchan) if pick[k]] + myinclude = [info["ch_names"][k] for k in range(nchan) if pick[k]] myinclude += include if len(myinclude) == 0: sel = np.array([], int) else: - sel = pick_channels( - info['ch_names'], myinclude, exclude, ordered=False) + sel = pick_channels(info["ch_names"], myinclude, exclude, ordered=False) return sel @@ -529,42 +649,45 @@ def pick_info(info, sel=(), copy=True, verbose=None): if sel is None: return info elif len(sel) == 0: - raise ValueError('No channels match the selection.') - n_unique = len(np.unique(np.arange(len(info['ch_names']))[sel])) + raise ValueError("No channels match the selection.") + n_unique = len(np.unique(np.arange(len(info["ch_names"]))[sel])) if n_unique != len(sel): - raise ValueError('Found %d / %d unique names, sel is not unique' - % (n_unique, len(sel))) + raise ValueError( + "Found %d / %d unique names, sel is not unique" % (n_unique, len(sel)) + ) # make sure required the compensation channels are present - if len(info.get('comps', [])) > 0: - ch_names = [info['ch_names'][idx] for idx in sel] + if len(info.get("comps", [])) > 0: + ch_names = [info["ch_names"][idx] for idx in sel] _, comps_missing = _bad_chans_comp(info, ch_names) if len(comps_missing) > 0: - logger.info('Removing %d compensators from info because ' - 'not all compensation channels were picked.' - % (len(info['comps']),)) + logger.info( + "Removing %d compensators from info because " + "not all compensation channels were picked." % (len(info["comps"]),) + ) with info._unlock(): - info['comps'] = [] + info["comps"] = [] with info._unlock(): - info['chs'] = [info['chs'][k] for k in sel] + info["chs"] = [info["chs"][k] for k in sel] info._update_redundant() - info['bads'] = [ch for ch in info['bads'] if ch in info['ch_names']] - if 'comps' in info: - comps = deepcopy(info['comps']) + info["bads"] = [ch for ch in info["bads"] if ch in info["ch_names"]] + if "comps" in info: + comps = deepcopy(info["comps"]) for c in comps: - row_idx = [k for k, n in enumerate(c['data']['row_names']) - if n in info['ch_names']] - row_names = [c['data']['row_names'][i] for i in row_idx] - rowcals = c['rowcals'][row_idx] - c['rowcals'] = rowcals - c['data']['nrow'] = len(row_names) - c['data']['row_names'] = row_names - c['data']['data'] = c['data']['data'][row_idx] + row_idx = [ + k for k, n in enumerate(c["data"]["row_names"]) if n in info["ch_names"] + ] + row_names = [c["data"]["row_names"][i] for i in row_idx] + rowcals = c["rowcals"][row_idx] + c["rowcals"] = rowcals + c["data"]["nrow"] = len(row_names) + c["data"]["row_names"] = row_names + c["data"]["data"] = c["data"]["data"][row_idx] with info._unlock(): - info['comps'] = comps - if info.get('custom_ref_applied', False) and not _electrode_types(info): + info["comps"] = comps + if info.get("custom_ref_applied", False) and not _electrode_types(info): with info._unlock(): - info['custom_ref_applied'] = FIFF.FIFFV_MNE_CUSTOM_REF_OFF + info["custom_ref_applied"] = FIFF.FIFFV_MNE_CUSTOM_REF_OFF info._check_consistency() return info @@ -577,14 +700,16 @@ def _has_kit_refs(info, picks): run when KIT reference channels are included. """ for p in picks: - if info['chs'][p]['coil_type'] == FIFF.FIFFV_COIL_KIT_REF_MAG: + if info["chs"][p]["coil_type"] == FIFF.FIFFV_COIL_KIT_REF_MAG: return True return False -@deprecated('pick_channels_evoked in deprecated and will be removed in 1.5, ' - 'use evoked.copy().pick(...) instead.') -def pick_channels_evoked(orig, include=[], exclude='bads'): +@deprecated( + "pick_channels_evoked in deprecated and will be removed in 1.5, " + "use evoked.copy().pick(...) instead." +) +def pick_channels_evoked(orig, include=[], exclude="bads"): """Pick channels from evoked data. Parameters @@ -606,13 +731,11 @@ def pick_channels_evoked(orig, include=[], exclude='bads'): if len(include) == 0 and len(exclude) == 0: return orig - exclude = _check_excludes_includes(exclude, info=orig.info, - allow_bads=True) - sel = pick_channels(orig.info['ch_names'], include=include, - exclude=exclude) + exclude = _check_excludes_includes(exclude, info=orig.info, allow_bads=True) + sel = pick_channels(orig.info["ch_names"], include=include, exclude=exclude) if len(sel) == 0: - raise ValueError('Warning : No channels match the selection.') + raise ValueError("Warning : No channels match the selection.") res = deepcopy(orig) # @@ -628,8 +751,9 @@ def pick_channels_evoked(orig, include=[], exclude='bads'): @verbose -def pick_channels_forward(orig, include=[], exclude=[], ordered=None, - copy=True, *, verbose=None): +def pick_channels_forward( + orig, include=[], exclude=[], ordered=None, copy=True, *, verbose=None +): """Pick channels from forward operator. Parameters @@ -655,61 +779,73 @@ def pick_channels_forward(orig, include=[], exclude=[], ordered=None, Forward solution restricted to selected channels. If include and exclude are empty it returns orig without copy. """ - orig['info']._check_consistency() + orig["info"]._check_consistency() if len(include) == 0 and len(exclude) == 0: return orig.copy() if copy else orig - exclude = _check_excludes_includes(exclude, - info=orig['info'], allow_bads=True) + exclude = _check_excludes_includes(exclude, info=orig["info"], allow_bads=True) # Allow for possibility of channel ordering in forward solution being # different from that of the M/EEG file it is based on. - sel_sol = pick_channels(orig['sol']['row_names'], include=include, - exclude=exclude, ordered=ordered) - sel_info = pick_channels(orig['info']['ch_names'], include=include, - exclude=exclude, ordered=ordered) + sel_sol = pick_channels( + orig["sol"]["row_names"], include=include, exclude=exclude, ordered=ordered + ) + sel_info = pick_channels( + orig["info"]["ch_names"], include=include, exclude=exclude, ordered=ordered + ) fwd = deepcopy(orig) if copy else orig # Check that forward solution and original data file agree on #channels if len(sel_sol) != len(sel_info): - raise ValueError('Forward solution and functional data appear to ' - 'have different channel names, please check.') + raise ValueError( + "Forward solution and functional data appear to " + "have different channel names, please check." + ) # Do we have something? nuse = len(sel_sol) if nuse == 0: - raise ValueError('Nothing remains after picking') + raise ValueError("Nothing remains after picking") - logger.info(' %d out of %d channels remain after picking' - % (nuse, fwd['nchan'])) + logger.info(" %d out of %d channels remain after picking" % (nuse, fwd["nchan"])) # Pick the correct rows of the forward operator using sel_sol - fwd['sol']['data'] = fwd['sol']['data'][sel_sol, :] - fwd['_orig_sol'] = fwd['_orig_sol'][sel_sol, :] - fwd['sol']['nrow'] = nuse + fwd["sol"]["data"] = fwd["sol"]["data"][sel_sol, :] + fwd["_orig_sol"] = fwd["_orig_sol"][sel_sol, :] + fwd["sol"]["nrow"] = nuse - ch_names = [fwd['sol']['row_names'][k] for k in sel_sol] - fwd['nchan'] = nuse - fwd['sol']['row_names'] = ch_names + ch_names = [fwd["sol"]["row_names"][k] for k in sel_sol] + fwd["nchan"] = nuse + fwd["sol"]["row_names"] = ch_names # Pick the appropriate channel names from the info-dict using sel_info - with fwd['info']._unlock(): - fwd['info']['chs'] = [fwd['info']['chs'][k] for k in sel_info] - fwd['info']._update_redundant() - fwd['info']['bads'] = [b for b in fwd['info']['bads'] if b in ch_names] - - if fwd['sol_grad'] is not None: - fwd['sol_grad']['data'] = fwd['sol_grad']['data'][sel_sol, :] - fwd['_orig_sol_grad'] = fwd['_orig_sol_grad'][sel_sol, :] - fwd['sol_grad']['nrow'] = nuse - fwd['sol_grad']['row_names'] = [fwd['sol_grad']['row_names'][k] - for k in sel_sol] + with fwd["info"]._unlock(): + fwd["info"]["chs"] = [fwd["info"]["chs"][k] for k in sel_info] + fwd["info"]._update_redundant() + fwd["info"]["bads"] = [b for b in fwd["info"]["bads"] if b in ch_names] + + if fwd["sol_grad"] is not None: + fwd["sol_grad"]["data"] = fwd["sol_grad"]["data"][sel_sol, :] + fwd["_orig_sol_grad"] = fwd["_orig_sol_grad"][sel_sol, :] + fwd["sol_grad"]["nrow"] = nuse + fwd["sol_grad"]["row_names"] = [ + fwd["sol_grad"]["row_names"][k] for k in sel_sol + ] return fwd -def pick_types_forward(orig, meg=False, eeg=False, ref_meg=True, seeg=False, - ecog=False, dbs=False, include=[], exclude=[]): +def pick_types_forward( + orig, + meg=False, + eeg=False, + ref_meg=True, + seeg=False, + ecog=False, + dbs=False, + include=[], + exclude=[], +): """Pick by channel type and names from a forward operator. Parameters @@ -741,12 +877,21 @@ def pick_types_forward(orig, meg=False, eeg=False, ref_meg=True, seeg=False, res : dict Forward solution restricted to selected channel types. """ - info = orig['info'] - sel = pick_types(info, meg, eeg, ref_meg=ref_meg, seeg=seeg, - ecog=ecog, dbs=dbs, include=include, exclude=exclude) + info = orig["info"] + sel = pick_types( + info, + meg, + eeg, + ref_meg=ref_meg, + seeg=seeg, + ecog=ecog, + dbs=dbs, + include=include, + exclude=exclude, + ) if len(sel) == 0: - raise ValueError('No valid channels found') - include_ch_names = [info['ch_names'][k] for k in sel] + raise ValueError("No valid channels found") + include_ch_names = [info["ch_names"][k] for k in sel] return pick_channels_forward(orig, include_ch_names) @@ -766,14 +911,24 @@ def channel_indices_by_type(info, picks=None): A dictionary that maps each channel type to a (possibly empty) list of channel indices. """ - idx_by_type = {key: list() for key in _PICK_TYPES_KEYS if - key not in ('meg', 'fnirs', 'eyetrack')} - idx_by_type.update(mag=list(), grad=list(), hbo=list(), hbr=list(), - fnirs_cw_amplitude=list(), fnirs_fd_ac_amplitude=list(), - fnirs_fd_phase=list(), fnirs_od=list(), - eyegaze=list(), pupil=list()) - picks = _picks_to_idx(info, picks, - none='all', exclude=(), allow_empty=True) + idx_by_type = { + key: list() + for key in _PICK_TYPES_KEYS + if key not in ("meg", "fnirs", "eyetrack") + } + idx_by_type.update( + mag=list(), + grad=list(), + hbo=list(), + hbr=list(), + fnirs_cw_amplitude=list(), + fnirs_fd_ac_amplitude=list(), + fnirs_fd_phase=list(), + fnirs_od=list(), + eyegaze=list(), + pupil=list(), + ) + picks = _picks_to_idx(info, picks, none="all", exclude=(), allow_empty=True) for k in picks: ch_type = channel_type(info, k) for key in idx_by_type.keys(): @@ -783,8 +938,9 @@ def channel_indices_by_type(info, picks=None): @verbose -def pick_channels_cov(orig, include=[], exclude='bads', ordered=None, - copy=True, *, verbose=None): +def pick_channels_cov( + orig, include=[], exclude="bads", ordered=None, copy=True, *, verbose=None +): """Pick channels from covariance matrix. Parameters @@ -812,22 +968,23 @@ def pick_channels_cov(orig, include=[], exclude='bads', ordered=None, orig = orig.copy() # A little peculiarity of the cov objects is that these two fields # should not be copied over when None. - if 'method' in orig and orig['method'] is None: - del orig['method'] - if 'loglik' in orig and orig['loglik'] is None: - del orig['loglik'] - - exclude = orig['bads'] if exclude == 'bads' else exclude - sel = pick_channels(orig['names'], include=include, exclude=exclude, - ordered=ordered) - data = orig['data'][sel][:, sel] if not orig['diag'] else orig['data'][sel] - names = [orig['names'][k] for k in sel] - bads = [name for name in orig['bads'] if name in orig['names']] - - orig['data'] = data - orig['names'] = names - orig['bads'] = bads - orig['dim'] = len(data) + if "method" in orig and orig["method"] is None: + del orig["method"] + if "loglik" in orig and orig["loglik"] is None: + del orig["loglik"] + + exclude = orig["bads"] if exclude == "bads" else exclude + sel = pick_channels( + orig["names"], include=include, exclude=exclude, ordered=ordered + ) + data = orig["data"][sel][:, sel] if not orig["diag"] else orig["data"][sel] + names = [orig["names"][k] for k in sel] + bads = [name for name in orig["bads"] if name in orig["names"]] + + orig["data"] = data + orig["names"] = names + orig["bads"] = bads + orig["dim"] = len(data) return orig @@ -836,8 +993,10 @@ def _mag_grad_dependent(info): """Determine of mag and grad should be dealt with jointly.""" # right now just uses SSS, could be computed / checked from cov # but probably overkill - return any(ph.get('max_info', {}).get('sss_info', {}).get('in_order', 0) - for ph in info.get('proc_history', [])) + return any( + ph.get("max_info", {}).get("sss_info", {}).get("in_order", 0) + for ph in info.get("proc_history", []) + ) @fill_doc @@ -855,24 +1014,28 @@ def _contains_ch_type(info, ch_type): has_ch_type : bool Whether the channel type is present or not. """ - _validate_type(ch_type, 'str', "ch_type") + _validate_type(ch_type, "str", "ch_type") meg_extras = list(_MEG_CH_TYPES_SPLIT) fnirs_extras = list(_FNIRS_CH_TYPES_SPLIT) et_extras = list(_EYETRACK_CH_TYPES_SPLIT) - valid_channel_types = sorted([key for key in _PICK_TYPES_KEYS - if key != 'meg'] - + meg_extras + fnirs_extras + et_extras) - _check_option('ch_type', ch_type, valid_channel_types) + valid_channel_types = sorted( + [key for key in _PICK_TYPES_KEYS if key != "meg"] + + meg_extras + + fnirs_extras + + et_extras + ) + _check_option("ch_type", ch_type, valid_channel_types) if info is None: - raise ValueError('Cannot check for channels of type "%s" because info ' - 'is None' % (ch_type,)) - return any(ch_type == channel_type(info, ii) - for ii in range(info['nchan'])) + raise ValueError( + 'Cannot check for channels of type "%s" because info ' + "is None" % (ch_type,) + ) + return any(ch_type == channel_type(info, ii) for ii in range(info["nchan"])) @fill_doc -def _picks_by_type(info, meg_combined=False, ref_meg=False, exclude='bads'): +def _picks_by_type(info, meg_combined=False, ref_meg=False, exclude="bads"): """Get data channel indices as separate list of tuples. Parameters @@ -892,36 +1055,41 @@ def _picks_by_type(info, meg_combined=False, ref_meg=False, exclude='bads'): picks_list : list of tuples The list of tuples of picks and the type string. """ - _validate_type(ref_meg, bool, 'ref_meg') + _validate_type(ref_meg, bool, "ref_meg") exclude = _check_info_exclude(info, exclude) - if meg_combined == 'auto': + if meg_combined == "auto": meg_combined = _mag_grad_dependent(info) picks_list = [] picks_list = {ch_type: list() for ch_type in _DATA_CH_TYPES_SPLIT} - for k in range(info['nchan']): - if info['chs'][k]['ch_name'] not in exclude: + for k in range(info["nchan"]): + if info["chs"][k]["ch_name"] not in exclude: this_type = channel_type(info, k) try: picks_list[this_type].append(k) except KeyError: # This annoyance is due to differences in pick_types # and channel_type behavior - if this_type == 'ref_meg': - ch = info['chs'][k] + if this_type == "ref_meg": + ch = info["chs"][k] if _triage_meg_pick(ch, ref_meg): - if ch['unit'] == FIFF.FIFF_UNIT_T: - picks_list['mag'].append(k) - elif ch['unit'] == FIFF.FIFF_UNIT_T_M: - picks_list['grad'].append(k) + if ch["unit"] == FIFF.FIFF_UNIT_T: + picks_list["mag"].append(k) + elif ch["unit"] == FIFF.FIFF_UNIT_T_M: + picks_list["grad"].append(k) else: pass # not a data channel type - picks_list = [(ch_type, np.array(picks_list[ch_type], int)) - for ch_type in _DATA_CH_TYPES_SPLIT] - assert _DATA_CH_TYPES_SPLIT[:2] == ('mag', 'grad') + picks_list = [ + (ch_type, np.array(picks_list[ch_type], int)) + for ch_type in _DATA_CH_TYPES_SPLIT + ] + assert _DATA_CH_TYPES_SPLIT[:2] == ("mag", "grad") if meg_combined and len(picks_list[0][1]) and len(picks_list[1][1]): picks_list.insert( - 0, ('meg', np.unique(np.concatenate([picks_list.pop(0)[1], - picks_list.pop(0)[1]]))) + 0, + ( + "meg", + np.unique(np.concatenate([picks_list.pop(0)[1], picks_list.pop(0)[1]])), + ), ) picks_list = [p for p in picks_list if len(p[1])] return picks_list @@ -944,56 +1112,133 @@ def _check_excludes_includes(chs, info=None, allow_bads=False): this will be the bad channels found in 'info'. """ from .meas_info import Info + if not isinstance(chs, (list, tuple, set, np.ndarray)): if allow_bads is True: if not isinstance(info, Info): - raise ValueError('Supply an info object if allow_bads is true') - elif chs != 'bads': + raise ValueError("Supply an info object if allow_bads is true") + elif chs != "bads": raise ValueError('If chs is a string, it must be "bads"') else: - chs = info['bads'] + chs = info["bads"] else: raise ValueError( - 'include/exclude must be list, tuple, ndarray, or "bads". ' + - 'You provided type {}'.format(type(chs))) + 'include/exclude must be list, tuple, ndarray, or "bads". ' + + "You provided type {}".format(type(chs)) + ) return chs _PICK_TYPES_DATA_DICT = dict( - meg=True, eeg=True, csd=True, stim=False, eog=False, ecg=False, emg=False, - misc=False, resp=False, chpi=False, exci=False, ias=False, syst=False, - seeg=True, dipole=False, gof=False, bio=False, ecog=True, fnirs=True, - dbs=True, temperature=False, gsr=False, eyetrack=True) -_PICK_TYPES_KEYS = tuple(list(_PICK_TYPES_DATA_DICT) + ['ref_meg']) -_MEG_CH_TYPES_SPLIT = ('mag', 'grad', 'planar1', 'planar2') -_FNIRS_CH_TYPES_SPLIT = ('hbo', 'hbr', 'fnirs_cw_amplitude', - 'fnirs_fd_ac_amplitude', 'fnirs_fd_phase', 'fnirs_od') -_EYETRACK_CH_TYPES_SPLIT = ('eyegaze', 'pupil') + meg=True, + eeg=True, + csd=True, + stim=False, + eog=False, + ecg=False, + emg=False, + misc=False, + resp=False, + chpi=False, + exci=False, + ias=False, + syst=False, + seeg=True, + dipole=False, + gof=False, + bio=False, + ecog=True, + fnirs=True, + dbs=True, + temperature=False, + gsr=False, + eyetrack=True, +) +_PICK_TYPES_KEYS = tuple(list(_PICK_TYPES_DATA_DICT) + ["ref_meg"]) +_MEG_CH_TYPES_SPLIT = ("mag", "grad", "planar1", "planar2") +_FNIRS_CH_TYPES_SPLIT = ( + "hbo", + "hbr", + "fnirs_cw_amplitude", + "fnirs_fd_ac_amplitude", + "fnirs_fd_phase", + "fnirs_od", +) +_EYETRACK_CH_TYPES_SPLIT = ("eyegaze", "pupil") _DATA_CH_TYPES_ORDER_DEFAULT = ( - 'mag', 'grad', 'eeg', 'csd', 'eog', 'ecg', 'resp', 'emg', 'ref_meg', - 'misc', 'stim', 'chpi', 'exci', 'ias', 'syst', 'seeg', 'bio', 'ecog', - 'dbs', 'temperature', 'gsr', 'gof', 'dipole', -) + _FNIRS_CH_TYPES_SPLIT + _EYETRACK_CH_TYPES_SPLIT + ('whitened',) + ( + "mag", + "grad", + "eeg", + "csd", + "eog", + "ecg", + "resp", + "emg", + "ref_meg", + "misc", + "stim", + "chpi", + "exci", + "ias", + "syst", + "seeg", + "bio", + "ecog", + "dbs", + "temperature", + "gsr", + "gof", + "dipole", + ) + + _FNIRS_CH_TYPES_SPLIT + + _EYETRACK_CH_TYPES_SPLIT + + ("whitened",) +) # Valid data types, ordered for consistency, used in viz/evoked. _VALID_CHANNEL_TYPES = ( - 'eeg', 'grad', 'mag', 'seeg', 'eog', 'ecg', 'resp', 'emg', 'dipole', 'gof', - 'bio', 'ecog', 'dbs' -) + _FNIRS_CH_TYPES_SPLIT + _EYETRACK_CH_TYPES_SPLIT + ('misc', 'csd') + ( + "eeg", + "grad", + "mag", + "seeg", + "eog", + "ecg", + "resp", + "emg", + "dipole", + "gof", + "bio", + "ecog", + "dbs", + ) + + _FNIRS_CH_TYPES_SPLIT + + _EYETRACK_CH_TYPES_SPLIT + + ("misc", "csd") +) _DATA_CH_TYPES_SPLIT = ( - 'mag', 'grad', 'eeg', 'csd', 'seeg', 'ecog', 'dbs' + "mag", + "grad", + "eeg", + "csd", + "seeg", + "ecog", + "dbs", ) + _FNIRS_CH_TYPES_SPLIT # Electrode types (e.g., can be average-referenced together or separately) -_ELECTRODE_CH_TYPES = ('eeg', 'ecog', 'seeg', 'dbs') +_ELECTRODE_CH_TYPES = ("eeg", "ecog", "seeg", "dbs") -def _electrode_types(info, *, exclude='bads'): - return [ch_type for ch_type in _ELECTRODE_CH_TYPES - if len(pick_types(info, exclude=exclude, **{ch_type: True}))] +def _electrode_types(info, *, exclude="bads"): + return [ + ch_type + for ch_type in _ELECTRODE_CH_TYPES + if len(pick_types(info, exclude=exclude, **{ch_type: True})) + ] -def _pick_data_channels(info, exclude='bads', with_ref_meg=True, - with_aux=False): +def _pick_data_channels(info, exclude="bads", with_ref_meg=True, with_aux=False): """Pick only data channels.""" kwargs = _PICK_TYPES_DATA_DICT if with_aux: @@ -1004,15 +1249,23 @@ def _pick_data_channels(info, exclude='bads', with_ref_meg=True, def _pick_data_or_ica(info, exclude=()): """Pick only data or ICA channels.""" - if any(ch_name.startswith('ICA') for ch_name in info['ch_names']): + if any(ch_name.startswith("ICA") for ch_name in info["ch_names"]): picks = pick_types(info, exclude=exclude, misc=True) else: picks = _pick_data_channels(info, exclude=exclude, with_ref_meg=True) return picks -def _picks_to_idx(info, picks, none='data', exclude='bads', allow_empty=False, - with_ref_meg=True, return_kind=False, picks_on="channels"): +def _picks_to_idx( + info, + picks, + none="data", + exclude="bads", + allow_empty=False, + with_ref_meg=True, + return_kind=False, + picks_on="channels", +): """Convert and check pick validity. Parameters @@ -1022,25 +1275,26 @@ def _picks_to_idx(info, picks, none='data', exclude='bads', allow_empty=False, 'components' for error messages about selection of components. """ from .meas_info import Info + picked_ch_type_or_generic = False # # None -> all, data, or data_or_ica (ndarray of int) # if isinstance(info, Info): - n_chan = info['nchan'] + n_chan = info["nchan"] else: - info = _ensure_int(info, 'info', 'an int or Info') + info = _ensure_int(info, "info", "an int or Info") n_chan = info assert n_chan >= 0 orig_picks = picks # We do some extra_repr gymnastics to avoid calling repr(orig_picks) too # soon as it can be a performance bottleneck (repr on ndarray is slow) - extra_repr = '' + extra_repr = "" if picks is None: if isinstance(info, int): # special wrapper for no real info picks = np.arange(n_chan) - extra_repr = ', treated as range(%d)' % (n_chan,) + extra_repr = ", treated as range(%d)" % (n_chan,) else: picks = none # let _picks_str_to_idx handle it extra_repr = 'None, treated as "%s"' % (none,) @@ -1057,15 +1311,22 @@ def _picks_to_idx(info, picks, none='data', exclude='bads', allow_empty=False, picks = np.atleast_1d(picks) # this works even for picks == 'something' picks = np.array([], dtype=int) if len(picks) == 0 else picks if picks.ndim != 1: - raise ValueError('picks must be 1D, got %sD' % (picks.ndim,)) - if picks.dtype.char in ('S', 'U'): - picks = _picks_str_to_idx(info, picks, exclude, with_ref_meg, - return_kind, extra_repr, allow_empty, - orig_picks) + raise ValueError("picks must be 1D, got %sD" % (picks.ndim,)) + if picks.dtype.char in ("S", "U"): + picks = _picks_str_to_idx( + info, + picks, + exclude, + with_ref_meg, + return_kind, + extra_repr, + allow_empty, + orig_picks, + ) if return_kind: picked_ch_type_or_generic = picks[1] picks = picks[0] - if picks.dtype.kind not in ['i', 'u']: + if picks.dtype.kind not in ["i", "u"]: extra_ch = " or list of str (names)" if picks_on == "channels" else "" msg = ( f"picks must be a list of int (indices){extra_ch}. " @@ -1079,27 +1340,31 @@ def _picks_to_idx(info, picks, none='data', exclude='bads', allow_empty=False, # ensure we have (optionally non-empty) ndarray of valid int # if len(picks) == 0 and not allow_empty: - raise ValueError('No appropriate %s found for the given picks ' - '(%r)' % (picks_on, orig_picks)) + raise ValueError( + "No appropriate %s found for the given picks " + "(%r)" % (picks_on, orig_picks) + ) if (picks < -n_chan).any(): - raise ValueError('All picks must be >= %d, got %r' - % (-n_chan, orig_picks)) + raise ValueError("All picks must be >= %d, got %r" % (-n_chan, orig_picks)) if (picks >= n_chan).any(): - raise ValueError('All picks must be < n_%s (%d), got %r' - % (picks_on, n_chan, orig_picks)) + raise ValueError( + "All picks must be < n_%s (%d), got %r" % (picks_on, n_chan, orig_picks) + ) picks %= n_chan # ensure positive if return_kind: return picks, picked_ch_type_or_generic return picks -def _picks_str_to_idx(info, picks, exclude, with_ref_meg, return_kind, - extra_repr, allow_empty, orig_picks): +def _picks_str_to_idx( + info, picks, exclude, with_ref_meg, return_kind, extra_repr, allow_empty, orig_picks +): """Turn a list of str into ndarray of int.""" # special case for _picks_to_idx w/no info: shouldn't really happen if isinstance(info, int): - raise ValueError('picks as str can only be used when measurement ' - 'info is available') + raise ValueError( + "picks as str can only be used when measurement " "info is available" + ) # # first: check our special cases @@ -1107,21 +1372,23 @@ def _picks_str_to_idx(info, picks, exclude, with_ref_meg, return_kind, picks_generic = list() if len(picks) == 1: - if picks[0] in ('all', 'data', 'data_or_ica'): - if picks[0] == 'all': - use_exclude = info['bads'] if exclude == 'bads' else exclude + if picks[0] in ("all", "data", "data_or_ica"): + if picks[0] == "all": + use_exclude = info["bads"] if exclude == "bads" else exclude picks_generic = pick_channels( - info['ch_names'], info['ch_names'], exclude=use_exclude) - elif picks[0] == 'data': - picks_generic = _pick_data_channels(info, exclude=exclude, - with_ref_meg=with_ref_meg) - elif picks[0] == 'data_or_ica': + info["ch_names"], info["ch_names"], exclude=use_exclude + ) + elif picks[0] == "data": + picks_generic = _pick_data_channels( + info, exclude=exclude, with_ref_meg=with_ref_meg + ) + elif picks[0] == "data_or_ica": picks_generic = _pick_data_or_ica(info, exclude=exclude) - if len(picks_generic) == 0 and orig_picks is None and \ - not allow_empty: - raise ValueError('picks (%s) yielded no channels, consider ' - 'passing picks explicitly' - % (repr(orig_picks) + extra_repr,)) + if len(picks_generic) == 0 and orig_picks is None and not allow_empty: + raise ValueError( + "picks (%s) yielded no channels, consider " + "passing picks explicitly" % (repr(orig_picks) + extra_repr,) + ) # # second: match all to channel names @@ -1131,7 +1398,7 @@ def _picks_str_to_idx(info, picks, exclude, with_ref_meg, return_kind, picks_name = list() for pick in picks: try: - picks_name.append(info['ch_names'].index(pick)) + picks_name.append(info["ch_names"].index(pick)) except ValueError: bad_names.append(pick) @@ -1157,18 +1424,19 @@ def _picks_str_to_idx(info, picks, exclude, with_ref_meg, return_kind, bad_type = list(picks) # triage MEG and FNIRS, which are complicated due to non-bool entries extra_picks = set() - if 'ref_meg' not in picks and not with_ref_meg: - kwargs['ref_meg'] = False - if len(meg) > 0 and not kwargs.get('meg', False): + if "ref_meg" not in picks and not with_ref_meg: + kwargs["ref_meg"] = False + if len(meg) > 0 and not kwargs.get("meg", False): # easiest just to iterate for use_meg in meg: - extra_picks |= set(pick_types( - info, meg=use_meg, ref_meg=False, exclude=exclude)) - if len(fnirs) > 0 and not kwargs.get('fnirs', False): + extra_picks |= set( + pick_types(info, meg=use_meg, ref_meg=False, exclude=exclude) + ) + if len(fnirs) > 0 and not kwargs.get("fnirs", False): if len(fnirs) == 1: - kwargs['fnirs'] = list(fnirs)[0] + kwargs["fnirs"] = list(fnirs)[0] else: - kwargs['fnirs'] = list(fnirs) + kwargs["fnirs"] = list(fnirs) picks_type = pick_types(info, exclude=exclude, **kwargs) if len(extra_picks) > 0: picks_type = sorted(set(picks_type) | set(extra_picks)) @@ -1181,23 +1449,27 @@ def _picks_str_to_idx(info, picks, exclude, with_ref_meg, return_kind, if sum(any_found) == 0: if not allow_empty: raise ValueError( - 'picks (%s) could not be interpreted as ' + "picks (%s) could not be interpreted as " 'channel names (no channel "%s"), channel types (no ' 'type "%s" present), or a generic type (just "all" or "data")' - % (repr(orig_picks) + extra_repr, str(bad_names), bad_type)) + % (repr(orig_picks) + extra_repr, str(bad_names), bad_type) + ) picks = np.array([], int) elif sum(any_found) > 1: - raise RuntimeError('Some channel names are ambiguously equivalent to ' - 'channel types, cannot use string-based ' - 'picks for these') + raise RuntimeError( + "Some channel names are ambiguously equivalent to " + "channel types, cannot use string-based " + "picks for these" + ) else: picks = np.array(all_picks[np.where(any_found)[0][0]]) picked_ch_type_or_generic = not len(picks_name) if len(bad_names) > 0 and not picked_ch_type_or_generic: raise ValueError( - f'Channel(s) {bad_names} could not be picked, because ' - 'they are not present in the info instance.') + f"Channel(s) {bad_names} could not be picked, because " + "they are not present in the info instance." + ) if return_kind: return picks, picked_ch_type_or_generic @@ -1209,12 +1481,11 @@ def _pick_inst(inst, picks, exclude, copy=True): if copy is True: inst = inst.copy() picks = _picks_to_idx(inst.info, picks, exclude=[]) - pick_names = [inst.info['ch_names'][pick] for pick in picks] + pick_names = [inst.info["ch_names"][pick] for pick in picks] inst.pick_channels(pick_names) - if exclude == 'bads': - exclude = [ch for ch in inst.info['bads'] - if ch in inst.info['ch_names']] + if exclude == "bads": + exclude = [ch for ch in inst.info["bads"] if ch in inst.info["ch_names"]] if exclude is not None: inst.drop_channels(exclude) return inst @@ -1222,12 +1493,11 @@ def _pick_inst(inst, picks, exclude, copy=True): def _get_channel_types(info, picks=None, unique=False, only_data_chs=False): """Get the data channel types in an info instance.""" - none = 'data' if only_data_chs else 'all' + none = "data" if only_data_chs else "all" picks = _picks_to_idx(info, picks, none, (), allow_empty=False) ch_types = [channel_type(info, pick) for pick in picks] if only_data_chs: - ch_types = [ch_type for ch_type in ch_types - if ch_type in _DATA_CH_TYPES_SPLIT] + ch_types = [ch_type for ch_type in ch_types if ch_type in _DATA_CH_TYPES_SPLIT] if unique: # set does not preserve order but dict does, so let's just use it ch_types = list({k: k for k in ch_types}.keys()) diff --git a/mne/io/proc_history.py b/mne/io/proc_history.py index 290730a2aeb..d62963a44c0 100644 --- a/mne/io/proc_history.py +++ b/mne/io/proc_history.py @@ -6,25 +6,41 @@ from .open import read_tag, fiff_open from .tree import dir_tree_find -from .write import (start_block, end_block, write_int, write_float, - write_string, write_float_matrix, write_int_matrix, - write_float_sparse, write_id, write_name_list_sanitized, - _safe_name_list) +from .write import ( + start_block, + end_block, + write_int, + write_float, + write_string, + write_float_matrix, + write_int_matrix, + write_float_sparse, + write_id, + write_name_list_sanitized, + _safe_name_list, +) from .tag import find_tag, _int_item, _float_item from .constants import FIFF from ..fixes import _csc_matrix_cast from ..utils import warn, _check_fname -_proc_keys = ['parent_file_id', 'block_id', 'parent_block_id', - 'date', 'experimenter', 'creator'] -_proc_ids = [FIFF.FIFF_PARENT_FILE_ID, - FIFF.FIFF_BLOCK_ID, - FIFF.FIFF_PARENT_BLOCK_ID, - FIFF.FIFF_MEAS_DATE, - FIFF.FIFF_EXPERIMENTER, - FIFF.FIFF_CREATOR] -_proc_writers = [write_id, write_id, write_id, - write_int, write_string, write_string] +_proc_keys = [ + "parent_file_id", + "block_id", + "parent_block_id", + "date", + "experimenter", + "creator", +] +_proc_ids = [ + FIFF.FIFF_PARENT_FILE_ID, + FIFF.FIFF_BLOCK_ID, + FIFF.FIFF_PARENT_BLOCK_ID, + FIFF.FIFF_MEAS_DATE, + FIFF.FIFF_EXPERIMENTER, + FIFF.FIFF_CREATOR, +] +_proc_writers = [write_id, write_id, write_id, write_int, write_string, write_string] _proc_casters = [dict, dict, dict, np.array, str, str] @@ -76,44 +92,42 @@ def _read_proc_history(fid, tree): out = list() if len(proc_history) > 0: proc_history = proc_history[0] - proc_records = dir_tree_find(proc_history, - FIFF.FIFFB_PROCESSING_RECORD) + proc_records = dir_tree_find(proc_history, FIFF.FIFFB_PROCESSING_RECORD) for proc_record in proc_records: record = dict() - for i_ent in range(proc_record['nent']): - kind = proc_record['directory'][i_ent].kind - pos = proc_record['directory'][i_ent].pos - for key, id_, cast in zip(_proc_keys, _proc_ids, - _proc_casters): + for i_ent in range(proc_record["nent"]): + kind = proc_record["directory"][i_ent].kind + pos = proc_record["directory"][i_ent].pos + for key, id_, cast in zip(_proc_keys, _proc_ids, _proc_casters): if kind == id_: tag = read_tag(fid, pos) record[key] = cast(tag.data) break else: - warn('Unknown processing history item %s' % kind) - record['max_info'] = _read_maxfilter_record(fid, proc_record) + warn("Unknown processing history item %s" % kind) + record["max_info"] = _read_maxfilter_record(fid, proc_record) iass = dir_tree_find(proc_record, FIFF.FIFFB_IAS) if len(iass) > 0: # XXX should eventually populate this ss = [dict() for _ in range(len(iass))] - record['ias'] = ss - if len(record['max_info']) > 0: + record["ias"] = ss + if len(record["max_info"]) > 0: out.append(record) return out def _write_proc_history(fid, info): """Write processing history to file.""" - if len(info['proc_history']) > 0: + if len(info["proc_history"]) > 0: start_block(fid, FIFF.FIFFB_PROCESSING_HISTORY) - for record in info['proc_history']: + for record in info["proc_history"]: start_block(fid, FIFF.FIFFB_PROCESSING_RECORD) for key, id_, writer in zip(_proc_keys, _proc_ids, _proc_writers): if key in record: writer(fid, id_, record[key]) - _write_maxfilter_record(fid, record['max_info']) - if 'ias' in record: - for _ in record['ias']: + _write_maxfilter_record(fid, record["max_info"]) + if "ias" in record: + for _ in record["ias"]: start_block(fid, FIFF.FIFFB_IAS) # XXX should eventually populate this end_block(fid, FIFF.FIFFB_IAS) @@ -121,41 +135,71 @@ def _write_proc_history(fid, info): end_block(fid, FIFF.FIFFB_PROCESSING_HISTORY) -_sss_info_keys = ('job', 'frame', 'origin', 'in_order', - 'out_order', 'nchan', 'components', 'nfree', - 'hpi_g_limit', 'hpi_dist_limit') -_sss_info_ids = (FIFF.FIFF_SSS_JOB, - FIFF.FIFF_SSS_FRAME, - FIFF.FIFF_SSS_ORIGIN, - FIFF.FIFF_SSS_ORD_IN, - FIFF.FIFF_SSS_ORD_OUT, - FIFF.FIFF_SSS_NMAG, - FIFF.FIFF_SSS_COMPONENTS, - FIFF.FIFF_SSS_NFREE, - FIFF.FIFF_HPI_FIT_GOOD_LIMIT, - FIFF.FIFF_HPI_FIT_DIST_LIMIT) -_sss_info_writers = (write_int, write_int, write_float, write_int, - write_int, write_int, write_int, write_int, - write_float, write_float) -_sss_info_casters = (_int_item, _int_item, np.array, _int_item, - _int_item, _int_item, np.array, _int_item, - _float_item, _float_item) +_sss_info_keys = ( + "job", + "frame", + "origin", + "in_order", + "out_order", + "nchan", + "components", + "nfree", + "hpi_g_limit", + "hpi_dist_limit", +) +_sss_info_ids = ( + FIFF.FIFF_SSS_JOB, + FIFF.FIFF_SSS_FRAME, + FIFF.FIFF_SSS_ORIGIN, + FIFF.FIFF_SSS_ORD_IN, + FIFF.FIFF_SSS_ORD_OUT, + FIFF.FIFF_SSS_NMAG, + FIFF.FIFF_SSS_COMPONENTS, + FIFF.FIFF_SSS_NFREE, + FIFF.FIFF_HPI_FIT_GOOD_LIMIT, + FIFF.FIFF_HPI_FIT_DIST_LIMIT, +) +_sss_info_writers = ( + write_int, + write_int, + write_float, + write_int, + write_int, + write_int, + write_int, + write_int, + write_float, + write_float, +) +_sss_info_casters = ( + _int_item, + _int_item, + np.array, + _int_item, + _int_item, + _int_item, + np.array, + _int_item, + _float_item, + _float_item, +) -_max_st_keys = ('job', 'subspcorr', 'buflen') -_max_st_ids = (FIFF.FIFF_SSS_JOB, FIFF.FIFF_SSS_ST_CORR, - FIFF.FIFF_SSS_ST_LENGTH) +_max_st_keys = ("job", "subspcorr", "buflen") +_max_st_ids = (FIFF.FIFF_SSS_JOB, FIFF.FIFF_SSS_ST_CORR, FIFF.FIFF_SSS_ST_LENGTH) _max_st_writers = (write_int, write_float, write_float) _max_st_casters = (_int_item, _float_item, _float_item) -_sss_ctc_keys = ('block_id', 'date', 'creator', 'decoupler') -_sss_ctc_ids = (FIFF.FIFF_BLOCK_ID, - FIFF.FIFF_MEAS_DATE, - FIFF.FIFF_CREATOR, - FIFF.FIFF_DECOUPLER_MATRIX) +_sss_ctc_keys = ("block_id", "date", "creator", "decoupler") +_sss_ctc_ids = ( + FIFF.FIFF_BLOCK_ID, + FIFF.FIFF_MEAS_DATE, + FIFF.FIFF_CREATOR, + FIFF.FIFF_DECOUPLER_MATRIX, +) _sss_ctc_writers = (write_id, write_int, write_string, write_float_sparse) _sss_ctc_casters = (dict, np.array, str, _csc_matrix_cast) -_sss_cal_keys = ('cal_chans', 'cal_corrs') +_sss_cal_keys = ("cal_chans", "cal_corrs") _sss_cal_ids = (FIFF.FIFF_SSS_CAL_CHANS, FIFF.FIFF_SSS_CAL_CORRS) _sss_cal_writers = (write_int_matrix, write_float_matrix) _sss_cal_casters = (np.array, np.array) @@ -163,19 +207,19 @@ def _write_proc_history(fid, info): def _read_ctc(fname): """Read cross-talk correction matrix.""" - fname = _check_fname(fname, overwrite='read', must_exist=True) + fname = _check_fname(fname, overwrite="read", must_exist=True) f, tree, _ = fiff_open(fname) with f as fid: - sss_ctc = _read_maxfilter_record(fid, tree)['sss_ctc'] - bad_str = 'Invalid cross-talk FIF: %s' % fname + sss_ctc = _read_maxfilter_record(fid, tree)["sss_ctc"] + bad_str = "Invalid cross-talk FIF: %s" % fname if len(sss_ctc) == 0: raise ValueError(bad_str) node = dir_tree_find(tree, FIFF.FIFFB_DATA_CORRECTION)[0] comment = find_tag(fid, node, FIFF.FIFF_COMMENT).data - if comment != 'cross-talk compensation matrix': + if comment != "cross-talk compensation matrix": raise ValueError(bad_str) - sss_ctc['creator'] = find_tag(fid, node, FIFF.FIFF_CREATOR).data - sss_ctc['date'] = find_tag(fid, node, FIFF.FIFF_MEAS_DATE).data + sss_ctc["creator"] = find_tag(fid, node, FIFF.FIFF_CREATOR).data + sss_ctc["date"] = find_tag(fid, node, FIFF.FIFF_MEAS_DATE).data return sss_ctc @@ -185,11 +229,10 @@ def _read_maxfilter_record(fid, tree): sss_info = dict() if len(sss_info_block) > 0: sss_info_block = sss_info_block[0] - for i_ent in range(sss_info_block['nent']): - kind = sss_info_block['directory'][i_ent].kind - pos = sss_info_block['directory'][i_ent].pos - for key, id_, cast in zip(_sss_info_keys, _sss_info_ids, - _sss_info_casters): + for i_ent in range(sss_info_block["nent"]): + kind = sss_info_block["directory"][i_ent].kind + pos = sss_info_block["directory"][i_ent].pos + for key, id_, cast in zip(_sss_info_keys, _sss_info_ids, _sss_info_casters): if kind == id_: tag = read_tag(fid, pos) sss_info[key] = cast(tag.data) @@ -199,11 +242,10 @@ def _read_maxfilter_record(fid, tree): max_st = dict() if len(max_st_block) > 0: max_st_block = max_st_block[0] - for i_ent in range(max_st_block['nent']): - kind = max_st_block['directory'][i_ent].kind - pos = max_st_block['directory'][i_ent].pos - for key, id_, cast in zip(_max_st_keys, _max_st_ids, - _max_st_casters): + for i_ent in range(max_st_block["nent"]): + kind = max_st_block["directory"][i_ent].kind + pos = max_st_block["directory"][i_ent].pos + for key, id_, cast in zip(_max_st_keys, _max_st_ids, _max_st_casters): if kind == id_: tag = read_tag(fid, pos) max_st[key] = cast(tag.data) @@ -213,11 +255,10 @@ def _read_maxfilter_record(fid, tree): sss_ctc = dict() if len(sss_ctc_block) > 0: sss_ctc_block = sss_ctc_block[0] - for i_ent in range(sss_ctc_block['nent']): - kind = sss_ctc_block['directory'][i_ent].kind - pos = sss_ctc_block['directory'][i_ent].pos - for key, id_, cast in zip(_sss_ctc_keys, _sss_ctc_ids, - _sss_ctc_casters): + for i_ent in range(sss_ctc_block["nent"]): + kind = sss_ctc_block["directory"][i_ent].kind + pos = sss_ctc_block["directory"][i_ent].pos + for key, id_, cast in zip(_sss_ctc_keys, _sss_ctc_ids, _sss_ctc_casters): if kind == id_: tag = read_tag(fid, pos) sss_ctc[key] = cast(tag.data) @@ -225,69 +266,66 @@ def _read_maxfilter_record(fid, tree): else: if kind == FIFF.FIFF_PROJ_ITEM_CH_NAME_LIST: tag = read_tag(fid, pos) - chs = _safe_name_list(tag.data, 'read', 'proj_items_chs') + chs = _safe_name_list(tag.data, "read", "proj_items_chs") # This list can null chars in the last entry, e.g.: # [..., 'MEG2642', 'MEG2643', 'MEG2641\x00 ... \x00'] - chs[-1] = chs[-1].split('\x00')[0] - sss_ctc['proj_items_chs'] = chs + chs[-1] = chs[-1].split("\x00")[0] + sss_ctc["proj_items_chs"] = chs sss_cal_block = dir_tree_find(tree, FIFF.FIFFB_SSS_CAL) # 503 sss_cal = dict() if len(sss_cal_block) > 0: sss_cal_block = sss_cal_block[0] - for i_ent in range(sss_cal_block['nent']): - kind = sss_cal_block['directory'][i_ent].kind - pos = sss_cal_block['directory'][i_ent].pos - for key, id_, cast in zip(_sss_cal_keys, _sss_cal_ids, - _sss_cal_casters): + for i_ent in range(sss_cal_block["nent"]): + kind = sss_cal_block["directory"][i_ent].kind + pos = sss_cal_block["directory"][i_ent].pos + for key, id_, cast in zip(_sss_cal_keys, _sss_cal_ids, _sss_cal_casters): if kind == id_: tag = read_tag(fid, pos) sss_cal[key] = cast(tag.data) break - max_info = dict(sss_info=sss_info, sss_ctc=sss_ctc, - sss_cal=sss_cal, max_st=max_st) + max_info = dict(sss_info=sss_info, sss_ctc=sss_ctc, sss_cal=sss_cal, max_st=max_st) return max_info def _write_maxfilter_record(fid, record): """Write maxfilter processing record to file.""" - sss_info = record['sss_info'] + sss_info = record["sss_info"] if len(sss_info) > 0: start_block(fid, FIFF.FIFFB_SSS_INFO) - for key, id_, writer in zip(_sss_info_keys, _sss_info_ids, - _sss_info_writers): + for key, id_, writer in zip(_sss_info_keys, _sss_info_ids, _sss_info_writers): if key in sss_info: writer(fid, id_, sss_info[key]) end_block(fid, FIFF.FIFFB_SSS_INFO) - max_st = record['max_st'] + max_st = record["max_st"] if len(max_st) > 0: start_block(fid, FIFF.FIFFB_SSS_ST_INFO) - for key, id_, writer in zip(_max_st_keys, _max_st_ids, - _max_st_writers): + for key, id_, writer in zip(_max_st_keys, _max_st_ids, _max_st_writers): if key in max_st: writer(fid, id_, max_st[key]) end_block(fid, FIFF.FIFFB_SSS_ST_INFO) - sss_ctc = record['sss_ctc'] + sss_ctc = record["sss_ctc"] if len(sss_ctc) > 0: # dict has entries start_block(fid, FIFF.FIFFB_CHANNEL_DECOUPLER) - for key, id_, writer in zip(_sss_ctc_keys, _sss_ctc_ids, - _sss_ctc_writers): + for key, id_, writer in zip(_sss_ctc_keys, _sss_ctc_ids, _sss_ctc_writers): if key in sss_ctc: writer(fid, id_, sss_ctc[key]) - if 'proj_items_chs' in sss_ctc: + if "proj_items_chs" in sss_ctc: write_name_list_sanitized( - fid, FIFF.FIFF_PROJ_ITEM_CH_NAME_LIST, - sss_ctc['proj_items_chs'], 'proj_items_chs') + fid, + FIFF.FIFF_PROJ_ITEM_CH_NAME_LIST, + sss_ctc["proj_items_chs"], + "proj_items_chs", + ) end_block(fid, FIFF.FIFFB_CHANNEL_DECOUPLER) - sss_cal = record['sss_cal'] + sss_cal = record["sss_cal"] if len(sss_cal) > 0: start_block(fid, FIFF.FIFFB_SSS_CAL) - for key, id_, writer in zip(_sss_cal_keys, _sss_cal_ids, - _sss_cal_writers): + for key, id_, writer in zip(_sss_cal_keys, _sss_cal_ids, _sss_cal_writers): if key in sss_cal: writer(fid, id_, sss_cal[key]) end_block(fid, FIFF.FIFFB_SSS_CAL) diff --git a/mne/io/proj.py b/mne/io/proj.py index e874c47ebd5..1cf2e61a0e1 100644 --- a/mne/io/proj.py +++ b/mne/io/proj.py @@ -15,13 +15,26 @@ from .pick import pick_types, pick_info, _electrode_types, _ELECTRODE_CH_TYPES from .tag import find_tag, _rename_list from .tree import dir_tree_find -from .write import (write_int, write_float, write_string, write_float_matrix, - end_block, start_block, write_name_list_sanitized, - _safe_name_list) -from ..defaults import (_INTERPOLATION_DEFAULT, _BORDER_DEFAULT, - _EXTRAPOLATE_DEFAULT) -from ..utils import (logger, verbose, warn, fill_doc, _validate_type, - object_diff, _check_option) +from .write import ( + write_int, + write_float, + write_string, + write_float_matrix, + end_block, + start_block, + write_name_list_sanitized, + _safe_name_list, +) +from ..defaults import _INTERPOLATION_DEFAULT, _BORDER_DEFAULT, _EXTRAPOLATE_DEFAULT +from ..utils import ( + logger, + verbose, + warn, + fill_doc, + _validate_type, + object_diff, + _check_option, +) class Projection(dict): @@ -46,16 +59,24 @@ class Projection(dict): The explained variance (proportion). """ - def __init__(self, *, data, desc='', kind=FIFF.FIFFV_PROJ_ITEM_FIELD, - active=False, explained_var=None): - super().__init__(desc=desc, kind=kind, active=active, data=data, - explained_var=explained_var) + def __init__( + self, + *, + data, + desc="", + kind=FIFF.FIFFV_PROJ_ITEM_FIELD, + active=False, + explained_var=None, + ): + super().__init__( + desc=desc, kind=kind, active=active, data=data, explained_var=explained_var + ) def __repr__(self): # noqa: D105 - s = "%s" % self['desc'] - s += ", active : %s" % self['active'] + s = "%s" % self["desc"] + s += ", active : %s" % self["active"] s += f", n_channels : {len(self['data']['col_names'])}" - if self['explained_var'] is not None: + if self["explained_var"] is not None: s += f', exp. var : {self["explained_var"] * 100:0.2f}%' return "" % s @@ -65,9 +86,9 @@ def __deepcopy__(self, memodict): cls = self.__class__ result = cls.__new__(cls) for k, v in self.items(): - if k == 'data': + if k == "data": v = v.copy() - v['data'] = v['data'].copy() + v["data"] = v["data"].copy() result[k] = v else: result[k] = v # kind, active, desc, explained_var immutable @@ -83,11 +104,28 @@ def __ne__(self, other): @fill_doc def plot_topomap( - self, info, *, sensors=True, show_names=False, contours=6, - outlines='head', sphere=None, image_interp=_INTERPOLATION_DEFAULT, - extrapolate=_EXTRAPOLATE_DEFAULT, border=_BORDER_DEFAULT, res=64, - size=1, cmap=None, vlim=(None, None), cnorm=None, colorbar=False, - cbar_fmt='%3.1f', units=None, axes=None, show=True): + self, + info, + *, + sensors=True, + show_names=False, + contours=6, + outlines="head", + sphere=None, + image_interp=_INTERPOLATION_DEFAULT, + extrapolate=_EXTRAPOLATE_DEFAULT, + border=_BORDER_DEFAULT, + res=64, + size=1, + cmap=None, + vlim=(None, None), + cnorm=None, + colorbar=False, + cbar_fmt="%3.1f", + units=None, + axes=None, + show=True, + ): """Plot topographic maps of SSP projections. Parameters @@ -132,13 +170,29 @@ def plot_topomap( .. versionadded:: 0.15.0 """ # noqa: E501 from ..viz.topomap import plot_projs_topomap + return plot_projs_topomap( - self, info, sensors=sensors, show_names=show_names, - contours=contours, outlines=outlines, sphere=sphere, - image_interp=image_interp, extrapolate=extrapolate, border=border, - res=res, size=size, cmap=cmap, vlim=vlim, cnorm=cnorm, - colorbar=colorbar, cbar_fmt=cbar_fmt, units=units, axes=axes, - show=show) + self, + info, + sensors=sensors, + show_names=show_names, + contours=contours, + outlines=outlines, + sphere=sphere, + image_interp=image_interp, + extrapolate=extrapolate, + border=border, + res=res, + size=size, + cmap=cmap, + vlim=vlim, + cnorm=cnorm, + colorbar=colorbar, + cbar_fmt=cbar_fmt, + units=units, + axes=axes, + show=show, + ) class ProjMixin: @@ -170,8 +224,9 @@ class ProjMixin: @property def proj(self): """Whether or not projections are active.""" - return (len(self.info['projs']) > 0 and - all(p['active'] for p in self.info['projs'])) + return len(self.info["projs"]) > 0 and all( + p["active"] for p in self.info["projs"] + ) @verbose def add_proj(self, projs, remove_existing=False, verbose=None): @@ -193,27 +248,28 @@ def add_proj(self, projs, remove_existing=False, verbose=None): if isinstance(projs, Projection): projs = [projs] - if (not isinstance(projs, list) and - not all(isinstance(p, Projection) for p in projs)): - raise ValueError('Only projs can be added. You supplied ' - 'something else.') + if not isinstance(projs, list) and not all( + isinstance(p, Projection) for p in projs + ): + raise ValueError("Only projs can be added. You supplied " "something else.") # mark proj as inactive, as they have not been applied projs = deactivate_proj(projs, copy=True) if remove_existing: # we cannot remove the proj if they are active - if any(p['active'] for p in self.info['projs']): - raise ValueError('Cannot remove projectors that have ' - 'already been applied') + if any(p["active"] for p in self.info["projs"]): + raise ValueError( + "Cannot remove projectors that have " "already been applied" + ) with self.info._unlock(): - self.info['projs'] = projs + self.info["projs"] = projs else: - self.info['projs'].extend(projs) + self.info["projs"].extend(projs) # We don't want to add projectors that are activated again. with self.info._unlock(): - self.info['projs'] = _uniquify_projs(self.info['projs'], - check_active=False, - sort=False) + self.info["projs"] = _uniquify_projs( + self.info["projs"], check_active=False, sort=False + ) return self @verbose @@ -252,27 +308,32 @@ def apply_proj(self, verbose=None): from ..epochs import BaseEpochs from ..evoked import Evoked from .base import BaseRaw - if self.info['projs'] is None or len(self.info['projs']) == 0: - logger.info('No projector specified for this dataset. ' - 'Please consider the method self.add_proj.') + + if self.info["projs"] is None or len(self.info["projs"]) == 0: + logger.info( + "No projector specified for this dataset. " + "Please consider the method self.add_proj." + ) return self # Exit delayed mode if you apply proj if isinstance(self, BaseEpochs) and self._do_delayed_proj: - logger.info('Leaving delayed SSP mode.') + logger.info("Leaving delayed SSP mode.") self._do_delayed_proj = False - if all(p['active'] for p in self.info['projs']): - logger.info('Projections have already been applied. ' - 'Setting proj attribute to True.') + if all(p["active"] for p in self.info["projs"]): + logger.info( + "Projections have already been applied. " + "Setting proj attribute to True." + ) return self - _projector, info = setup_proj(deepcopy(self.info), add_eeg_ref=False, - activate=True) + _projector, info = setup_proj( + deepcopy(self.info), add_eeg_ref=False, activate=True + ) # let's not raise a RuntimeError here, otherwise interactive plotting if _projector is None: # won't be fun. - logger.info('The projections don\'t apply to these data.' - ' Doing nothing.') + logger.info("The projections don't apply to these data." " Doing nothing.") return self self._projector, self.info = _projector, info if isinstance(self, (BaseRaw, Evoked)): @@ -284,10 +345,10 @@ def apply_proj(self, verbose=None): self._data[ii] = self._project_epoch(e) else: self.load_data() # will automatically apply - logger.info('SSP projectors applied...') + logger.info("SSP projectors applied...") return self - def del_proj(self, idx='all'): + def del_proj(self, idx="all"): """Remove SSP projection vector. .. note:: The projection vector can only be removed if it is inactive @@ -304,34 +365,52 @@ def del_proj(self, idx='all'): self : instance of Raw | Epochs | Evoked The instance. """ - if isinstance(idx, str) and idx == 'all': - idx = list(range(len(self.info['projs']))) + if isinstance(idx, str) and idx == "all": + idx = list(range(len(self.info["projs"]))) idx = np.atleast_1d(np.array(idx, int)).ravel() for ii in idx: - proj = self.info['projs'][ii] - if (proj['active'] and - set(self.info['ch_names']) & - set(proj['data']['col_names'])): - msg = (f'Cannot remove projector that has already been ' - f'applied, unless you first remove all channels it ' - f'applies to. The problematic projector is: {proj}') + proj = self.info["projs"][ii] + if proj["active"] and set(self.info["ch_names"]) & set( + proj["data"]["col_names"] + ): + msg = ( + f"Cannot remove projector that has already been " + f"applied, unless you first remove all channels it " + f"applies to. The problematic projector is: {proj}" + ) raise ValueError(msg) - keep = np.ones(len(self.info['projs'])) + keep = np.ones(len(self.info["projs"])) keep[idx] = False # works with negative indexing and does checks with self.info._unlock(): - self.info['projs'] = [p for p, k in zip(self.info['projs'], keep) - if k] + self.info["projs"] = [p for p, k in zip(self.info["projs"], keep) if k] return self @fill_doc def plot_projs_topomap( - self, ch_type=None, *, sensors=True, show_names=False, contours=6, - outlines='head', sphere=None, image_interp=_INTERPOLATION_DEFAULT, - extrapolate=_EXTRAPOLATE_DEFAULT, border=_BORDER_DEFAULT, res=64, - size=1, cmap=None, vlim=(None, None), cnorm=None, colorbar=False, - cbar_fmt='%3.1f', units=None, axes=None, show=True): + self, + ch_type=None, + *, + sensors=True, + show_names=False, + contours=6, + outlines="head", + sphere=None, + image_interp=_INTERPOLATION_DEFAULT, + extrapolate=_EXTRAPOLATE_DEFAULT, + border=_BORDER_DEFAULT, + res=64, + size=1, + cmap=None, + vlim=(None, None), + cnorm=None, + colorbar=False, + cbar_fmt="%3.1f", + units=None, + axes=None, + show=True, + ): """Plot SSP vector. Parameters @@ -372,49 +451,69 @@ def plot_projs_topomap( fig : instance of Figure Figure distributing one image per channel across sensor topography. """ - _projs = [deepcopy(_proj) for _proj in self.info['projs']] + _projs = [deepcopy(_proj) for _proj in self.info["projs"]] if _projs is None or len(_projs) == 0: - raise ValueError('No projectors in Info; nothing to plot.') + raise ValueError("No projectors in Info; nothing to plot.") if ch_type is not None: # make sure the requested channel type(s) exist - _validate_type(ch_type, (str, list, tuple), 'ch_type') + _validate_type(ch_type, (str, list, tuple), "ch_type") if isinstance(ch_type, str): ch_type = [ch_type] bad_ch_types = [_type not in self for _type in ch_type] if any(bad_ch_types): - raise ValueError(f'ch_type {ch_type[bad_ch_types]} not ' - f'present in {self.__class__.__name__}.') + raise ValueError( + f"ch_type {ch_type[bad_ch_types]} not " + f"present in {self.__class__.__name__}." + ) # remove projs from unrequested channel types. This is a bit # convoluted because Projection objects don't store channel types, # only channel names available_ch_types = np.array(self.get_channel_types()) for _proj in _projs[::-1]: - idx = np.isin(self.ch_names, _proj['data']['col_names']) + idx = np.isin(self.ch_names, _proj["data"]["col_names"]) proj_ch_type = np.unique(available_ch_types[idx]) - err_msg = 'Projector contains multiple channel types' + err_msg = "Projector contains multiple channel types" assert len(proj_ch_type) == 1, err_msg if proj_ch_type[0] != ch_type: _projs.remove(_proj) if len(_projs) == 0: - raise ValueError('Nothing to plot (no projectors for channel ' - f'type {ch_type}).') + raise ValueError( + "Nothing to plot (no projectors for channel " f"type {ch_type})." + ) # now we have non-empty _projs list with correct channel type(s) from ..viz.topomap import plot_projs_topomap + fig = plot_projs_topomap( - _projs, self.info, sensors=sensors, show_names=show_names, - contours=contours, outlines=outlines, sphere=sphere, - image_interp=image_interp, extrapolate=extrapolate, - border=border, res=res, size=size, cmap=cmap, vlim=vlim, - cnorm=cnorm, colorbar=colorbar, cbar_fmt=cbar_fmt, - units=units, axes=axes, show=show) + _projs, + self.info, + sensors=sensors, + show_names=show_names, + contours=contours, + outlines=outlines, + sphere=sphere, + image_interp=image_interp, + extrapolate=extrapolate, + border=border, + res=res, + size=size, + cmap=cmap, + vlim=vlim, + cnorm=cnorm, + colorbar=colorbar, + cbar_fmt=cbar_fmt, + units=units, + axes=axes, + show=show, + ) return fig - def _reconstruct_proj(self, mode='accurate', origin='auto'): + def _reconstruct_proj(self, mode="accurate", origin="auto"): from ..forward import _map_meg_or_eeg_channels - if len(self.info['projs']) == 0: + + if len(self.info["projs"]) == 0: return self self.apply_proj() - for kind in ('meg', 'eeg'): + for kind in ("meg", "eeg"): kwargs = dict(meg=False) kwargs[kind] = True picks = pick_types(self.info, **kwargs) @@ -423,27 +522,30 @@ def _reconstruct_proj(self, mode='accurate', origin='auto'): info_from = pick_info(self.info, picks) info_to = info_from.copy() with info_to._unlock(): - info_to['projs'] = [] - if kind == 'eeg' and _has_eeg_average_ref_proj(info_from): - info_to['projs'] = [ - make_eeg_average_ref_proj(info_to, verbose=False)] + info_to["projs"] = [] + if kind == "eeg" and _has_eeg_average_ref_proj(info_from): + info_to["projs"] = [ + make_eeg_average_ref_proj(info_to, verbose=False) + ] mapping = _map_meg_or_eeg_channels( - info_from, info_to, mode=mode, origin=origin) - self.data[..., picks, :] = np.matmul( - mapping, self.data[..., picks, :]) + info_from, info_to, mode=mode, origin=origin + ) + self.data[..., picks, :] = np.matmul(mapping, self.data[..., picks, :]) return self def _proj_equal(a, b, check_active=True): """Test if two projectors are equal.""" - equal = ((a['active'] == b['active'] or not check_active) and - a['kind'] == b['kind'] and - a['desc'] == b['desc'] and - a['data']['col_names'] == b['data']['col_names'] and - a['data']['row_names'] == b['data']['row_names'] and - a['data']['ncol'] == b['data']['ncol'] and - a['data']['nrow'] == b['data']['nrow'] and - np.all(a['data']['data'] == b['data']['data'])) + equal = ( + (a["active"] == b["active"] or not check_active) + and a["kind"] == b["kind"] + and a["desc"] == b["desc"] + and a["data"]["col_names"] == b["data"]["col_names"] + and a["data"]["row_names"] == b["data"]["row_names"] + and a["data"]["ncol"] == b["data"]["ncol"] + and a["data"]["nrow"] == b["data"]["nrow"] + and np.all(a["data"]["data"] == b["data"]["data"]) + ) return equal @@ -483,31 +585,31 @@ def _read_proj(fid, node, *, ch_names_mapping=None, verbose=None): if tag is not None: desc = tag.data else: - raise ValueError('Projection item description missing') + raise ValueError("Projection item description missing") tag = find_tag(fid, item, FIFF.FIFF_PROJ_ITEM_KIND) if tag is not None: kind = int(tag.data.item()) else: - raise ValueError('Projection item kind missing') + raise ValueError("Projection item kind missing") tag = find_tag(fid, item, FIFF.FIFF_PROJ_ITEM_NVEC) if tag is not None: nvec = int(tag.data.item()) else: - raise ValueError('Number of projection vectors not specified') + raise ValueError("Number of projection vectors not specified") tag = find_tag(fid, item, FIFF.FIFF_PROJ_ITEM_CH_NAME_LIST) if tag is not None: - names = _safe_name_list(tag.data, 'read', 'names') + names = _safe_name_list(tag.data, "read", "names") else: - raise ValueError('Projection item channel list missing') + raise ValueError("Projection item channel list missing") tag = find_tag(fid, item, FIFF.FIFF_PROJ_ITEM_VECTORS) if tag is not None: data = tag.data else: - raise ValueError('Projection item data missing') + raise ValueError("Projection item data missing") tag = find_tag(fid, item, FIFF.FIFF_MNE_PROJ_ITEM_ACTIVE) if tag is not None: @@ -526,28 +628,36 @@ def _read_proj(fid, node, *, ch_names_mapping=None, verbose=None): data = data.T if data.shape[1] != len(names): - raise ValueError('Number of channel names does not match the ' - 'size of data matrix') + raise ValueError( + "Number of channel names does not match the " "size of data matrix" + ) # just always use this, we used to have bugs with writing the # number correctly... nchan = len(names) names[:] = _rename_list(names, ch_names_mapping) # Use exactly the same fields in data as in a named matrix - one = Projection(kind=kind, active=active, desc=desc, - data=dict(nrow=nvec, ncol=nchan, row_names=None, - col_names=names, data=data), - explained_var=explained_var) + one = Projection( + kind=kind, + active=active, + desc=desc, + data=dict( + nrow=nvec, ncol=nchan, row_names=None, col_names=names, data=data + ), + explained_var=explained_var, + ) projs.append(one) if len(projs) > 0: - logger.info(' Read a total of %d projection items:' % len(projs)) + logger.info(" Read a total of %d projection items:" % len(projs)) for proj in projs: - misc = 'active' if proj['active'] else ' idle' - logger.info(f' {proj["desc"]} ' - f'({proj["data"]["nrow"]} x ' - f'{len(proj["data"]["col_names"])}) {misc}') + misc = "active" if proj["active"] else " idle" + logger.info( + f' {proj["desc"]} ' + f'({proj["data"]["nrow"]} x ' + f'{len(proj["data"]["col_names"])}) {misc}' + ) return projs @@ -555,6 +665,7 @@ def _read_proj(fid, node, *, ch_names_mapping=None, verbose=None): ############################################################################### # Write + def _write_proj(fid, projs, *, ch_names_mapping=None): """Write a projection operator to a file. @@ -570,30 +681,29 @@ def _write_proj(fid, projs, *, ch_names_mapping=None): ch_names_mapping = dict() if ch_names_mapping is None else ch_names_mapping # validation - _validate_type(projs, (list, tuple), 'projs') + _validate_type(projs, (list, tuple), "projs") for pi, proj in enumerate(projs): - _validate_type(proj, Projection, f'projs[{pi}]') + _validate_type(proj, Projection, f"projs[{pi}]") start_block(fid, FIFF.FIFFB_PROJ) for proj in projs: start_block(fid, FIFF.FIFFB_PROJ_ITEM) - write_int(fid, FIFF.FIFF_NCHAN, len(proj['data']['col_names'])) - names = _rename_list(proj['data']['col_names'], ch_names_mapping) + write_int(fid, FIFF.FIFF_NCHAN, len(proj["data"]["col_names"])) + names = _rename_list(proj["data"]["col_names"], ch_names_mapping) write_name_list_sanitized( - fid, FIFF.FIFF_PROJ_ITEM_CH_NAME_LIST, names, 'col_names') - write_string(fid, FIFF.FIFF_NAME, proj['desc']) - write_int(fid, FIFF.FIFF_PROJ_ITEM_KIND, proj['kind']) - if proj['kind'] == FIFF.FIFFV_PROJ_ITEM_FIELD: + fid, FIFF.FIFF_PROJ_ITEM_CH_NAME_LIST, names, "col_names" + ) + write_string(fid, FIFF.FIFF_NAME, proj["desc"]) + write_int(fid, FIFF.FIFF_PROJ_ITEM_KIND, proj["kind"]) + if proj["kind"] == FIFF.FIFFV_PROJ_ITEM_FIELD: write_float(fid, FIFF.FIFF_PROJ_ITEM_TIME, 0.0) - write_int(fid, FIFF.FIFF_PROJ_ITEM_NVEC, proj['data']['nrow']) - write_int(fid, FIFF.FIFF_MNE_PROJ_ITEM_ACTIVE, proj['active']) - write_float_matrix(fid, FIFF.FIFF_PROJ_ITEM_VECTORS, - proj['data']['data']) - if proj['explained_var'] is not None: - write_float(fid, FIFF.FIFF_MNE_ICA_PCA_EXPLAINED_VAR, - proj['explained_var']) + write_int(fid, FIFF.FIFF_PROJ_ITEM_NVEC, proj["data"]["nrow"]) + write_int(fid, FIFF.FIFF_MNE_PROJ_ITEM_ACTIVE, proj["active"]) + write_float_matrix(fid, FIFF.FIFF_PROJ_ITEM_VECTORS, proj["data"]["data"]) + if proj["explained_var"] is not None: + write_float(fid, FIFF.FIFF_MNE_ICA_PCA_EXPLAINED_VAR, proj["explained_var"]) end_block(fid, FIFF.FIFFB_PROJ_ITEM) end_block(fid, FIFF.FIFFB_PROJ) @@ -602,16 +712,17 @@ def _write_proj(fid, projs, *, ch_names_mapping=None): ############################################################################### # Utils + def _check_projs(projs, copy=True): """Check that projs is a list of Projection.""" if not isinstance(projs, (list, tuple)): - raise TypeError('projs must be a list or tuple, got %s' - % (type(projs),)) + raise TypeError("projs must be a list or tuple, got %s" % (type(projs),)) for pi, p in enumerate(projs): if not isinstance(p, Projection): - raise TypeError('All entries in projs list must be Projection ' - 'instances, but projs[%d] is type %s' - % (pi, type(p))) + raise TypeError( + "All entries in projs list must be Projection " + "instances, but projs[%d] is type %s" % (pi, type(p)) + ) return deepcopy(projs) if copy else projs @@ -644,8 +755,7 @@ def make_projector(projs, ch_names, bads=(), include_active=True): return _make_projector(projs, ch_names, bads, include_active) -def _make_projector(projs, ch_names, bads=(), include_active=True, - inplace=False): +def _make_projector(projs, ch_names, bads=(), include_active=True, inplace=False): """Subselect projs based on ch_names and bads. Use inplace=True mode to modify ``projs`` inplace so that no @@ -653,9 +763,10 @@ def _make_projector(projs, ch_names, bads=(), include_active=True, the given inputs. If inplace=True, no meaningful data are returned. """ from scipy import linalg + nchan = len(ch_names) if nchan == 0: - raise ValueError('No channel names specified') + raise ValueError("No channel names specified") default_return = (np.eye(nchan, nchan), 0, np.empty((nchan, 0))) @@ -666,9 +777,9 @@ def _make_projector(projs, ch_names, bads=(), include_active=True, nvec = 0 nproj = 0 for p in projs: - if not p['active'] or include_active: + if not p["active"] or include_active: nproj += 1 - nvec += p['data']['nrow'] + nvec += p["data"]["nrow"] if nproj == 0: return default_return @@ -679,65 +790,71 @@ def _make_projector(projs, ch_names, bads=(), include_active=True, nonzero = 0 bads = set(bads) for k, p in enumerate(projs): - if not p['active'] or include_active: - if (len(p['data']['col_names']) != - len(np.unique(p['data']['col_names']))): - raise ValueError('Channel name list in projection item %d' - ' contains duplicate items' % k) + if not p["active"] or include_active: + if len(p["data"]["col_names"]) != len(np.unique(p["data"]["col_names"])): + raise ValueError( + "Channel name list in projection item %d" + " contains duplicate items" % k + ) # Get the two selection vectors to pick correct elements from # the projection vectors omitting bad channels sel = [] vecsel = [] - p_set = set(p['data']['col_names']) # faster membership access + p_set = set(p["data"]["col_names"]) # faster membership access for c, name in enumerate(ch_names): if name not in bads and name in p_set: sel.append(c) - vecsel.append(p['data']['col_names'].index(name)) + vecsel.append(p["data"]["col_names"].index(name)) # If there is something to pick, pickit - nrow = p['data']['nrow'] - this_vecs = vecs[:, nvec:nvec + nrow] + nrow = p["data"]["nrow"] + this_vecs = vecs[:, nvec : nvec + nrow] if len(sel) > 0: - this_vecs[sel] = p['data']['data'][:, vecsel].T + this_vecs[sel] = p["data"]["data"][:, vecsel].T # Rescale for better detection of small singular values - for v in range(p['data']['nrow']): + for v in range(p["data"]["nrow"]): psize = np.linalg.norm(this_vecs[:, v]) if psize > 0: - orig_n = p['data']['data'].any(axis=0).sum() + orig_n = p["data"]["data"].any(axis=0).sum() # Average ref still works if channels are removed # Use relative power to determine if we're in trouble. # 10% loss is hopefully a reasonable threshold. - if psize < 0.9 and not inplace and \ - (p['kind'] != FIFF.FIFFV_PROJ_ITEM_EEG_AVREF or - len(vecsel) == 1): + if ( + psize < 0.9 + and not inplace + and ( + p["kind"] != FIFF.FIFFV_PROJ_ITEM_EEG_AVREF + or len(vecsel) == 1 + ) + ): warn( f'Projection vector {repr(p["desc"])} has been ' - f'reduced to {100 * psize:0.2f}% of its ' - 'original magnitude by subselecting ' - f'{len(vecsel)}/{orig_n} of the original ' - 'channels. If the ignored channels were bad ' - 'during SSP computation, we recommend ' - 'recomputing proj (via compute_proj_raw ' - 'or related functions) with the bad channels ' - 'properly marked, because computing SSP with bad ' - 'channels present in the data but unmarked is ' - 'dangerous (it can bias the PCA used by SSP). ' - 'On the other hand, if you know that all channels ' - 'were good during SSP computation, you can safely ' - 'use info.normalize_proj() to suppress this ' - 'warning during projection.') + f"reduced to {100 * psize:0.2f}% of its " + "original magnitude by subselecting " + f"{len(vecsel)}/{orig_n} of the original " + "channels. If the ignored channels were bad " + "during SSP computation, we recommend " + "recomputing proj (via compute_proj_raw " + "or related functions) with the bad channels " + "properly marked, because computing SSP with bad " + "channels present in the data but unmarked is " + "dangerous (it can bias the PCA used by SSP). " + "On the other hand, if you know that all channels " + "were good during SSP computation, you can safely " + "use info.normalize_proj() to suppress this " + "warning during projection." + ) this_vecs[:, v] /= psize nonzero += 1 # If doing "inplace" mode, "fix" the projectors to only operate # on this subset of channels. if inplace: - p['data']['data'] = this_vecs[sel].T - p['data']['col_names'] = [p['data']['col_names'][ii] - for ii in vecsel] - p['data']['ncol'] = len(p['data']['col_names']) - nvec += p['data']['nrow'] + p["data"]["data"] = this_vecs[sel].T + p["data"]["col_names"] = [p["data"]["col_names"][ii] for ii in vecsel] + p["data"]["ncol"] = len(p["data"]["col_names"]) + nvec += p["data"]["nrow"] # Check whether all of the vectors are exactly zero if nonzero == 0 or inplace: @@ -753,8 +870,10 @@ def _make_projector(projs, ch_names, bads=(), include_active=True, # Here is the celebrated result proj = np.eye(nchan, nchan) - np.dot(U, U.T) if nproj >= nchan: # e.g., 3 channels and 3 projectors - raise RuntimeError('Application of %d projectors for %d channels ' - 'will yield no components.' % (nproj, nchan)) + raise RuntimeError( + "Application of %d projectors for %d channels " + "will yield no components." % (nproj, nchan) + ) return proj, nproj, U @@ -767,8 +886,13 @@ def _normalize_proj(info): with picks. """ # Here we do info.get b/c info can actually be a noise cov - _make_projector(info['projs'], info.get('ch_names', info.get('names')), - info['bads'], include_active=True, inplace=True) + _make_projector( + info["projs"], + info.get("ch_names", info.get("names")), + info["bads"], + include_active=True, + inplace=True, + ) @fill_doc @@ -790,8 +914,9 @@ def make_projector_info(info, include_active=True): nproj : int How many items in the projector. """ - proj, nproj, _ = make_projector(info['projs'], info['ch_names'], - info['bads'], include_active) + proj, nproj, _ = make_projector( + info["projs"], info["ch_names"], info["bads"], include_active + ) return proj, nproj @@ -819,9 +944,9 @@ def activate_proj(projs, copy=True, verbose=None): # Activate the projection items for proj in projs: - proj['active'] = True + proj["active"] = True - logger.info('%d projection items activated' % len(projs)) + logger.info("%d projection items activated" % len(projs)) return projs @@ -850,9 +975,9 @@ def deactivate_proj(projs, copy=True, verbose=None): # Deactivate the projection items for proj in projs: - proj['active'] = False + proj["active"] = False - logger.info('%d projection items deactivated' % len(projs)) + logger.info("%d projection items deactivated" % len(projs)) return projs @@ -862,8 +987,7 @@ def deactivate_proj(projs, copy=True, verbose=None): @verbose -def make_eeg_average_ref_proj(info, activate=True, *, ch_type='eeg', - verbose=None): +def make_eeg_average_ref_proj(info, activate=True, *, ch_type="eeg", verbose=None): """Create an EEG average reference SSP projection vector. Parameters @@ -883,82 +1007,94 @@ def make_eeg_average_ref_proj(info, activate=True, *, ch_type='eeg', proj: instance of Projection The SSP/PCA projector. """ - if info.get('custom_ref_applied', False): - raise RuntimeError('A custom reference has been applied to the ' - 'data earlier. Please use the ' - 'mne.io.set_eeg_reference function to move from ' - 'one EEG reference to another.') - - _validate_type(ch_type, (list, tuple, str), 'ch_type') + if info.get("custom_ref_applied", False): + raise RuntimeError( + "A custom reference has been applied to the " + "data earlier. Please use the " + "mne.io.set_eeg_reference function to move from " + "one EEG reference to another." + ) + + _validate_type(ch_type, (list, tuple, str), "ch_type") singleton = False if isinstance(ch_type, str): ch_type = [ch_type] singleton = True for ci, this_ch_type in enumerate(ch_type): - _check_option('ch_type' + ('' if singleton else f'[{ci}]'), - this_ch_type, list(_EEG_AVREF_PICK_DICT)) + _check_option( + "ch_type" + ("" if singleton else f"[{ci}]"), + this_ch_type, + list(_EEG_AVREF_PICK_DICT), + ) - ch_type_name = '/'.join(c.upper() for c in ch_type) + ch_type_name = "/".join(c.upper() for c in ch_type) logger.info(f"Adding average {ch_type_name} reference projection.") ch_dict = {c: True for c in ch_type} for c in ch_type: - one_picks = pick_types(info, exclude='bads', **{c: True}) + one_picks = pick_types(info, exclude="bads", **{c: True}) if len(one_picks) == 0: - raise ValueError(f'Cannot create {ch_type_name} average reference ' - f'projector (no {c.upper()} data found)') + raise ValueError( + f"Cannot create {ch_type_name} average reference " + f"projector (no {c.upper()} data found)" + ) del ch_type - ch_sel = pick_types(info, **ch_dict, exclude='bads') - ch_names = info['ch_names'] + ch_sel = pick_types(info, **ch_dict, exclude="bads") + ch_names = info["ch_names"] ch_names = [ch_names[k] for k in ch_sel] n_chs = len(ch_sel) vec = np.ones((1, n_chs)) vec /= np.sqrt(n_chs) explained_var = None - proj_data = dict(col_names=ch_names, row_names=None, - data=vec, nrow=1, ncol=n_chs) + proj_data = dict(col_names=ch_names, row_names=None, data=vec, nrow=1, ncol=n_chs) proj = Projection( - active=activate, data=proj_data, explained_var=explained_var, - desc=f'Average {ch_type_name} reference', - kind=FIFF.FIFFV_PROJ_ITEM_EEG_AVREF) + active=activate, + data=proj_data, + explained_var=explained_var, + desc=f"Average {ch_type_name} reference", + kind=FIFF.FIFFV_PROJ_ITEM_EEG_AVREF, + ) return proj @verbose def _has_eeg_average_ref_proj( - info, *, projs=None, check_active=False, ch_type=None, verbose=None): + info, *, projs=None, check_active=False, ch_type=None, verbose=None +): """Determine if a list of projectors has an average EEG ref. Optionally, set check_active=True to additionally check if the CAR has already been applied. """ from .meas_info import Info - _validate_type(info, Info, 'info') - projs = info.get('projs', []) if projs is None else projs + + _validate_type(info, Info, "info") + projs = info.get("projs", []) if projs is None else projs if ch_type is None: pick_kwargs = _EEG_AVREF_PICK_DICT else: ch_type = [ch_type] if isinstance(ch_type, str) else ch_type pick_kwargs = {ch_type: True for ch_type in ch_type} - ch_type = '/'.join(c.upper() for c in pick_kwargs) + ch_type = "/".join(c.upper() for c in pick_kwargs) want_names = [ - info['ch_names'][pick] for pick in pick_types( - info, exclude='bads', **pick_kwargs)] + info["ch_names"][pick] + for pick in pick_types(info, exclude="bads", **pick_kwargs) + ] if not want_names: return False found_names = list() for proj in projs: - if (proj['kind'] == FIFF.FIFFV_PROJ_ITEM_EEG_AVREF or - re.match('^Average .* reference$', proj['desc'])): - if not check_active or proj['active']: - found_names.extend(proj['data']['col_names']) + if proj["kind"] == FIFF.FIFFV_PROJ_ITEM_EEG_AVREF or re.match( + "^Average .* reference$", proj["desc"] + ): + if not check_active or proj["active"]: + found_names.extend(proj["data"]["col_names"]) # If some are missing we have a problem (keep order for the message, # otherwise we could use set logic) missing = [name for name in want_names if name not in found_names] if missing: if found_names: # found some but not all: warn - warn(f'Incomplete {ch_type} projector, ' - f'missing channel(s) {missing}') + warn(f"Incomplete {ch_type} projector, " f"missing channel(s) {missing}") return False return True @@ -969,7 +1105,7 @@ def _needs_eeg_average_ref_proj(info): This returns True if no custom reference has been applied and no average reference projection is present in the list of projections. """ - if info['custom_ref_applied']: + if info["custom_ref_applied"]: return False if not _electrode_types(info): return False @@ -979,8 +1115,9 @@ def _needs_eeg_average_ref_proj(info): @verbose -def setup_proj(info, add_eeg_ref=True, activate=True, *, eeg_ref_ch_type='eeg', - verbose=None): +def setup_proj( + info, add_eeg_ref=True, activate=True, *, eeg_ref_ch_type="eeg", verbose=None +): """Set up projection for Raw and Epochs. Parameters @@ -1008,24 +1145,23 @@ def setup_proj(info, add_eeg_ref=True, activate=True, *, eeg_ref_ch_type='eeg', # Add EEG ref reference proj if necessary if add_eeg_ref and _needs_eeg_average_ref_proj(info): eeg_proj = make_eeg_average_ref_proj( - info, activate=activate, ch_type=eeg_ref_ch_type) - info['projs'].append(eeg_proj) + info, activate=activate, ch_type=eeg_ref_ch_type + ) + info["projs"].append(eeg_proj) # Create the projector projector, nproj = make_projector_info(info) if nproj == 0: if verbose: - logger.info('The projection vectors do not apply to these ' - 'channels') + logger.info("The projection vectors do not apply to these " "channels") projector = None else: - logger.info('Created an SSP operator (subspace dimension = %d)' - % nproj) + logger.info("Created an SSP operator (subspace dimension = %d)" % nproj) # The projection items have been activated if activate: with info._unlock(): - info['projs'] = activate_proj(info['projs'], copy=False) + info["projs"] = activate_proj(info["projs"], copy=False) return projector, info @@ -1041,11 +1177,11 @@ def _uniquify_projs(projs, check_active=True, sort=True): def sorter(x): """Sort in a nice way.""" - digits = [s for s in x['desc'] if s.isdigit()] + digits = [s for s in x["desc"] if s.isdigit()] if digits: sort_idx = int(digits[-1]) else: sort_idx = next(my_count) - return (sort_idx, x['desc']) + return (sort_idx, x["desc"]) return sorted(final_projs, key=sorter) if sort else final_projs diff --git a/mne/io/reference.py b/mne/io/reference.py index f62c5637140..a948313ec62 100644 --- a/mne/io/reference.py +++ b/mne/io/reference.py @@ -10,14 +10,21 @@ from .meas_info import _check_ch_keys from .proj import _has_eeg_average_ref_proj, make_eeg_average_ref_proj from .proj import setup_proj -from .pick import (pick_types, pick_channels, pick_channels_forward, - _ELECTRODE_CH_TYPES) +from .pick import pick_types, pick_channels, pick_channels_forward, _ELECTRODE_CH_TYPES from .base import BaseRaw from ..evoked import Evoked from ..epochs import BaseEpochs from ..fixes import pinv -from ..utils import (logger, warn, verbose, _validate_type, _check_preload, - _check_option, fill_doc, _on_missing) +from ..utils import ( + logger, + warn, + verbose, + _validate_type, + _check_preload, + _check_option, + fill_doc, + _on_missing, +) from ..defaults import DEFAULTS @@ -54,62 +61,63 @@ def _check_before_reference(inst, ref_from, ref_to, ch_type): _check_preload(inst, "Applying a reference") ch_type = _get_ch_type(inst, ch_type) - ch_dict = {**{type_: True for type_ in ch_type}, - 'meg': False, 'ref_meg': False} + ch_dict = {**{type_: True for type_ in ch_type}, "meg": False, "ref_meg": False} eeg_idx = pick_types(inst.info, **ch_dict) if ref_to is None: ref_to = [inst.ch_names[i] for i in eeg_idx] - extra = 'EEG channels found' + extra = "EEG channels found" else: - extra = 'channels supplied' + extra = "channels supplied" if len(ref_to) == 0: - raise ValueError('No %s to apply the reference to' % (extra,)) + raise ValueError("No %s to apply the reference to" % (extra,)) # After referencing, existing SSPs might not be valid anymore. projs_to_remove = [] - for i, proj in enumerate(inst.info['projs']): + for i, proj in enumerate(inst.info["projs"]): # Remove any average reference projections - if proj['desc'] == 'Average EEG reference' or \ - proj['kind'] == FIFF.FIFFV_PROJ_ITEM_EEG_AVREF: - logger.info('Removing existing average EEG reference ' - 'projection.') + if ( + proj["desc"] == "Average EEG reference" + or proj["kind"] == FIFF.FIFFV_PROJ_ITEM_EEG_AVREF + ): + logger.info("Removing existing average EEG reference " "projection.") # Don't remove the projection right away, but do this at the end of # this loop. projs_to_remove.append(i) # Inactive SSPs may block re-referencing - elif (not proj['active'] and - len([ch for ch in (ref_from + ref_to) - if ch in proj['data']['col_names']]) > 0): - + elif ( + not proj["active"] + and len( + [ch for ch in (ref_from + ref_to) if ch in proj["data"]["col_names"]] + ) + > 0 + ): raise RuntimeError( - 'Inactive signal space projection (SSP) operators are ' - 'present that operate on sensors involved in the desired ' - 'referencing scheme. These projectors need to be applied ' - 'using the apply_proj() method function before the desired ' - 'reference can be set.' + "Inactive signal space projection (SSP) operators are " + "present that operate on sensors involved in the desired " + "referencing scheme. These projectors need to be applied " + "using the apply_proj() method function before the desired " + "reference can be set." ) for i in projs_to_remove: - del inst.info['projs'][i] + del inst.info["projs"][i] # Need to call setup_proj after changing the projs: - inst._projector, _ = \ - setup_proj(inst.info, add_eeg_ref=False, activate=False) + inst._projector, _ = setup_proj(inst.info, add_eeg_ref=False, activate=False) # If the reference touches EEG/ECoG/sEEG/DBS electrodes, note in the # info that a non-CAR has been applied. ref_to_channels = pick_channels(inst.ch_names, ref_to, ordered=True) if len(np.intersect1d(ref_to_channels, eeg_idx)) > 0: with inst.info._unlock(): - inst.info['custom_ref_applied'] = FIFF.FIFFV_MNE_CUSTOM_REF_ON + inst.info["custom_ref_applied"] = FIFF.FIFFV_MNE_CUSTOM_REF_ON return ref_to -def _apply_reference(inst, ref_from, ref_to=None, forward=None, - ch_type='auto'): +def _apply_reference(inst, ref_from, ref_to=None, forward=None, ch_type="auto"): """Apply a custom EEG referencing scheme.""" ref_to = _check_before_reference(inst, ref_from, ref_to, ch_type) @@ -133,7 +141,7 @@ def _apply_reference(inst, ref_from, ref_to=None, forward=None, # use ch_sel and the given forward forward = pick_channels_forward(forward, ref_names, ordered=True) # 1-3. Compute a forward (G) and avg-ref'ed data (done above) - G = forward['sol']['data'] + G = forward["sol"]["data"] assert G.shape[0] == len(ref_names) # 4. Compute the forward (G) and average-reference it (Ga): Ga = G - np.mean(G, axis=0, keepdims=True) @@ -173,17 +181,17 @@ def add_reference_channels(inst, ref_channels, copy=True): Data with added EEG reference channels. """ # Check to see that data is preloaded - _check_preload(inst, 'add_reference_channels') - _validate_type(ref_channels, (list, tuple, str), 'ref_channels') + _check_preload(inst, "add_reference_channels") + _validate_type(ref_channels, (list, tuple, str), "ref_channels") if isinstance(ref_channels, str): ref_channels = [ref_channels] for ch in ref_channels: - if ch in inst.info['ch_names']: + if ch in inst.info["ch_names"]: raise ValueError("Channel %s already specified in inst." % ch) # Once CAR is applied (active), don't allow adding channels if _has_eeg_average_ref_proj(inst.info, check_active=True): - raise RuntimeError('Average reference already applied to data.') + raise RuntimeError("Average reference already applied to data.") if copy: inst = inst.copy() @@ -197,87 +205,104 @@ def add_reference_channels(inst, ref_channels, copy=True): data = inst._data x, y, z = data.shape refs = np.zeros((x * len(ref_channels), z)) - data = np.vstack((data.reshape((x * y, z), order='F'), refs)) - data = data.reshape(x, y + len(ref_channels), z, order='F') + data = np.vstack((data.reshape((x * y, z), order="F"), refs)) + data = data.reshape(x, y + len(ref_channels), z, order="F") inst._data = data else: - raise TypeError("inst should be Raw, Epochs, or Evoked instead of %s." - % type(inst)) - nchan = len(inst.info['ch_names']) + raise TypeError( + "inst should be Raw, Epochs, or Evoked instead of %s." % type(inst) + ) + nchan = len(inst.info["ch_names"]) # only do this if we actually have digitisation points - if inst.info.get('dig', None) is not None: + if inst.info.get("dig", None) is not None: # "zeroth" EEG electrode dig points is reference - ref_dig_loc = [dl for dl in inst.info['dig'] if ( - dl['kind'] == FIFF.FIFFV_POINT_EEG and - dl['ident'] == 0)] + ref_dig_loc = [ + dl + for dl in inst.info["dig"] + if (dl["kind"] == FIFF.FIFFV_POINT_EEG and dl["ident"] == 0) + ] if len(ref_channels) > 1 or len(ref_dig_loc) != len(ref_channels): ref_dig_array = np.full(12, np.nan) - warn('The locations of multiple reference channels are ignored.') + warn("The locations of multiple reference channels are ignored.") else: # n_ref_channels == 1 and a single ref digitization exists - ref_dig_array = np.concatenate((ref_dig_loc[0]['r'], - ref_dig_loc[0]['r'], np.zeros(6))) + ref_dig_array = np.concatenate( + (ref_dig_loc[0]["r"], ref_dig_loc[0]["r"], np.zeros(6)) + ) # Replace the (possibly new) Ref location for each channel for idx in pick_types(inst.info, meg=False, eeg=True, exclude=[]): - inst.info['chs'][idx]['loc'][3:6] = ref_dig_loc[0]['r'] + inst.info["chs"][idx]["loc"][3:6] = ref_dig_loc[0]["r"] else: # Ideally we'd fall back on getting the location from a montage, but # locations for non-present channels aren't stored, so location is # unknown. Users can call set_montage() again if needed. ref_dig_array = np.full(12, np.nan) - logger.info('Location for this channel is unknown; consider calling ' - 'set_montage() again if needed.') + logger.info( + "Location for this channel is unknown; consider calling " + "set_montage() again if needed." + ) for ch in ref_channels: - chan_info = {'ch_name': ch, - 'coil_type': FIFF.FIFFV_COIL_EEG, - 'kind': FIFF.FIFFV_EEG_CH, - 'logno': nchan + 1, - 'scanno': nchan + 1, - 'cal': 1, - 'range': 1., - 'unit_mul': 0., - 'unit': FIFF.FIFF_UNIT_V, - 'coord_frame': FIFF.FIFFV_COORD_HEAD, - 'loc': ref_dig_array} - inst.info['chs'].append(chan_info) + chan_info = { + "ch_name": ch, + "coil_type": FIFF.FIFFV_COIL_EEG, + "kind": FIFF.FIFFV_EEG_CH, + "logno": nchan + 1, + "scanno": nchan + 1, + "cal": 1, + "range": 1.0, + "unit_mul": 0.0, + "unit": FIFF.FIFF_UNIT_V, + "coord_frame": FIFF.FIFFV_COORD_HEAD, + "loc": ref_dig_array, + } + inst.info["chs"].append(chan_info) inst.info._update_redundant() range_ = np.arange(1, len(ref_channels) + 1) if isinstance(inst, BaseRaw): inst._cals = np.hstack((inst._cals, [1] * len(ref_channels))) for pi, picks in enumerate(inst._read_picks): - inst._read_picks[pi] = np.concatenate( - [picks, np.max(picks) + range_]) + inst._read_picks[pi] = np.concatenate([picks, np.max(picks) + range_]) elif isinstance(inst, BaseEpochs): picks = inst.picks - inst.picks = np.concatenate( - [picks, np.max(picks) + range_]) + inst.picks = np.concatenate([picks, np.max(picks) + range_]) inst.info._check_consistency() - set_eeg_reference(inst, ref_channels=ref_channels, copy=False, - verbose=False) + set_eeg_reference(inst, ref_channels=ref_channels, copy=False, verbose=False) return inst _ref_dict = { - FIFF.FIFFV_MNE_CUSTOM_REF_ON: 'on', - FIFF.FIFFV_MNE_CUSTOM_REF_OFF: 'off', - FIFF.FIFFV_MNE_CUSTOM_REF_CSD: 'CSD', + FIFF.FIFFV_MNE_CUSTOM_REF_ON: "on", + FIFF.FIFFV_MNE_CUSTOM_REF_OFF: "off", + FIFF.FIFFV_MNE_CUSTOM_REF_CSD: "CSD", } def _check_can_reref(inst): _validate_type(inst, (BaseRaw, BaseEpochs, Evoked), "Instance") - current_custom = inst.info['custom_ref_applied'] - if current_custom not in (FIFF.FIFFV_MNE_CUSTOM_REF_ON, - FIFF.FIFFV_MNE_CUSTOM_REF_OFF): - raise RuntimeError('Cannot set new reference on data with custom ' - 'reference type %r' % (_ref_dict[current_custom],)) + current_custom = inst.info["custom_ref_applied"] + if current_custom not in ( + FIFF.FIFFV_MNE_CUSTOM_REF_ON, + FIFF.FIFFV_MNE_CUSTOM_REF_OFF, + ): + raise RuntimeError( + "Cannot set new reference on data with custom " + "reference type %r" % (_ref_dict[current_custom],) + ) @verbose -def set_eeg_reference(inst, ref_channels='average', copy=True, - projection=False, ch_type='auto', forward=None, - *, joint=False, verbose=None): +def set_eeg_reference( + inst, + ref_channels="average", + copy=True, + projection=False, + ch_type="auto", + forward=None, + *, + joint=False, + verbose=None, +): """Specify which reference to use for EEG data. Use this function to explicitly specify the desired reference for EEG. @@ -315,111 +340,128 @@ def set_eeg_reference(inst, ref_channels='average', copy=True, %(set_eeg_reference_see_also_notes)s """ from ..forward import Forward + _check_can_reref(inst) ch_type = _get_ch_type(inst, ch_type) if projection: # average reference projector - if ref_channels != 'average': - raise ValueError('Setting projection=True is only supported for ' - 'ref_channels="average", got %r.' - % (ref_channels,)) + if ref_channels != "average": + raise ValueError( + "Setting projection=True is only supported for " + 'ref_channels="average", got %r.' % (ref_channels,) + ) # We need verbose='error' here in case we add projs sequentially - if _has_eeg_average_ref_proj( - inst.info, ch_type=ch_type, verbose='error'): - warn('An average reference projection was already added. The data ' - 'has been left untouched.') + if _has_eeg_average_ref_proj(inst.info, ch_type=ch_type, verbose="error"): + warn( + "An average reference projection was already added. The data " + "has been left untouched." + ) else: # Creating an average reference may fail. In this case, make # sure that the custom_ref_applied flag is left untouched. - custom_ref_applied = inst.info['custom_ref_applied'] + custom_ref_applied = inst.info["custom_ref_applied"] try: with inst.info._unlock(): - inst.info['custom_ref_applied'] = \ - FIFF.FIFFV_MNE_CUSTOM_REF_OFF + inst.info["custom_ref_applied"] = FIFF.FIFFV_MNE_CUSTOM_REF_OFF if joint: inst.add_proj( make_eeg_average_ref_proj( - inst.info, ch_type=ch_type, activate=False)) + inst.info, ch_type=ch_type, activate=False + ) + ) else: for this_ch_type in ch_type: inst.add_proj( make_eeg_average_ref_proj( - inst.info, ch_type=this_ch_type, - activate=False)) + inst.info, ch_type=this_ch_type, activate=False + ) + ) except Exception: with inst.info._unlock(): - inst.info['custom_ref_applied'] = custom_ref_applied + inst.info["custom_ref_applied"] = custom_ref_applied raise # If the data has been preloaded, projections will no # longer be automatically applied. if inst.preload: - logger.info('Average reference projection was added, ' - 'but has not been applied yet. Use the ' - 'apply_proj method to apply it.') + logger.info( + "Average reference projection was added, " + "but has not been applied yet. Use the " + "apply_proj method to apply it." + ) return inst, None del projection # not used anymore inst = inst.copy() if copy else inst - ch_dict = {**{type_: True for type_ in ch_type}, - 'meg': False, 'ref_meg': False} + ch_dict = {**{type_: True for type_ in ch_type}, "meg": False, "ref_meg": False} ch_sel = [inst.ch_names[i] for i in pick_types(inst.info, **ch_dict)] - if ref_channels == 'REST': + if ref_channels == "REST": _validate_type(forward, Forward, 'forward when ref_channels="REST"') else: forward = None # signal to _apply_reference not to do REST - if ref_channels in ('average', 'REST'): - logger.info(f'Applying {ref_channels} reference.') + if ref_channels in ("average", "REST"): + logger.info(f"Applying {ref_channels} reference.") ref_channels = ch_sel if ref_channels == []: - logger.info('EEG data marked as already having the desired reference.') + logger.info("EEG data marked as already having the desired reference.") else: logger.info( - 'Applying a custom ' + "Applying a custom " f"{tuple(DEFAULTS['titles'][type_] for type_ in ch_type)} " - 'reference.') + "reference." + ) - return _apply_reference(inst, ref_channels, ch_sel, forward, - ch_type=ch_type) + return _apply_reference(inst, ref_channels, ch_sel, forward, ch_type=ch_type) def _get_ch_type(inst, ch_type): - _validate_type(ch_type, (str, list, tuple), 'ch_type') - valid_ch_types = ('auto',) + _ELECTRODE_CH_TYPES + _validate_type(ch_type, (str, list, tuple), "ch_type") + valid_ch_types = ("auto",) + _ELECTRODE_CH_TYPES if isinstance(ch_type, str): - _check_option('ch_type', ch_type, valid_ch_types) - if ch_type != 'auto': + _check_option("ch_type", ch_type, valid_ch_types) + if ch_type != "auto": ch_type = [ch_type] elif isinstance(ch_type, (list, tuple)): for type_ in ch_type: - _validate_type(type_, str, 'ch_type') - _check_option('ch_type', type_, valid_ch_types[1:]) + _validate_type(type_, str, "ch_type") + _check_option("ch_type", type_, valid_ch_types[1:]) ch_type = list(ch_type) # if ch_type is 'auto', search through list to find first reasonable # reference-able channel type. - if ch_type == 'auto': + if ch_type == "auto": for type_ in _ELECTRODE_CH_TYPES: if type_ in inst: ch_type = [type_] - logger.info('%s channel type selected for ' - 're-referencing' % DEFAULTS['titles'][type_]) + logger.info( + "%s channel type selected for " + "re-referencing" % DEFAULTS["titles"][type_] + ) break # if auto comes up empty, or the user specifies a bad ch_type. else: - raise ValueError('No EEG, ECoG, sEEG or DBS channels found ' - 'to rereference.') + raise ValueError( + "No EEG, ECoG, sEEG or DBS channels found " "to rereference." + ) return ch_type @verbose -def set_bipolar_reference(inst, anode, cathode, ch_name=None, ch_info=None, - drop_refs=True, copy=True, on_bad="warn", - verbose=None): +def set_bipolar_reference( + inst, + anode, + cathode, + ch_name=None, + ch_info=None, + drop_refs=True, + copy=True, + on_bad="warn", + verbose=None, +): """Re-reference selected channels using a bipolar referencing scheme. A bipolar reference takes the difference between two channels (the anode @@ -496,38 +538,47 @@ def set_bipolar_reference(inst, anode, cathode, ch_name=None, ch_info=None, cathode = [cathode] if len(anode) != len(cathode): - raise ValueError('Number of anodes (got %d) must equal the number ' - 'of cathodes (got %d).' % (len(anode), len(cathode))) + raise ValueError( + "Number of anodes (got %d) must equal the number " + "of cathodes (got %d)." % (len(anode), len(cathode)) + ) if ch_name is None: - ch_name = [f'{a}-{c}' for (a, c) in zip(anode, cathode)] + ch_name = [f"{a}-{c}" for (a, c) in zip(anode, cathode)] elif not isinstance(ch_name, list): ch_name = [ch_name] if len(ch_name) != len(anode): - raise ValueError('Number of channel names must equal the number of ' - 'anodes/cathodes (got %d).' % len(ch_name)) + raise ValueError( + "Number of channel names must equal the number of " + "anodes/cathodes (got %d)." % len(ch_name) + ) # Check for duplicate channel names (it is allowed to give the name of the # anode or cathode channel, as they will be replaced). for ch, a, c in zip(ch_name, anode, cathode): if ch not in [a, c] and ch in inst.ch_names: - raise ValueError('There is already a channel named "%s", please ' - 'specify a different name for the bipolar ' - 'channel using the ch_name parameter.' % ch) + raise ValueError( + 'There is already a channel named "%s", please ' + "specify a different name for the bipolar " + "channel using the ch_name parameter." % ch + ) if ch_info is None: ch_info = [{} for _ in anode] elif not isinstance(ch_info, list): ch_info = [ch_info] if len(ch_info) != len(anode): - raise ValueError('Number of channel info dictionaries must equal the ' - 'number of anodes/cathodes.') + raise ValueError( + "Number of channel info dictionaries must equal the " + "number of anodes/cathodes." + ) if copy: inst = inst.copy() - anode = _check_before_reference(inst, ref_from=cathode, - ref_to=anode, ch_type='auto') + anode = _check_before_reference( + inst, ref_from=cathode, ref_to=anode, ch_type="auto" + ) # Create bipolar reference channels by multiplying the data # (channels x time) with a matrix (n_virtual_channels x channels) @@ -537,25 +588,30 @@ def set_bipolar_reference(inst, anode, cathode, ch_name=None, ch_info=None, multiplier[idx, inst.ch_names.index(a)] = 1 multiplier[idx, inst.ch_names.index(c)] = -1 - ref_info = create_info(ch_names=ch_name, sfreq=inst.info['sfreq'], - ch_types=inst.get_channel_types(picks=anode)) + ref_info = create_info( + ch_names=ch_name, + sfreq=inst.info["sfreq"], + ch_types=inst.get_channel_types(picks=anode), + ) # Update "chs" in Reference-Info. for ch_idx, (an, info) in enumerate(zip(anode, ch_info)): - _check_ch_keys(info, ch_idx, name='ch_info', check_min=False) + _check_ch_keys(info, ch_idx, name="ch_info", check_min=False) an_idx = inst.ch_names.index(an) # Copy everything from anode (except ch_name). - an_chs = {k: v for k, v in inst.info['chs'][an_idx].items() - if k != 'ch_name'} - ref_info['chs'][ch_idx].update(an_chs) + an_chs = {k: v for k, v in inst.info["chs"][an_idx].items() if k != "ch_name"} + ref_info["chs"][ch_idx].update(an_chs) # Set coil-type to bipolar. - ref_info['chs'][ch_idx]['coil_type'] = FIFF.FIFFV_COIL_EEG_BIPOLAR + ref_info["chs"][ch_idx]["coil_type"] = FIFF.FIFFV_COIL_EEG_BIPOLAR # Update with info from ch_info-parameter. - ref_info['chs'][ch_idx].update(info) + ref_info["chs"][ch_idx].update(info) # Set other info-keys from original instance. - pick_info = {k: v for k, v in inst.info.items() if k not in - ['chs', 'ch_names', 'bads', 'nchan', 'sfreq']} + pick_info = { + k: v + for k, v in inst.info.items() + if k not in ["chs", "ch_names", "bads", "nchan", "sfreq"] + } with ref_info._unlock(): ref_info.update(pick_info) @@ -564,16 +620,25 @@ def set_bipolar_reference(inst, anode, cathode, ch_name=None, ch_info=None, ref_data = multiplier @ inst._data if isinstance(inst, BaseRaw): - ref_inst = RawArray(ref_data, ref_info, first_samp=inst.first_samp, - copy=None) + ref_inst = RawArray(ref_data, ref_info, first_samp=inst.first_samp, copy=None) elif isinstance(inst, BaseEpochs): - ref_inst = EpochsArray(ref_data, ref_info, events=inst.events, - tmin=inst.tmin, event_id=inst.event_id, - metadata=inst.metadata) + ref_inst = EpochsArray( + ref_data, + ref_info, + events=inst.events, + tmin=inst.tmin, + event_id=inst.event_id, + metadata=inst.metadata, + ) else: - ref_inst = EvokedArray(ref_data, ref_info, tmin=inst.tmin, - comment=inst.comment, nave=inst.nave, - kind='average') + ref_inst = EvokedArray( + ref_data, + ref_info, + tmin=inst.tmin, + comment=inst.comment, + nave=inst.nave, + kind="average", + ) # Add referenced instance to original instance. inst.add_channels([ref_inst], force_update_info=True) @@ -581,19 +646,19 @@ def set_bipolar_reference(inst, anode, cathode, ch_name=None, ch_info=None, # Handle bad channels. bad_bipolar_chs = [] for ch_idx, (a, c) in enumerate(zip(anode, cathode)): - if a in inst.info['bads'] or c in inst.info['bads']: + if a in inst.info["bads"] or c in inst.info["bads"]: bad_bipolar_chs.append(ch_name[ch_idx]) # Add warnings if bad channels are present. if bad_bipolar_chs: - msg = f'Bipolar channels are based on bad channels: {bad_bipolar_chs}.' + msg = f"Bipolar channels are based on bad channels: {bad_bipolar_chs}." _on_missing(on_bad, msg) - inst.info['bads'] += bad_bipolar_chs + inst.info["bads"] += bad_bipolar_chs - added_channels = ', '.join([name for name in ch_name]) - logger.info(f'Added the following bipolar channels:\n{added_channels}') + added_channels = ", ".join([name for name in ch_name]) + logger.info(f"Added the following bipolar channels:\n{added_channels}") - for attr_name in ['picks', '_projector']: + for attr_name in ["picks", "_projector"]: setattr(inst, attr_name, None) # Drop remaining channels. diff --git a/mne/io/snirf/_snirf.py b/mne/io/snirf/_snirf.py index 1b1b20fd531..2bde9fcf263 100644 --- a/mne/io/snirf/_snirf.py +++ b/mne/io/snirf/_snirf.py @@ -10,8 +10,7 @@ from ..meas_info import create_info, _format_dig_points from ..utils import _mult_cal_one from ...annotations import Annotations -from ...utils import (logger, verbose, fill_doc, warn, _check_fname, - _import_h5py) +from ...utils import logger, verbose, fill_doc, warn, _check_fname, _import_h5py from ..constants import FIFF from .._digitization import _make_dig_points from ...transforms import _frame_to_str, apply_trans @@ -56,7 +55,7 @@ def read_raw_snirf(fname, optode_frame="unknown", preload=False, verbose=None): def _open(fname): - return open(fname, 'r', encoding='latin-1') + return open(fname, "r", encoding="latin-1") @fill_doc @@ -81,38 +80,42 @@ class RawSNIRF(BaseRaw): """ @verbose - def __init__(self, fname, optode_frame="unknown", - preload=False, verbose=None): + def __init__(self, fname, optode_frame="unknown", preload=False, verbose=None): # Must be here due to circular import error from ...preprocessing.nirs import _validate_nirs_info + h5py = _import_h5py() fname = str(_check_fname(fname, "read", True, "fname")) - logger.info('Loading %s' % fname) - - with h5py.File(fname, 'r') as dat: - - if 'data2' in dat['nirs']: - warn("File contains multiple recordings. " - "MNE does not support this feature. " - "Only the first dataset will be processed.") + logger.info("Loading %s" % fname) + + with h5py.File(fname, "r") as dat: + if "data2" in dat["nirs"]: + warn( + "File contains multiple recordings. " + "MNE does not support this feature. " + "Only the first dataset will be processed." + ) manufacturer = _get_metadata_str(dat, "ManufacturerName") if (optode_frame == "unknown") & (manufacturer == "Gowerlabs"): optode_frame = "head" - snirf_data_type = np.array(dat.get('nirs/data1/measurementList1' - '/dataType')).item() + snirf_data_type = np.array( + dat.get("nirs/data1/measurementList1" "/dataType") + ).item() if snirf_data_type not in [1, 99999]: # 1 = Continuous Wave # 99999 = Processed - raise RuntimeError('MNE only supports reading continuous' - ' wave amplitude and processed haemoglobin' - ' SNIRF files. Expected type' - ' code 1 or 99999 but received type ' - f'code {snirf_data_type}') + raise RuntimeError( + "MNE only supports reading continuous" + " wave amplitude and processed haemoglobin" + " SNIRF files. Expected type" + " code 1 or 99999 but received type " + f"code {snirf_data_type}" + ) - last_samps = dat.get('/nirs/data1/dataTimeSeries').shape[0] - 1 + last_samps = dat.get("/nirs/data1/dataTimeSeries").shape[0] - 1 sampling_rate = _extract_sampling_rate(dat) @@ -120,86 +123,102 @@ def __init__(self, fname, optode_frame="unknown", warn("Unable to extract sample rate from SNIRF file.") # Extract wavelengths - fnirs_wavelengths = np.array(dat.get('nirs/probe/wavelengths')) + fnirs_wavelengths = np.array(dat.get("nirs/probe/wavelengths")) fnirs_wavelengths = [int(w) for w in fnirs_wavelengths] if len(fnirs_wavelengths) != 2: - raise RuntimeError(f'The data contains ' - f'{len(fnirs_wavelengths)}' - f' wavelengths: {fnirs_wavelengths}. ' - f'MNE only supports reading continuous' - ' wave amplitude SNIRF files ' - 'with two wavelengths.') + raise RuntimeError( + f"The data contains " + f"{len(fnirs_wavelengths)}" + f" wavelengths: {fnirs_wavelengths}. " + f"MNE only supports reading continuous" + " wave amplitude SNIRF files " + "with two wavelengths." + ) # Extract channels def atoi(text): return int(text) if text.isdigit() else text def natural_keys(text): - return [atoi(c) for c in re.split(r'(\d+)', text)] + return [atoi(c) for c in re.split(r"(\d+)", text)] - channels = np.array([name for name in dat['nirs']['data1'].keys()]) - channels_idx = np.array(['measurementList' in n for n in channels]) + channels = np.array([name for name in dat["nirs"]["data1"].keys()]) + channels_idx = np.array(["measurementList" in n for n in channels]) channels = channels[channels_idx] channels = sorted(channels, key=natural_keys) # Source and detector labels are optional fields. # Use S1, S2, S3, etc if not specified. - if 'sourceLabels_disabled' in dat['nirs/probe']: + if "sourceLabels_disabled" in dat["nirs/probe"]: # This is disabled as # MNE-Python does not currently support custom source names. # Instead, sources must be integer values. - sources = np.array(dat.get('nirs/probe/sourceLabels')) - sources = [s.decode('UTF-8') for s in sources] + sources = np.array(dat.get("nirs/probe/sourceLabels")) + sources = [s.decode("UTF-8") for s in sources] else: - sources = np.unique([_correct_shape(np.array(dat.get( - 'nirs/data1/' + c + '/sourceIndex')))[0] - for c in channels]) + sources = np.unique( + [ + _correct_shape( + np.array(dat.get("nirs/data1/" + c + "/sourceIndex")) + )[0] + for c in channels + ] + ) sources = [f"S{int(s)}" for s in sources] - if 'detectorLabels_disabled' in dat['nirs/probe']: + if "detectorLabels_disabled" in dat["nirs/probe"]: # This is disabled as # MNE-Python does not currently support custom detector names. # Instead, detector must be integer values. - detectors = np.array(dat.get('nirs/probe/detectorLabels')) - detectors = [d.decode('UTF-8') for d in detectors] + detectors = np.array(dat.get("nirs/probe/detectorLabels")) + detectors = [d.decode("UTF-8") for d in detectors] else: - detectors = np.unique([_correct_shape(np.array(dat.get( - 'nirs/data1/' + c + '/detectorIndex')))[0] - for c in channels]) + detectors = np.unique( + [ + _correct_shape( + np.array(dat.get("nirs/data1/" + c + "/detectorIndex")) + )[0] + for c in channels + ] + ) detectors = [f"D{int(d)}" for d in detectors] # Extract source and detector locations # 3D positions are optional in SNIRF, # but highly recommended in MNE. - if ('detectorPos3D' in dat['nirs/probe']) &\ - ('sourcePos3D' in dat['nirs/probe']): + if ("detectorPos3D" in dat["nirs/probe"]) & ( + "sourcePos3D" in dat["nirs/probe"] + ): # If 3D positions are available they are used even if 2D exists - detPos3D = np.array(dat.get('nirs/probe/detectorPos3D')) - srcPos3D = np.array(dat.get('nirs/probe/sourcePos3D')) - elif ('detectorPos2D' in dat['nirs/probe']) &\ - ('sourcePos2D' in dat['nirs/probe']): - warn('The data only contains 2D location information for the ' - 'optode positions. ' - 'It is highly recommended that data is used ' - 'which contains 3D location information for the ' - 'optode positions. With only 2D locations it can not be ' - 'guaranteed that MNE functions will behave correctly ' - 'and produce accurate results. If it is not possible to ' - 'include 3D positions in your data, please consider ' - 'using the set_montage() function.') - - detPos2D = np.array(dat.get('nirs/probe/detectorPos2D')) - srcPos2D = np.array(dat.get('nirs/probe/sourcePos2D')) + detPos3D = np.array(dat.get("nirs/probe/detectorPos3D")) + srcPos3D = np.array(dat.get("nirs/probe/sourcePos3D")) + elif ("detectorPos2D" in dat["nirs/probe"]) & ( + "sourcePos2D" in dat["nirs/probe"] + ): + warn( + "The data only contains 2D location information for the " + "optode positions. " + "It is highly recommended that data is used " + "which contains 3D location information for the " + "optode positions. With only 2D locations it can not be " + "guaranteed that MNE functions will behave correctly " + "and produce accurate results. If it is not possible to " + "include 3D positions in your data, please consider " + "using the set_montage() function." + ) + + detPos2D = np.array(dat.get("nirs/probe/detectorPos2D")) + srcPos2D = np.array(dat.get("nirs/probe/sourcePos2D")) # Set the third dimension to zero. See gh#9308 - detPos3D = np.append(detPos2D, - np.zeros((detPos2D.shape[0], 1)), axis=1) - srcPos3D = np.append(srcPos2D, - np.zeros((srcPos2D.shape[0], 1)), axis=1) + detPos3D = np.append(detPos2D, np.zeros((detPos2D.shape[0], 1)), axis=1) + srcPos3D = np.append(srcPos2D, np.zeros((srcPos2D.shape[0], 1)), axis=1) else: - raise RuntimeError('No optode location information is ' - 'provided. MNE requires at least 2D ' - 'location information') + raise RuntimeError( + "No optode location information is " + "provided. MNE requires at least 2D " + "location information" + ) assert len(sources) == srcPos3D.shape[0] assert len(detectors) == detPos3D.shape[0] @@ -207,68 +226,80 @@ def natural_keys(text): chnames = [] ch_types = [] for chan in channels: - src_idx = int(_correct_shape(np.array(dat.get('nirs/data1/' + - chan + '/sourceIndex')))[0]) - det_idx = int(_correct_shape(np.array(dat.get('nirs/data1/' + - chan + '/detectorIndex')))[0]) + src_idx = int( + _correct_shape( + np.array(dat.get("nirs/data1/" + chan + "/sourceIndex")) + )[0] + ) + det_idx = int( + _correct_shape( + np.array(dat.get("nirs/data1/" + chan + "/detectorIndex")) + )[0] + ) if snirf_data_type == 1: - wve_idx = int(_correct_shape(np.array( - dat.get('nirs/data1/' + chan + - '/wavelengthIndex')))[0]) - ch_name = sources[src_idx - 1] + '_' +\ - detectors[det_idx - 1] + ' ' +\ - str(fnirs_wavelengths[wve_idx - 1]) + wve_idx = int( + _correct_shape( + np.array(dat.get("nirs/data1/" + chan + "/wavelengthIndex")) + )[0] + ) + ch_name = ( + sources[src_idx - 1] + + "_" + + detectors[det_idx - 1] + + " " + + str(fnirs_wavelengths[wve_idx - 1]) + ) chnames.append(ch_name) - ch_types.append('fnirs_cw_amplitude') + ch_types.append("fnirs_cw_amplitude") elif snirf_data_type == 99999: dt_id = _correct_shape( - np.array(dat.get('nirs/data1/' + chan + - '/dataTypeLabel')))[0].decode('UTF-8') + np.array(dat.get("nirs/data1/" + chan + "/dataTypeLabel")) + )[0].decode("UTF-8") # Convert between SNIRF processed names and MNE type names dt_id = dt_id.lower().replace("dod", "fnirs_od") - ch_name = sources[src_idx - 1] + '_' + \ - detectors[det_idx - 1] + ch_name = sources[src_idx - 1] + "_" + detectors[det_idx - 1] if dt_id == "fnirs_od": - wve_idx = int(_correct_shape(np.array( - dat.get('nirs/data1/' + chan + - '/wavelengthIndex')))[0]) - suffix = ' ' + str(fnirs_wavelengths[wve_idx - 1]) + wve_idx = int( + _correct_shape( + np.array( + dat.get("nirs/data1/" + chan + "/wavelengthIndex") + ) + )[0] + ) + suffix = " " + str(fnirs_wavelengths[wve_idx - 1]) else: - suffix = ' ' + dt_id.lower() + suffix = " " + dt_id.lower() ch_name = ch_name + suffix chnames.append(ch_name) ch_types.append(dt_id) # Create mne structure - info = create_info(chnames, - sampling_rate, - ch_types=ch_types) + info = create_info(chnames, sampling_rate, ch_types=ch_types) subject_info = {} - names = np.array(dat.get('nirs/metaDataTags/SubjectID')) - subject_info['first_name'] = \ - _correct_shape(names)[0].decode('UTF-8') + names = np.array(dat.get("nirs/metaDataTags/SubjectID")) + subject_info["first_name"] = _correct_shape(names)[0].decode("UTF-8") # Read non standard (but allowed) custom metadata tags - if 'lastName' in dat.get('nirs/metaDataTags/'): - ln = dat.get('/nirs/metaDataTags/lastName')[0].decode('UTF-8') - subject_info['last_name'] = ln - if 'middleName' in dat.get('nirs/metaDataTags/'): - m = dat.get('/nirs/metaDataTags/middleName')[0].decode('UTF-8') - subject_info['middle_name'] = m - if 'sex' in dat.get('nirs/metaDataTags/'): - s = dat.get('/nirs/metaDataTags/sex')[0].decode('UTF-8') - if s in {'M', 'Male', '1', 'm'}: - subject_info['sex'] = FIFF.FIFFV_SUBJ_SEX_MALE - elif s in {'F', 'Female', '2', 'f'}: - subject_info['sex'] = FIFF.FIFFV_SUBJ_SEX_FEMALE - elif s in {'0', 'u'}: - subject_info['sex'] = FIFF.FIFFV_SUBJ_SEX_UNKNOWN + if "lastName" in dat.get("nirs/metaDataTags/"): + ln = dat.get("/nirs/metaDataTags/lastName")[0].decode("UTF-8") + subject_info["last_name"] = ln + if "middleName" in dat.get("nirs/metaDataTags/"): + m = dat.get("/nirs/metaDataTags/middleName")[0].decode("UTF-8") + subject_info["middle_name"] = m + if "sex" in dat.get("nirs/metaDataTags/"): + s = dat.get("/nirs/metaDataTags/sex")[0].decode("UTF-8") + if s in {"M", "Male", "1", "m"}: + subject_info["sex"] = FIFF.FIFFV_SUBJ_SEX_MALE + elif s in {"F", "Female", "2", "f"}: + subject_info["sex"] = FIFF.FIFFV_SUBJ_SEX_FEMALE + elif s in {"0", "u"}: + subject_info["sex"] = FIFF.FIFFV_SUBJ_SEX_UNKNOWN # End non standard name reading # Update info info.update(subject_info=subject_info) @@ -283,140 +314,163 @@ def natural_keys(text): # These are all in MNI or MEG coordinates, so let's transform # them to the Neuromag head coordinate frame srcPos3D, detPos3D, _, head_t = _convert_fnirs_to_head( - 'fsaverage', optode_frame, 'head', srcPos3D, detPos3D, []) + "fsaverage", optode_frame, "head", srcPos3D, detPos3D, [] + ) else: head_t = np.eye(4) if optode_frame in ["head", "mri", "meg"]: # Then the transformation to head was performed above coord_frame = FIFF.FIFFV_COORD_HEAD - elif 'MNE_coordFrame' in dat.get('nirs/metaDataTags/'): - coord_frame = int(dat.get('/nirs/metaDataTags/MNE_coordFrame') - [0]) + elif "MNE_coordFrame" in dat.get("nirs/metaDataTags/"): + coord_frame = int(dat.get("/nirs/metaDataTags/MNE_coordFrame")[0]) else: coord_frame = FIFF.FIFFV_COORD_UNKNOWN for idx, chan in enumerate(channels): - src_idx = int(_correct_shape(np.array(dat.get('nirs/data1/' + - chan + '/sourceIndex')))[0]) - det_idx = int(_correct_shape(np.array(dat.get('nirs/data1/' + - chan + '/detectorIndex')))[0]) + src_idx = int( + _correct_shape( + np.array(dat.get("nirs/data1/" + chan + "/sourceIndex")) + )[0] + ) + det_idx = int( + _correct_shape( + np.array(dat.get("nirs/data1/" + chan + "/detectorIndex")) + )[0] + ) - info['chs'][idx]['loc'][3:6] = srcPos3D[src_idx - 1, :] - info['chs'][idx]['loc'][6:9] = detPos3D[det_idx - 1, :] + info["chs"][idx]["loc"][3:6] = srcPos3D[src_idx - 1, :] + info["chs"][idx]["loc"][6:9] = detPos3D[det_idx - 1, :] # Store channel as mid point - midpoint = (info['chs'][idx]['loc'][3:6] + - info['chs'][idx]['loc'][6:9]) / 2 - info['chs'][idx]['loc'][0:3] = midpoint - info['chs'][idx]['coord_frame'] = coord_frame - - if (snirf_data_type in [1]) or \ - ((snirf_data_type == 99999) and - (ch_types[idx] == "fnirs_od")): - wve_idx = int(_correct_shape(np.array(dat.get( - 'nirs/data1/' + chan + '/wavelengthIndex')))[0]) - info['chs'][idx]['loc'][9] = fnirs_wavelengths[wve_idx - 1] - - if 'landmarkPos3D' in dat.get('nirs/probe/'): - diglocs = np.array(dat.get('/nirs/probe/landmarkPos3D')) + midpoint = ( + info["chs"][idx]["loc"][3:6] + info["chs"][idx]["loc"][6:9] + ) / 2 + info["chs"][idx]["loc"][0:3] = midpoint + info["chs"][idx]["coord_frame"] = coord_frame + + if (snirf_data_type in [1]) or ( + (snirf_data_type == 99999) and (ch_types[idx] == "fnirs_od") + ): + wve_idx = int( + _correct_shape( + np.array(dat.get("nirs/data1/" + chan + "/wavelengthIndex")) + )[0] + ) + info["chs"][idx]["loc"][9] = fnirs_wavelengths[wve_idx - 1] + + if "landmarkPos3D" in dat.get("nirs/probe/"): + diglocs = np.array(dat.get("/nirs/probe/landmarkPos3D")) diglocs /= length_scaling - digname = np.array(dat.get('/nirs/probe/landmarkLabels')) + digname = np.array(dat.get("/nirs/probe/landmarkLabels")) nasion, lpa, rpa, hpi = None, None, None, None extra_ps = dict() for idx, dign in enumerate(digname): dign = dign.lower() - if dign in [b'lpa', b'al']: + if dign in [b"lpa", b"al"]: lpa = diglocs[idx, :3] - elif dign in [b'nasion']: + elif dign in [b"nasion"]: nasion = diglocs[idx, :3] - elif dign in [b'rpa', b'ar']: + elif dign in [b"rpa", b"ar"]: rpa = diglocs[idx, :3] else: - extra_ps[f'EEG{len(extra_ps) + 1:03d}'] = \ - diglocs[idx, :3] + extra_ps[f"EEG{len(extra_ps) + 1:03d}"] = diglocs[idx, :3] add_missing_fiducials = ( - coord_frame == FIFF.FIFFV_COORD_HEAD and - lpa is None and rpa is None and nasion is None + coord_frame == FIFF.FIFFV_COORD_HEAD + and lpa is None + and rpa is None + and nasion is None ) dig = _make_dig_points( - nasion=nasion, lpa=lpa, rpa=rpa, hpi=hpi, + nasion=nasion, + lpa=lpa, + rpa=rpa, + hpi=hpi, dig_ch_pos=extra_ps, coord_frame=_frame_to_str[coord_frame], - add_missing_fiducials=add_missing_fiducials) + add_missing_fiducials=add_missing_fiducials, + ) else: - ch_locs = [info['chs'][idx]['loc'][0:3] - for idx in range(len(channels))] + ch_locs = [info["chs"][idx]["loc"][0:3] for idx in range(len(channels))] # Set up digitization - dig = get_mni_fiducials('fsaverage', verbose=False) + dig = get_mni_fiducials("fsaverage", verbose=False) for fid in dig: - fid['r'] = apply_trans(head_t, fid['r']) - fid['coord_frame'] = FIFF.FIFFV_COORD_HEAD + fid["r"] = apply_trans(head_t, fid["r"]) + fid["coord_frame"] = FIFF.FIFFV_COORD_HEAD for ii, ch_loc in enumerate(ch_locs, 1): - dig.append(dict( - kind=FIFF.FIFFV_POINT_EEG, # misnomer prob okay - r=ch_loc, - ident=ii, - coord_frame=FIFF.FIFFV_COORD_HEAD, - )) + dig.append( + dict( + kind=FIFF.FIFFV_POINT_EEG, # misnomer prob okay + r=ch_loc, + ident=ii, + coord_frame=FIFF.FIFFV_COORD_HEAD, + ) + ) dig = _format_dig_points(dig) del head_t with info._unlock(): - info['dig'] = dig - - str_date = _correct_shape(np.array((dat.get( - '/nirs/metaDataTags/MeasurementDate'))))[0].decode('UTF-8') - str_time = _correct_shape(np.array((dat.get( - '/nirs/metaDataTags/MeasurementTime'))))[0].decode('UTF-8') + info["dig"] = dig + + str_date = _correct_shape( + np.array((dat.get("/nirs/metaDataTags/MeasurementDate"))) + )[0].decode("UTF-8") + str_time = _correct_shape( + np.array((dat.get("/nirs/metaDataTags/MeasurementTime"))) + )[0].decode("UTF-8") str_datetime = str_date + str_time # Several formats have been observed so we try each in turn - for dt_code in ['%Y-%m-%d%H:%M:%SZ', - '%Y-%m-%d%H:%M:%S']: + for dt_code in ["%Y-%m-%d%H:%M:%SZ", "%Y-%m-%d%H:%M:%S"]: try: - meas_date = datetime.datetime.strptime( - str_datetime, dt_code) + meas_date = datetime.datetime.strptime(str_datetime, dt_code) except ValueError: pass else: break else: - warn("Extraction of measurement date from SNIRF file failed. " - "The date is being set to January 1st, 2000, " - f"instead of {str_datetime}") + warn( + "Extraction of measurement date from SNIRF file failed. " + "The date is being set to January 1st, 2000, " + f"instead of {str_datetime}" + ) meas_date = datetime.datetime(2000, 1, 1, 0, 0, 0) meas_date = meas_date.replace(tzinfo=datetime.timezone.utc) with info._unlock(): - info['meas_date'] = meas_date + info["meas_date"] = meas_date - if 'DateOfBirth' in dat.get('nirs/metaDataTags/'): - str_birth = np.array((dat.get('/nirs/metaDataTags/' - 'DateOfBirth')))[0].decode() - birth_matched = re.fullmatch(r'(\d+)-(\d+)-(\d+)', str_birth) + if "DateOfBirth" in dat.get("nirs/metaDataTags/"): + str_birth = np.array((dat.get("/nirs/metaDataTags/" "DateOfBirth")))[ + 0 + ].decode() + birth_matched = re.fullmatch(r"(\d+)-(\d+)-(\d+)", str_birth) if birth_matched is not None: - birthday = (int(birth_matched.groups()[0]), - int(birth_matched.groups()[1]), - int(birth_matched.groups()[2])) + birthday = ( + int(birth_matched.groups()[0]), + int(birth_matched.groups()[1]), + int(birth_matched.groups()[2]), + ) with info._unlock(): - info["subject_info"]['birthday'] = birthday + info["subject_info"]["birthday"] = birthday - super(RawSNIRF, self).__init__(info, preload, filenames=[fname], - last_samps=[last_samps], - verbose=verbose) + super(RawSNIRF, self).__init__( + info, + preload, + filenames=[fname], + last_samps=[last_samps], + verbose=verbose, + ) # Extract annotations # As described at https://github.com/fNIRS/snirf/ # blob/master/snirf_specification.md#nirsistimjdata annot = Annotations([], [], []) - for key in dat['nirs']: - if 'stim' in key: - data = np.atleast_2d(np.array( - dat.get('/nirs/' + key + '/data'))) + for key in dat["nirs"]: + if "stim" in key: + data = np.atleast_2d(np.array(dat.get("/nirs/" + key + "/data"))) if data.size > 0: - desc = _correct_shape(np.array(dat.get( - '/nirs/' + key + '/name')))[0] - annot.append(data[:, 0], - data[:, 1], - desc.decode('UTF-8')) + desc = _correct_shape( + np.array(dat.get("/nirs/" + key + "/name")) + )[0] + annot.append(data[:, 0], data[:, 1], desc.decode("UTF-8")) self.set_annotations(annot, emit_warning=False) # Validate that the fNIRS info is correctly formatted @@ -426,8 +480,8 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): """Read a segment of data from a file.""" import h5py - with h5py.File(self._filenames[0], 'r') as dat: - one = dat['/nirs/data1/dataTimeSeries'][start:stop].T + with h5py.File(self._filenames[0], "r") as dat: + one = dat["/nirs/data1/dataTimeSeries"][start:stop].T _mult_cal_one(data, one, idx, cals, mult) @@ -441,42 +495,48 @@ def _correct_shape(arr): def _get_timeunit_scaling(time_unit): """MNE expects time in seconds, return required scaling.""" - scalings = {'ms': 1000, 's': 1, 'unknown': 1} + scalings = {"ms": 1000, "s": 1, "unknown": 1} if time_unit in scalings: return scalings[time_unit] else: - raise RuntimeError(f'The time unit {time_unit} is not supported by ' - 'MNE. Please report this error as a GitHub ' - 'issue to inform the developers.') + raise RuntimeError( + f"The time unit {time_unit} is not supported by " + "MNE. Please report this error as a GitHub " + "issue to inform the developers." + ) def _get_lengthunit_scaling(length_unit): """MNE expects distance in m, return required scaling.""" - scalings = {'m': 1, 'cm': 100, 'mm': 1000} + scalings = {"m": 1, "cm": 100, "mm": 1000} if length_unit in scalings: return scalings[length_unit] else: - raise RuntimeError(f'The length unit {length_unit} is not supported ' - 'by MNE. Please report this error as a GitHub ' - 'issue to inform the developers.') + raise RuntimeError( + f"The length unit {length_unit} is not supported " + "by MNE. Please report this error as a GitHub " + "issue to inform the developers." + ) def _extract_sampling_rate(dat): """Extract the sample rate from the time field.""" - time_data = np.array(dat.get('nirs/data1/time')) + time_data = np.array(dat.get("nirs/data1/time")) sampling_rate = 0 if len(time_data) == 2: # specified as onset, samplerate - sampling_rate = 1. / (time_data[1] - time_data[0]) + sampling_rate = 1.0 / (time_data[1] - time_data[0]) else: # specified as time points fs_diff = np.around(np.diff(time_data), decimals=4) if len(np.unique(fs_diff)) == 1: # Uniformly sampled data - sampling_rate = 1. / np.unique(fs_diff).item() + sampling_rate = 1.0 / np.unique(fs_diff).item() else: - warn("MNE does not currently support reading " - "SNIRF files with non-uniform sampled data.") + warn( + "MNE does not currently support reading " + "SNIRF files with non-uniform sampled data." + ) time_unit = _get_metadata_str(dat, "TimeUnit") time_unit_scaling = _get_timeunit_scaling(time_unit) @@ -486,9 +546,9 @@ def _extract_sampling_rate(dat): def _get_metadata_str(dat, field): - if field not in np.array(dat.get('nirs/metaDataTags')): + if field not in np.array(dat.get("nirs/metaDataTags")): return None - data = dat.get(f'/nirs/metaDataTags/{field}') + data = dat.get(f"/nirs/metaDataTags/{field}") data = _correct_shape(np.array(data)) - data = str(data[0], 'utf-8') + data = str(data[0], "utf-8") return data diff --git a/mne/io/snirf/tests/test_snirf.py b/mne/io/snirf/tests/test_snirf.py index b9475d69583..721802086bb 100644 --- a/mne/io/snirf/tests/test_snirf.py +++ b/mne/io/snirf/tests/test_snirf.py @@ -9,9 +9,13 @@ from mne.datasets.testing import data_path, requires_testing_data from mne.io import read_raw_snirf, read_raw_nirx from mne.io.tests.test_raw import _test_raw_reader -from mne.preprocessing.nirs import (optical_density, beer_lambert_law, - short_channels, source_detector_distances, - _reorder_nirx) +from mne.preprocessing.nirs import ( + optical_density, + beer_lambert_law, + short_channels, + source_detector_distances, + _reorder_nirx, +) from mne.transforms import apply_trans, _get_trans from mne.io.constants import FIFF @@ -40,40 +44,18 @@ # NIRSport2 files nirx_nirsport2_103 = ( - testing_path - / "SNIRF" - / "NIRx" - / "NIRSport2" - / "1.0.3" - / "2021-04-23_005.snirf" + testing_path / "SNIRF" / "NIRx" / "NIRSport2" / "1.0.3" / "2021-04-23_005.snirf" ) nirx_nirsport2_103_2 = ( - testing_path - / "SNIRF" - / "NIRx" - / "NIRSport2" - / "1.0.3" - / "2021-05-05_001.snirf" + testing_path / "SNIRF" / "NIRx" / "NIRSport2" / "1.0.3" / "2021-05-05_001.snirf" ) snirf_nirsport2_20219 = ( - testing_path - / "SNIRF" - / "NIRx" - / "NIRSport2" - / "2021.9" - / "2021-10-01_002.snirf" + testing_path / "SNIRF" / "NIRx" / "NIRSport2" / "2021.9" / "2021-10-01_002.snirf" ) nirx_nirsport2_20219 = testing_path / "NIRx" / "nirsport_v2" / "aurora_2021_9" # Kernel -kernel_hb = ( - testing_path - / "SNIRF" - / "Kernel" - / "Flow50" - / "Portal_2021_11" - / "hb.snirf" -) +kernel_hb = testing_path / "SNIRF" / "Kernel" / "Flow50" / "Portal_2021_11" / "hb.snirf" h5py = pytest.importorskip("h5py") # module-level @@ -85,43 +67,49 @@ def _get_loc(raw, ch_name): - return raw.copy().pick(ch_name).info['chs'][0]['loc'] + return raw.copy().pick(ch_name).info["chs"][0]["loc"] @requires_testing_data -@pytest.mark.filterwarnings('ignore:.*contains 2D location.*:') -@pytest.mark.filterwarnings('ignore:.*measurement date.*:') -@pytest.mark.parametrize('fname', ([sfnirs_homer_103_wShort, - nirx_nirsport2_103, - sfnirs_homer_103_153, - nirx_nirsport2_103, - nirx_nirsport2_103_2, - nirx_nirsport2_103_2, - kernel_hb, - lumo110 - ])) +@pytest.mark.filterwarnings("ignore:.*contains 2D location.*:") +@pytest.mark.filterwarnings("ignore:.*measurement date.*:") +@pytest.mark.parametrize( + "fname", + ( + [ + sfnirs_homer_103_wShort, + nirx_nirsport2_103, + sfnirs_homer_103_153, + nirx_nirsport2_103, + nirx_nirsport2_103_2, + nirx_nirsport2_103_2, + kernel_hb, + lumo110, + ] + ), +) def test_basic_reading_and_min_process(fname): """Test reading SNIRF files and minimum typical processing.""" raw = read_raw_snirf(fname, preload=True) # SNIRF data can contain several types, so only apply appropriate functions - if 'fnirs_cw_amplitude' in raw: + if "fnirs_cw_amplitude" in raw: raw = optical_density(raw) - if 'fnirs_od' in raw: + if "fnirs_od" in raw: raw = beer_lambert_law(raw, ppf=6) - assert 'hbo' in raw - assert 'hbr' in raw + assert "hbo" in raw + assert "hbr" in raw @requires_testing_data -@pytest.mark.filterwarnings('ignore:.*measurement date.*:') +@pytest.mark.filterwarnings("ignore:.*measurement date.*:") def test_snirf_gowerlabs(): """Test reading SNIRF files.""" raw = read_raw_snirf(lumo110, preload=True) assert raw._data.shape == (216, 274) - assert raw.info['dig'][0]['coord_frame'] == FIFF.FIFFV_COORD_HEAD + assert raw.info["dig"][0]["coord_frame"] == FIFF.FIFFV_COORD_HEAD assert len(raw.ch_names) == 216 - assert_allclose(raw.info['sfreq'], 10.0) + assert_allclose(raw.info["sfreq"], 10.0) # we don't force them to be sorted according to a naive split assert raw.ch_names != sorted(raw.ch_names) # ... but this file does have a nice logical ordering already @@ -130,11 +118,12 @@ def test_snirf_gowerlabs(): raw.ch_names, # use a key which is (src triplet, freq, src, freq, det) key=lambda name: ( - (int(name.split()[0].split('_')[0][1:]) - 1) // 3, + (int(name.split()[0].split("_")[0][1:]) - 1) // 3, int(name.split()[1]), - int(name.split()[0].split('_')[0][1:]), - int(name.split()[0].split('_')[1][1:]) - )) + int(name.split()[0].split("_")[0][1:]), + int(name.split()[0].split("_")[1][1:]), + ), + ) @requires_testing_data @@ -144,40 +133,71 @@ def test_snirf_basic(): # Test data import assert raw._data.shape == (26, 145) - assert raw.info['sfreq'] == 12.5 + assert raw.info["sfreq"] == 12.5 # Test channel naming - assert raw.info['ch_names'][:4] == ["S1_D1 760", "S1_D9 760", - "S2_D3 760", "S2_D10 760"] - assert raw.info['ch_names'][24:26] == ['S5_D8 850', 'S5_D13 850'] + assert raw.info["ch_names"][:4] == [ + "S1_D1 760", + "S1_D9 760", + "S2_D3 760", + "S2_D10 760", + ] + assert raw.info["ch_names"][24:26] == ["S5_D8 850", "S5_D13 850"] # Test frequency encoding - assert raw.info['chs'][0]['loc'][9] == 760 - assert raw.info['chs'][24]['loc'][9] == 850 + assert raw.info["chs"][0]["loc"][9] == 760 + assert raw.info["chs"][24]["loc"][9] == 850 # Test source locations - assert_allclose([-8.6765 * 1e-2, 0.0049 * 1e-2, -2.6167 * 1e-2], - _get_loc(raw, 'S1_D1 760')[3:6], rtol=0.02) - assert_allclose([7.9579 * 1e-2, -2.7571 * 1e-2, -2.2631 * 1e-2], - _get_loc(raw, 'S2_D3 760')[3:6], rtol=0.02) - assert_allclose([-2.1387 * 1e-2, -8.8874 * 1e-2, 3.8393 * 1e-2], - _get_loc(raw, 'S3_D2 760')[3:6], rtol=0.02) - assert_allclose([1.8602 * 1e-2, 9.7164 * 1e-2, 1.7539 * 1e-2], - _get_loc(raw, 'S4_D4 760')[3:6], rtol=0.02) - assert_allclose([-0.1108 * 1e-2, 0.7066 * 1e-2, 8.9883 * 1e-2], - _get_loc(raw, 'S5_D5 760')[3:6], rtol=0.02) + assert_allclose( + [-8.6765 * 1e-2, 0.0049 * 1e-2, -2.6167 * 1e-2], + _get_loc(raw, "S1_D1 760")[3:6], + rtol=0.02, + ) + assert_allclose( + [7.9579 * 1e-2, -2.7571 * 1e-2, -2.2631 * 1e-2], + _get_loc(raw, "S2_D3 760")[3:6], + rtol=0.02, + ) + assert_allclose( + [-2.1387 * 1e-2, -8.8874 * 1e-2, 3.8393 * 1e-2], + _get_loc(raw, "S3_D2 760")[3:6], + rtol=0.02, + ) + assert_allclose( + [1.8602 * 1e-2, 9.7164 * 1e-2, 1.7539 * 1e-2], + _get_loc(raw, "S4_D4 760")[3:6], + rtol=0.02, + ) + assert_allclose( + [-0.1108 * 1e-2, 0.7066 * 1e-2, 8.9883 * 1e-2], + _get_loc(raw, "S5_D5 760")[3:6], + rtol=0.02, + ) # Test detector locations - assert_allclose([-8.0409 * 1e-2, -2.9677 * 1e-2, -2.5415 * 1e-2], - _get_loc(raw, 'S1_D1 760')[6:9], rtol=0.02) - assert_allclose([-8.7329 * 1e-2, 0.7577 * 1e-2, -2.7980 * 1e-2], - _get_loc(raw, 'S1_D9 850')[6:9], rtol=0.02) - assert_allclose([9.2027 * 1e-2, 0.0161 * 1e-2, -2.8909 * 1e-2], - _get_loc(raw, 'S2_D3 850')[6:9], rtol=0.02) - assert_allclose([7.7548 * 1e-2, -3.5901 * 1e-2, -2.3179 * 1e-2], - _get_loc(raw, 'S2_D10 850')[6:9], rtol=0.02) + assert_allclose( + [-8.0409 * 1e-2, -2.9677 * 1e-2, -2.5415 * 1e-2], + _get_loc(raw, "S1_D1 760")[6:9], + rtol=0.02, + ) + assert_allclose( + [-8.7329 * 1e-2, 0.7577 * 1e-2, -2.7980 * 1e-2], + _get_loc(raw, "S1_D9 850")[6:9], + rtol=0.02, + ) + assert_allclose( + [9.2027 * 1e-2, 0.0161 * 1e-2, -2.8909 * 1e-2], + _get_loc(raw, "S2_D3 850")[6:9], + rtol=0.02, + ) + assert_allclose( + [7.7548 * 1e-2, -3.5901 * 1e-2, -2.3179 * 1e-2], + _get_loc(raw, "S2_D10 850")[6:9], + rtol=0.02, + ) - assert 'fnirs_cw_amplitude' in raw + assert "fnirs_cw_amplitude" in raw @requires_testing_data @@ -189,21 +209,25 @@ def test_snirf_against_nirx(): # Check annotations are the same assert_allclose(raw_homer.annotations.onset, raw_orig.annotations.onset) - assert_allclose([float(d) for d in raw_homer.annotations.description], - [float(d) for d in raw_orig.annotations.description]) + assert_allclose( + [float(d) for d in raw_homer.annotations.description], + [float(d) for d in raw_orig.annotations.description], + ) # Homer writes durations as 5s regardless of the true duration. # So we will not test that the nirx file stim durations equal # the homer file stim durations. # Check names are the same - assert raw_homer.info['ch_names'] == raw_orig.info['ch_names'] + assert raw_homer.info["ch_names"] == raw_orig.info["ch_names"] # Check frequencies are the same num_chans = len(raw_homer.ch_names) - new_chs = raw_homer.info['chs'] - ori_chs = raw_orig.info['chs'] - assert_allclose([new_chs[idx]['loc'][9] for idx in range(num_chans)], - [ori_chs[idx]['loc'][9] for idx in range(num_chans)]) + new_chs = raw_homer.info["chs"] + ori_chs = raw_orig.info["chs"] + assert_allclose( + [new_chs[idx]["loc"][9] for idx in range(num_chans)], + [ori_chs[idx]["loc"][9] for idx in range(num_chans)], + ) # Check data is the same assert_allclose(raw_homer.get_data(), raw_orig.get_data()) @@ -216,26 +240,21 @@ def test_snirf_nonstandard(tmp_path): fname = str(tmp_path) + "/mod.snirf" # Manually mark up the file to match MNE-NIRS custom tags with h5py.File(fname, "r+") as f: - f.create_dataset("nirs/metaDataTags/middleName", - data=['X'.encode('UTF-8')]) - f.create_dataset("nirs/metaDataTags/lastName", - data=['Y'.encode('UTF-8')]) - f.create_dataset("nirs/metaDataTags/sex", - data=['1'.encode('UTF-8')]) + f.create_dataset("nirs/metaDataTags/middleName", data=["X".encode("UTF-8")]) + f.create_dataset("nirs/metaDataTags/lastName", data=["Y".encode("UTF-8")]) + f.create_dataset("nirs/metaDataTags/sex", data=["1".encode("UTF-8")]) raw = read_raw_snirf(fname, preload=True) - assert raw.info["subject_info"]["middle_name"] == 'X' - assert raw.info["subject_info"]["last_name"] == 'Y' + assert raw.info["subject_info"]["middle_name"] == "X" + assert raw.info["subject_info"]["last_name"] == "Y" assert raw.info["subject_info"]["sex"] == 1 with h5py.File(fname, "r+") as f: - del f['nirs/metaDataTags/sex'] - f.create_dataset("nirs/metaDataTags/sex", - data=['2'.encode('UTF-8')]) + del f["nirs/metaDataTags/sex"] + f.create_dataset("nirs/metaDataTags/sex", data=["2".encode("UTF-8")]) raw = read_raw_snirf(fname, preload=True) assert raw.info["subject_info"]["sex"] == 2 with h5py.File(fname, "r+") as f: - del f['nirs/metaDataTags/sex'] - f.create_dataset("nirs/metaDataTags/sex", - data=['0'.encode('UTF-8')]) + del f["nirs/metaDataTags/sex"] + f.create_dataset("nirs/metaDataTags/sex", data=["0".encode("UTF-8")]) raw = read_raw_snirf(fname, preload=True) assert raw.info["subject_info"]["sex"] == 0 @@ -250,16 +269,20 @@ def test_snirf_nirsport2(): # Test data import assert raw._data.shape == (92, 84) - assert_almost_equal(raw.info['sfreq'], 7.6, decimal=1) + assert_almost_equal(raw.info["sfreq"], 7.6, decimal=1) # Test channel naming - assert raw.info['ch_names'][:4] == ['S1_D1 760', 'S1_D3 760', - 'S1_D9 760', 'S1_D16 760'] - assert raw.info['ch_names'][24:26] == ['S8_D15 760', 'S8_D20 760'] + assert raw.info["ch_names"][:4] == [ + "S1_D1 760", + "S1_D3 760", + "S1_D9 760", + "S1_D16 760", + ] + assert raw.info["ch_names"][24:26] == ["S8_D15 760", "S8_D20 760"] # Test frequency encoding - assert raw.info['chs'][0]['loc'][9] == 760 - assert raw.info['chs'][-1]['loc'][9] == 850 + assert raw.info["chs"][0]["loc"][9] == 760 + assert raw.info["chs"][-1]["loc"][9] == 850 assert sum(short_channels(raw.info)) == 16 @@ -267,38 +290,44 @@ def test_snirf_nirsport2(): @requires_testing_data def test_snirf_coordframe(): """Test reading SNIRF files.""" - raw = read_raw_snirf(nirx_nirsport2_103, optode_frame="head").\ - info['chs'][3]['coord_frame'] + raw = read_raw_snirf(nirx_nirsport2_103, optode_frame="head").info["chs"][3][ + "coord_frame" + ] assert raw == FIFF.FIFFV_COORD_HEAD - raw = read_raw_snirf(nirx_nirsport2_103, optode_frame="mri").\ - info['chs'][3]['coord_frame'] + raw = read_raw_snirf(nirx_nirsport2_103, optode_frame="mri").info["chs"][3][ + "coord_frame" + ] assert raw == FIFF.FIFFV_COORD_HEAD - raw = read_raw_snirf(nirx_nirsport2_103, optode_frame="unknown").\ - info['chs'][3]['coord_frame'] + raw = read_raw_snirf(nirx_nirsport2_103, optode_frame="unknown").info["chs"][3][ + "coord_frame" + ] assert raw == FIFF.FIFFV_COORD_UNKNOWN @requires_testing_data def test_snirf_nirsport2_w_positions(): """Test reading SNIRF files with known positions.""" - raw = read_raw_snirf(nirx_nirsport2_103_2, preload=True, - optode_frame="mri") + raw = read_raw_snirf(nirx_nirsport2_103_2, preload=True, optode_frame="mri") _reorder_nirx(raw) # Test data import assert raw._data.shape == (40, 128) - assert_almost_equal(raw.info['sfreq'], 10.2, decimal=1) + assert_almost_equal(raw.info["sfreq"], 10.2, decimal=1) # Test channel naming - assert raw.info['ch_names'][:4] == ['S1_D1 760', 'S1_D1 850', - 'S1_D6 760', 'S1_D6 850'] - assert raw.info['ch_names'][24:26] == ['S6_D4 760', 'S6_D4 850'] + assert raw.info["ch_names"][:4] == [ + "S1_D1 760", + "S1_D1 850", + "S1_D6 760", + "S1_D6 850", + ] + assert raw.info["ch_names"][24:26] == ["S6_D4 760", "S6_D4 850"] # Test frequency encoding - assert raw.info['chs'][0]['loc'][9] == 760 - assert raw.info['chs'][1]['loc'][9] == 850 + assert raw.info["chs"][0]["loc"][9] == 760 + assert raw.info["chs"][1]["loc"][9] == 850 assert sum(short_channels(raw.info)) == 16 @@ -306,54 +335,52 @@ def test_snirf_nirsport2_w_positions(): # nirsite https://github.com/mne-tools/mne-testing-data/pull/86 # figure 3 allowed_distance_error = 0.005 - assert_allclose(source_detector_distances(raw.copy(). - pick("S1_D1 760").info), - [0.0304], atol=allowed_distance_error) - assert_allclose(source_detector_distances(raw.copy(). - pick("S2_D2 760").info), - [0.0400], atol=allowed_distance_error) + assert_allclose( + source_detector_distances(raw.copy().pick("S1_D1 760").info), + [0.0304], + atol=allowed_distance_error, + ) + assert_allclose( + source_detector_distances(raw.copy().pick("S2_D2 760").info), + [0.0400], + atol=allowed_distance_error, + ) # Test location of detectors # The locations of detectors can be seen in the first # figure on this page... # https://github.com/mne-tools/mne-testing-data/pull/86 allowed_dist_error = 0.0002 - locs = [ch['loc'][6:9] for ch in raw.info['chs']] - head_mri_t, _ = _get_trans('fsaverage', 'head', 'mri') + locs = [ch["loc"][6:9] for ch in raw.info["chs"]] + head_mri_t, _ = _get_trans("fsaverage", "head", "mri") mni_locs = apply_trans(head_mri_t, locs) - assert raw.info['ch_names'][0][3:5] == 'D1' - assert_allclose( - mni_locs[0], [-0.0841, -0.0464, -0.0129], atol=allowed_dist_error) + assert raw.info["ch_names"][0][3:5] == "D1" + assert_allclose(mni_locs[0], [-0.0841, -0.0464, -0.0129], atol=allowed_dist_error) - assert raw.info['ch_names'][2][3:5] == 'D6' - assert_allclose( - mni_locs[2], [-0.0841, -0.0138, 0.0248], atol=allowed_dist_error) + assert raw.info["ch_names"][2][3:5] == "D6" + assert_allclose(mni_locs[2], [-0.0841, -0.0138, 0.0248], atol=allowed_dist_error) - assert raw.info['ch_names'][34][3:5] == 'D5' - assert_allclose( - mni_locs[34], [0.0845, -0.0451, -0.0123], atol=allowed_dist_error) + assert raw.info["ch_names"][34][3:5] == "D5" + assert_allclose(mni_locs[34], [0.0845, -0.0451, -0.0123], atol=allowed_dist_error) # Test location of sensors # The locations of sensors can be seen in the second # figure on this page... # https://github.com/mne-tools/mne-testing-data/pull/86 allowed_dist_error = 0.0002 - locs = [ch['loc'][3:6] for ch in raw.info['chs']] - head_mri_t, _ = _get_trans('fsaverage', 'head', 'mri') + locs = [ch["loc"][3:6] for ch in raw.info["chs"]] + head_mri_t, _ = _get_trans("fsaverage", "head", "mri") mni_locs = apply_trans(head_mri_t, locs) - assert raw.info['ch_names'][0][:2] == 'S1' - assert_allclose( - mni_locs[0], [-0.0848, -0.0162, -0.0163], atol=allowed_dist_error) + assert raw.info["ch_names"][0][:2] == "S1" + assert_allclose(mni_locs[0], [-0.0848, -0.0162, -0.0163], atol=allowed_dist_error) - assert raw.info['ch_names'][9][:2] == 'S2' - assert_allclose( - mni_locs[9], [-0.0, -0.1195, 0.0142], atol=allowed_dist_error) + assert raw.info["ch_names"][9][:2] == "S2" + assert_allclose(mni_locs[9], [-0.0, -0.1195, 0.0142], atol=allowed_dist_error) - assert raw.info['ch_names'][34][:2] == 'S8' - assert_allclose( - mni_locs[34], [0.0828, -0.046, 0.0285], atol=allowed_dist_error) + assert raw.info["ch_names"][34][:2] == "S8" + assert_allclose(mni_locs[34], [0.0828, -0.046, 0.0285], atol=allowed_dist_error) mon = raw.get_montage() assert len(mon.dig) == 27 @@ -366,14 +393,14 @@ def test_snirf_fieldtrip_od(): # Test data import assert raw._data.shape == (72, 500) - assert raw.copy().pick('fnirs')._data.shape == (72, 500) - assert raw.copy().pick('fnirs_od')._data.shape == (72, 500) - with pytest.raises(ValueError, match='not be interpreted as channel'): - raw.copy().pick('hbo') - with pytest.raises(ValueError, match='not be interpreted as channel'): - raw.copy().pick('hbr') + assert raw.copy().pick("fnirs")._data.shape == (72, 500) + assert raw.copy().pick("fnirs_od")._data.shape == (72, 500) + with pytest.raises(ValueError, match="not be interpreted as channel"): + raw.copy().pick("hbo") + with pytest.raises(ValueError, match="not be interpreted as channel"): + raw.copy().pick("hbr") - assert_allclose(raw.info['sfreq'], 50) + assert_allclose(raw.info["sfreq"], 50) @requires_testing_data @@ -383,10 +410,10 @@ def test_snirf_kernel_hb(): # Test data import assert raw._data.shape == (180 * 2, 14) - assert raw.copy().pick('hbo')._data.shape == (180, 14) - assert raw.copy().pick('hbr')._data.shape == (180, 14) + assert raw.copy().pick("hbo")._data.shape == (180, 14) + assert raw.copy().pick("hbr")._data.shape == (180, 14) - assert_allclose(raw.info['sfreq'], 8.257638) + assert_allclose(raw.info["sfreq"], 8.257638) bad_nans = np.isnan(raw.get_data()).any(axis=1) assert np.sum(bad_nans) == 20 @@ -399,25 +426,31 @@ def test_snirf_kernel_hb(): @requires_testing_data -@pytest.mark.parametrize('fname, boundary_decimal, test_scaling, test_rank', ( - [sfnirs_homer_103_wShort, 0, True, True], - [nirx_nirsport2_103, 0, True, False], # strange rank behavior - [nirx_nirsport2_103_2, 0, False, True], # weirdly small values - [snirf_nirsport2_20219, 0, True, True], -)) +@pytest.mark.parametrize( + "fname, boundary_decimal, test_scaling, test_rank", + ( + [sfnirs_homer_103_wShort, 0, True, True], + [nirx_nirsport2_103, 0, True, False], # strange rank behavior + [nirx_nirsport2_103_2, 0, False, True], # weirdly small values + [snirf_nirsport2_20219, 0, True, True], + ), +) def test_snirf_standard(fname, boundary_decimal, test_scaling, test_rank): """Test standard operations.""" - _test_raw_reader(read_raw_snirf, fname=fname, - boundary_decimal=boundary_decimal, - test_scaling=test_scaling, - test_rank=test_rank) # low fs + _test_raw_reader( + read_raw_snirf, + fname=fname, + boundary_decimal=boundary_decimal, + test_scaling=test_scaling, + test_rank=test_rank, + ) # low fs @requires_testing_data def test_annotation_description_from_stim_groups(): """Test annotation descriptions parsed from stim group names.""" raw = read_raw_snirf(nirx_nirsport2_103_2, preload=True) - expected_descriptions = ['1', '2', '6'] + expected_descriptions = ["1", "2", "6"] assert_equal(expected_descriptions, raw.annotations.description) @@ -432,5 +465,5 @@ def test_annotation_duration_from_stim_groups(): # which represents duration, will be all 10s. # from snirf import Snirf # a = Snirf(snirf_nirsport2_20219, "r+"); print(a.nirs[0].stim[0].data) - expected_durations = np.full((10,), 10.) + expected_durations = np.full((10,), 10.0) assert_equal(expected_durations, raw.annotations.duration) diff --git a/mne/io/tag.py b/mne/io/tag.py index 21077701192..291c02c59e0 100644 --- a/mne/io/tag.py +++ b/mne/io/tag.py @@ -8,15 +8,22 @@ import numpy as np -from .constants import (FIFF, _dig_kind_named, _dig_cardinal_named, - _ch_kind_named, _ch_coil_type_named, _ch_unit_named, - _ch_unit_mul_named) +from .constants import ( + FIFF, + _dig_kind_named, + _dig_cardinal_named, + _ch_kind_named, + _ch_coil_type_named, + _ch_unit_named, + _ch_unit_mul_named, +) from ..utils.numerics import _julian_to_cal ############################################################################## # HELPERS + class Tag: """Tag in FIF tree structure. @@ -44,20 +51,27 @@ def __init__(self, kind, type_, size, next, pos=None): # noqa: D102 self.data = None def __repr__(self): # noqa: D105 - out = ("iIii', s)) + return Tag(*struct.unpack(">iIii", s)) _matrix_bit_dtype = { - FIFF.FIFFT_INT: (4, '>i4'), - FIFF.FIFFT_JULIAN: (4, '>i4'), - FIFF.FIFFT_FLOAT: (4, '>f4'), - FIFF.FIFFT_DOUBLE: (8, '>f8'), - FIFF.FIFFT_COMPLEX_FLOAT: (8, '>f4'), - FIFF.FIFFT_COMPLEX_DOUBLE: (16, '>f8'), + FIFF.FIFFT_INT: (4, ">i4"), + FIFF.FIFFT_JULIAN: (4, ">i4"), + FIFF.FIFFT_FLOAT: (4, ">f4"), + FIFF.FIFFT_DOUBLE: (8, ">f8"), + FIFF.FIFFT_COMPLEX_FLOAT: (8, ">f4"), + FIFF.FIFFT_COMPLEX_DOUBLE: (16, ">f8"), } def _read_matrix(fid, tag, shape, rlims, matrix_coding): """Read a matrix (dense or sparse) tag.""" from scipy import sparse + matrix_coding = matrix_coding >> 16 # This should be easy to implement (see _frombuffer_rows) # if we need it, but for now, it's not... if shape is not None: - raise ValueError('Row reading not implemented for matrices ' - 'yet') + raise ValueError("Row reading not implemented for matrices " "yet") # Matrices if matrix_coding == _matrix_coding_dense: # Find dimensions and return to the beginning of tag data pos = fid.tell() fid.seek(tag.size - 4, 1) - ndim = int(np.frombuffer(fid.read(4), dtype='>i4').item()) + ndim = int(np.frombuffer(fid.read(4), dtype=">i4").item()) fid.seek(-(ndim + 1) * 4, 1) - dims = np.frombuffer(fid.read(4 * ndim), dtype='>i4')[::-1] + dims = np.frombuffer(fid.read(4 * ndim), dtype=">i4")[::-1] # # Back to where the data start # fid.seek(pos, 0) if ndim > 3: - raise Exception('Only 2 or 3-dimensional matrices are ' - 'supported at this time') + raise Exception( + "Only 2 or 3-dimensional matrices are " "supported at this time" + ) matrix_type = _data_type & tag.type try: bit, dtype = _matrix_bit_dtype[matrix_type] except KeyError: - raise RuntimeError('Cannot handle matrix of type %d yet' - % matrix_type) + raise RuntimeError("Cannot handle matrix of type %d yet" % matrix_type) data = fid.read(int(bit * dims.prod())) data = np.frombuffer(data, dtype=dtype) # Note: we need the non-conjugate transpose here if matrix_type == FIFF.FIFFT_COMPLEX_FLOAT: - data = data.view('>c8') + data = data.view(">c8") elif matrix_type == FIFF.FIFFT_COMPLEX_DOUBLE: - data = data.view('>c16') + data = data.view(">c16") data.shape = dims elif matrix_coding in (_matrix_coding_CCS, _matrix_coding_RCS): # Find dimensions and return to the beginning of tag data pos = fid.tell() fid.seek(tag.size - 4, 1) - ndim = int(np.frombuffer(fid.read(4), dtype='>i4').item()) + ndim = int(np.frombuffer(fid.read(4), dtype=">i4").item()) fid.seek(-(ndim + 2) * 4, 1) - dims = np.frombuffer(fid.read(4 * (ndim + 1)), dtype='>i4') + dims = np.frombuffer(fid.read(4 * (ndim + 1)), dtype=">i4") if ndim != 2: - raise Exception('Only two-dimensional matrices are ' - 'supported at this time') + raise Exception( + "Only two-dimensional matrices are " "supported at this time" + ) # Back to where the data start fid.seek(pos, 0) nnz = int(dims[0]) nrow = int(dims[1]) ncol = int(dims[2]) - data = np.frombuffer(fid.read(4 * nnz), dtype='>f4') + data = np.frombuffer(fid.read(4 * nnz), dtype=">f4") shape = (dims[1], dims[2]) if matrix_coding == _matrix_coding_CCS: # CCS tmp_indices = fid.read(4 * nnz) - indices = np.frombuffer(tmp_indices, dtype='>i4') + indices = np.frombuffer(tmp_indices, dtype=">i4") tmp_ptr = fid.read(4 * (ncol + 1)) - indptr = np.frombuffer(tmp_ptr, dtype='>i4') + indptr = np.frombuffer(tmp_ptr, dtype=">i4") if indptr[-1] > len(indices) or np.any(indptr < 0): # There was a bug in MNE-C that caused some data to be # stored without byte swapping indices = np.concatenate( - (np.frombuffer(tmp_indices[:4 * (nrow + 1)], dtype='>i4'), - np.frombuffer(tmp_indices[4 * (nrow + 1):], dtype=' len(indices) or np.any(indptr < 0): # There was a bug in MNE-C that caused some data to be # stored without byte swapping indices = np.concatenate( - (np.frombuffer(tmp_indices[:4 * (ncol + 1)], dtype='>i4'), - np.frombuffer(tmp_indices[4 * (ncol + 1):], dtype='B', - FIFF.FIFFT_SHORT: '>i2', - FIFF.FIFFT_INT: '>i4', - FIFF.FIFFT_USHORT: '>u2', - FIFF.FIFFT_UINT: '>u4', - FIFF.FIFFT_FLOAT: '>f4', - FIFF.FIFFT_DOUBLE: '>f8', - FIFF.FIFFT_DAU_PACK16: '>i2', + FIFF.FIFFT_BYTE: ">B", + FIFF.FIFFT_SHORT: ">i2", + FIFF.FIFFT_INT: ">i4", + FIFF.FIFFT_USHORT: ">u2", + FIFF.FIFFT_UINT: ">u4", + FIFF.FIFFT_FLOAT: ">f4", + FIFF.FIFFT_DOUBLE: ">f8", + FIFF.FIFFT_DAU_PACK16: ">i2", } for key, dtype in _simple_dict.items(): _call_dict[key] = partial(_read_simple, dtype=dtype) @@ -472,7 +495,7 @@ def read_tag(fid, pos=None, shape=None, rlims=None): try: fun = _call_dict[tag.type] except KeyError: - raise Exception('Unimplemented tag data type %s' % tag.type) + raise Exception("Unimplemented tag data type %s" % tag.type) tag.data = fun(fid, tag, shape, rlims) if tag.next != FIFF.FIFFV_NEXT_SEQ: # f.seek(tag.next,0) @@ -498,8 +521,8 @@ def find_tag(fid, node, findkind): tag : instance of Tag The first tag found. """ - if node['directory'] is not None: - for subnode in node['directory']: + if node["directory"] is not None: + for subnode in node["directory"]: if subnode.kind == findkind: return read_tag(fid, subnode.pos) return None @@ -507,7 +530,7 @@ def find_tag(fid, node, findkind): def has_tag(node, kind): """Check if the node contains a Tag of a given kind.""" - for d in node['directory']: + for d in node["directory"]: if d.kind == kind: return True return False diff --git a/mne/io/tests/__init__.py b/mne/io/tests/__init__.py index aba6507665f..ca22217d57a 100644 --- a/mne/io/tests/__init__.py +++ b/mne/io/tests/__init__.py @@ -1,3 +1,3 @@ import os.path as op -data_dir = op.join(op.dirname(__file__), 'data') +data_dir = op.join(op.dirname(__file__), "data") diff --git a/mne/io/tests/test_apply_function.py b/mne/io/tests/test_apply_function.py index 887ba6a8eb3..920dd404dc6 100644 --- a/mne/io/tests/test_apply_function.py +++ b/mne/io/tests/test_apply_function.py @@ -27,7 +27,7 @@ def bad_3(x): def printer(x): """Print.""" - logger.info('exec') + logger.info("exec") return x @@ -37,23 +37,22 @@ def test_apply_function_verbose(): n_chan = 2 n_times = 3 ch_names = [str(ii) for ii in range(n_chan)] - raw = RawArray(np.zeros((n_chan, n_times)), - create_info(ch_names, 1., 'mag')) + raw = RawArray(np.zeros((n_chan, n_times)), create_info(ch_names, 1.0, "mag")) # test return types in both code paths (parallel / 1 job) - with pytest.raises(TypeError, match='Return value must be an ndarray'): + with pytest.raises(TypeError, match="Return value must be an ndarray"): raw.apply_function(bad_1) - with pytest.raises(ValueError, match='Return data must have shape'): + with pytest.raises(ValueError, match="Return data must have shape"): raw.apply_function(bad_2) - with pytest.raises(TypeError, match='Return value must be an ndarray'): + with pytest.raises(TypeError, match="Return value must be an ndarray"): raw.apply_function(bad_1, n_jobs=2) - with pytest.raises(ValueError, match='Return data must have shape'): + with pytest.raises(ValueError, match="Return data must have shape"): raw.apply_function(bad_2, n_jobs=2) # test return type when `channel_wise=False` raw.apply_function(printer, channel_wise=False) - with pytest.raises(TypeError, match='Return value must be an ndarray'): + with pytest.raises(TypeError, match="Return value must be an ndarray"): raw.apply_function(bad_1, channel_wise=False) - with pytest.raises(ValueError, match='Return data must have shape'): + with pytest.raises(ValueError, match="Return data must have shape"): raw.apply_function(bad_3, channel_wise=False) # check our arguments @@ -62,4 +61,4 @@ def test_apply_function_verbose(): assert len(sio.getvalue(close=False)) == 0 assert out is raw raw.apply_function(printer, verbose=True) - assert sio.getvalue().count('\n') == n_chan + assert sio.getvalue().count("\n") == n_chan diff --git a/mne/io/tests/test_compensator.py b/mne/io/tests/test_compensator.py index bb8b33bce32..2c2e2299f65 100644 --- a/mne/io/tests/test_compensator.py +++ b/mne/io/tests/test_compensator.py @@ -40,8 +40,8 @@ def test_compensation_identity(): assert_allclose(np.dot(comp2, comp1), desired, atol=1e-12) -@pytest.mark.parametrize('preload', (True, False)) -@pytest.mark.parametrize('pick', (False, True)) +@pytest.mark.parametrize("preload", (True, False)) +@pytest.mark.parametrize("pick", (False, True)) def test_compensation_apply(tmp_path, preload, pick): """Test applying compensation.""" # make sure that changing the comp doesn't modify the original data @@ -67,13 +67,14 @@ def test_compensation_apply(tmp_path, preload, pick): data2, _ = raw2[:, :] # channels have norm ~1e-12 assert_allclose(data, data2, rtol=1e-9, atol=1e-18) - for ch1, ch2 in zip(raw.info['chs'], raw2.info['chs']): - assert ch1['coil_type'] == ch2['coil_type'] + for ch1, ch2 in zip(raw.info["chs"], raw2.info["chs"]): + assert ch1["coil_type"] == ch2["coil_type"] @requires_mne def test_compensation_mne(tmp_path): """Test comensation by comparing with MNE.""" + def make_evoked(fname, comp): """Make evoked data.""" raw = read_raw_fif(fname) @@ -81,15 +82,21 @@ def make_evoked(fname, comp): raw.apply_gradient_compensation(comp) picks = pick_types(raw.info, meg=True, ref_meg=True) events = np.array([[0, 0, 1]], dtype=np.int64) - evoked = Epochs(raw, events, 1, 0, 20e-3, picks=picks, - baseline=None).average() + evoked = Epochs(raw, events, 1, 0, 20e-3, picks=picks, baseline=None).average() return evoked def compensate_mne(fname, comp): """Compensate using MNE-C.""" - tmp_fname = '%s-%d-ave.fif' % (fname.stem, comp) - cmd = ['mne_compensate_data', '--in', str(fname), - '--out', tmp_fname, '--grad', str(comp)] + tmp_fname = "%s-%d-ave.fif" % (fname.stem, comp) + cmd = [ + "mne_compensate_data", + "--in", + str(fname), + "--out", + tmp_fname, + "--grad", + str(comp), + ] run_subprocess(cmd) return read_evokeds(tmp_fname)[0] @@ -102,9 +109,10 @@ def compensate_mne(fname, comp): evoked_c = compensate_mne(fname_default, comp) picks_py = pick_types(evoked_py.info, meg=True, ref_meg=True) picks_c = pick_types(evoked_c.info, meg=True, ref_meg=True) - assert_allclose(evoked_py.data[picks_py], evoked_c.data[picks_c], - rtol=1e-3, atol=1e-17) - chs_py = [evoked_py.info['chs'][ii] for ii in picks_py] - chs_c = [evoked_c.info['chs'][ii] for ii in picks_c] + assert_allclose( + evoked_py.data[picks_py], evoked_c.data[picks_c], rtol=1e-3, atol=1e-17 + ) + chs_py = [evoked_py.info["chs"][ii] for ii in picks_py] + chs_c = [evoked_c.info["chs"][ii] for ii in picks_c] for ch_py, ch_c in zip(chs_py, chs_c): - assert ch_py['coil_type'] == ch_c['coil_type'] + assert ch_py["coil_type"] == ch_c["coil_type"] diff --git a/mne/io/tests/test_constants.py b/mne/io/tests/test_constants.py index 2f05d73b19a..5c84c0fb211 100644 --- a/mne/io/tests/test_constants.py +++ b/mne/io/tests/test_constants.py @@ -11,173 +11,219 @@ import pooch import pytest -from mne.io.constants import (FIFF, FWD, _coord_frame_named, _ch_kind_named, - _ch_unit_named, _ch_unit_mul_named, - _ch_coil_type_named, _dig_kind_named, - _dig_cardinal_named) +from mne.io.constants import ( + FIFF, + FWD, + _coord_frame_named, + _ch_kind_named, + _ch_unit_named, + _ch_unit_mul_named, + _ch_coil_type_named, + _dig_kind_named, + _dig_cardinal_named, +) from mne.forward._make_forward import _read_coil_defs from mne.utils import requires_good_network # https://github.com/mne-tools/fiff-constants/commits/master -REPO = 'mne-tools' -COMMIT = 'e27f68cbf74dbfc5193ad429cc77900a59475181' +REPO = "mne-tools" +COMMIT = "e27f68cbf74dbfc5193ad429cc77900a59475181" # These are oddities that we won't address: iod_dups = (355, 359) # these are in both MEGIN and MNE files tag_dups = (3501,) # in both MEGIN and MNE files -_dir_ignore_names = ('clear', 'copy', 'fromkeys', 'get', 'items', 'keys', - 'pop', 'popitem', 'setdefault', 'update', 'values', - 'has_key', 'iteritems', 'iterkeys', 'itervalues', # Py2 - 'viewitems', 'viewkeys', 'viewvalues', # Py2 - ) -_tag_ignore_names = ( -) # for fiff-constants pending updates +_dir_ignore_names = ( + "clear", + "copy", + "fromkeys", + "get", + "items", + "keys", + "pop", + "popitem", + "setdefault", + "update", + "values", + "has_key", + "iteritems", + "iterkeys", + "itervalues", # Py2 + "viewitems", + "viewkeys", + "viewvalues", # Py2 +) +_tag_ignore_names = () # for fiff-constants pending updates _ignore_incomplete_enums = ( # XXX eventually we could complete these - 'bem_surf_id', 'cardinal_point_cardiac', 'cond_model', 'coord', - 'dacq_system', 'diffusion_param', 'gantry_type', 'map_surf', - 'mne_lin_proj', 'mne_ori', 'mri_format', 'mri_pixel', 'proj_by', - 'tags', 'type', 'iod', 'volume_type', 'vol_type', + "bem_surf_id", + "cardinal_point_cardiac", + "cond_model", + "coord", + "dacq_system", + "diffusion_param", + "gantry_type", + "map_surf", + "mne_lin_proj", + "mne_ori", + "mri_format", + "mri_pixel", + "proj_by", + "tags", + "type", + "iod", + "volume_type", + "vol_type", ) # not in coil_def.dat but in DictionaryTypes:enum(coil) _missing_coil_def = ( - 0, # The location info contains no data - 1, # EEG electrode position in r0 - 3, # Old 24 channel system in HUT - 4, # The axial devices in the HUCS MCG system - 5, # Bipolar EEG electrode position - 6, # CSD-transformed EEG electrodes - 200, # Time-varying dipole definition - 300, # fNIRS oxyhemoglobin - 301, # fNIRS deoxyhemoglobin - 302, # fNIRS continuous wave - 303, # fNIRS optical density - 304, # fNIRS frequency domain AC amplitude - 305, # fNIRS frequency domain phase - 306, # fNIRS time domain gated amplitude - 307, # fNIRS time domain moments amplitude - 400, # Eye-tracking gaze position - 401, # Eye-tracking pupil size - 1000, # For testing the MCG software - 2001, # Generic axial gradiometer - 3011, # VV prototype wirewound planar sensor - 3014, # Vectorview SQ20950N planar gradiometer - 3021, # VV prototype wirewound magnetometer + 0, # The location info contains no data + 1, # EEG electrode position in r0 + 3, # Old 24 channel system in HUT + 4, # The axial devices in the HUCS MCG system + 5, # Bipolar EEG electrode position + 6, # CSD-transformed EEG electrodes + 200, # Time-varying dipole definition + 300, # fNIRS oxyhemoglobin + 301, # fNIRS deoxyhemoglobin + 302, # fNIRS continuous wave + 303, # fNIRS optical density + 304, # fNIRS frequency domain AC amplitude + 305, # fNIRS frequency domain phase + 306, # fNIRS time domain gated amplitude + 307, # fNIRS time domain moments amplitude + 400, # Eye-tracking gaze position + 401, # Eye-tracking pupil size + 1000, # For testing the MCG software + 2001, # Generic axial gradiometer + 3011, # VV prototype wirewound planar sensor + 3014, # Vectorview SQ20950N planar gradiometer + 3021, # VV prototype wirewound magnetometer ) # explicit aliases in constants.py _aliases = dict( - FIFFV_COIL_MAGNES_R_MAG='FIFFV_COIL_MAGNES_REF_MAG', - FIFFV_COIL_MAGNES_R_GRAD='FIFFV_COIL_MAGNES_REF_GRAD', - FIFFV_COIL_MAGNES_R_GRAD_OFF='FIFFV_COIL_MAGNES_OFFDIAG_REF_GRAD', - FIFFV_COIL_FNIRS_RAW='FIFFV_COIL_FNIRS_CW_AMPLITUDE', - FIFFV_MNE_COORD_CTF_HEAD='FIFFV_MNE_COORD_4D_HEAD', - FIFFV_MNE_COORD_KIT_HEAD='FIFFV_MNE_COORD_4D_HEAD', - FIFFV_MNE_COORD_DIGITIZER='FIFFV_COORD_ISOTRAK', - FIFFV_MNE_COORD_SURFACE_RAS='FIFFV_COORD_MRI', - FIFFV_MNE_SENSOR_COV='FIFFV_MNE_NOISE_COV', - FIFFV_POINT_EEG='FIFFV_POINT_ECG', - FIFF_DESCRIPTION='FIFF_COMMENT', - FIFF_REF_PATH='FIFF_MRI_SOURCE_PATH', + FIFFV_COIL_MAGNES_R_MAG="FIFFV_COIL_MAGNES_REF_MAG", + FIFFV_COIL_MAGNES_R_GRAD="FIFFV_COIL_MAGNES_REF_GRAD", + FIFFV_COIL_MAGNES_R_GRAD_OFF="FIFFV_COIL_MAGNES_OFFDIAG_REF_GRAD", + FIFFV_COIL_FNIRS_RAW="FIFFV_COIL_FNIRS_CW_AMPLITUDE", + FIFFV_MNE_COORD_CTF_HEAD="FIFFV_MNE_COORD_4D_HEAD", + FIFFV_MNE_COORD_KIT_HEAD="FIFFV_MNE_COORD_4D_HEAD", + FIFFV_MNE_COORD_DIGITIZER="FIFFV_COORD_ISOTRAK", + FIFFV_MNE_COORD_SURFACE_RAS="FIFFV_COORD_MRI", + FIFFV_MNE_SENSOR_COV="FIFFV_MNE_NOISE_COV", + FIFFV_POINT_EEG="FIFFV_POINT_ECG", + FIFF_DESCRIPTION="FIFF_COMMENT", + FIFF_REF_PATH="FIFF_MRI_SOURCE_PATH", ) @requires_good_network def test_constants(tmp_path): """Test compensation.""" - fname = 'fiff.zip' + fname = "fiff.zip" dest = tmp_path / fname pooch.retrieve( - url='https://codeload.github.com/' - f'{REPO}/fiff-constants/zip/{COMMIT}', + url="https://codeload.github.com/" f"{REPO}/fiff-constants/zip/{COMMIT}", path=tmp_path, fname=fname, - known_hash=None + known_hash=None, ) names = list() - with zipfile.ZipFile(dest, 'r') as ff: + with zipfile.ZipFile(dest, "r") as ff: for name in ff.namelist(): - if 'Dictionary' in name: + if "Dictionary" in name: ff.extract(name, tmp_path) names.append(os.path.basename(name)) shutil.move(tmp_path / name, tmp_path / names[-1]) names = sorted(names) - assert names == ['DictionaryIOD.txt', 'DictionaryIOD_MNE.txt', - 'DictionaryStructures.txt', - 'DictionaryTags.txt', 'DictionaryTags_MNE.txt', - 'DictionaryTypes.txt', 'DictionaryTypes_MNE.txt'] + assert names == [ + "DictionaryIOD.txt", + "DictionaryIOD_MNE.txt", + "DictionaryStructures.txt", + "DictionaryTags.txt", + "DictionaryTags_MNE.txt", + "DictionaryTypes.txt", + "DictionaryTypes_MNE.txt", + ] # IOD (MEGIN and MNE) fif = dict(iod=dict(), tags=dict(), types=dict(), defines=dict()) con = dict(iod=dict(), tags=dict(), types=dict(), defines=dict()) fiff_version = None - for name in ['DictionaryIOD.txt', 'DictionaryIOD_MNE.txt']: + for name in ["DictionaryIOD.txt", "DictionaryIOD_MNE.txt"]: with open(tmp_path / name, "rb") as fid: for line in fid: - line = line.decode('latin1').strip() - if line.startswith('# Packing revision'): + line = line.decode("latin1").strip() + if line.startswith("# Packing revision"): assert fiff_version is None fiff_version = line.split()[-1] - if (line.startswith('#') or line.startswith('alias') or - len(line) == 0): + if line.startswith("#") or line.startswith("alias") or len(line) == 0: continue line = line.split('"') assert len(line) in (1, 2, 3) - desc = '' if len(line) == 1 else line[1] + desc = "" if len(line) == 1 else line[1] line = line[0].split() assert len(line) in (2, 3) if len(line) == 2: kind, id_ = line else: kind, id_, tagged = line - assert tagged in ('tagged',) + assert tagged in ("tagged",) id_ = int(id_) if id_ not in iod_dups: - assert id_ not in fif['iod'] - fif['iod'][id_] = [kind, desc] + assert id_ not in fif["iod"] + fif["iod"][id_] = [kind, desc] # Tags (MEGIN) with open(tmp_path / "DictionaryTags.txt", "rb") as fid: for line in fid: - line = line.decode('ISO-8859-1').strip() - if (line.startswith('#') or line.startswith('alias') or - line.startswith(':') or len(line) == 0): + line = line.decode("ISO-8859-1").strip() + if ( + line.startswith("#") + or line.startswith("alias") + or line.startswith(":") + or len(line) == 0 + ): continue line = line.split('"') assert len(line) in (1, 2, 3), line - desc = '' if len(line) == 1 else line[1] + desc = "" if len(line) == 1 else line[1] line = line[0].split() assert len(line) == 4, line kind, id_, dtype, unit = line id_ = int(id_) val = [kind, dtype, unit] - assert id_ not in fif['tags'], (fif['tags'].get(id_), val) - fif['tags'][id_] = val + assert id_ not in fif["tags"], (fif["tags"].get(id_), val) + fif["tags"][id_] = val # Tags (MNE) with open(tmp_path / "DictionaryTags_MNE.txt", "rb") as fid: for li, line in enumerate(fid): - line = line.decode('ISO-8859-1').strip() + line = line.decode("ISO-8859-1").strip() # ignore continuation lines (*) - if (line.startswith('#') or line.startswith('alias') or - line.startswith(':') or line.startswith('*') or - len(line) == 0): + if ( + line.startswith("#") + or line.startswith("alias") + or line.startswith(":") + or line.startswith("*") + or len(line) == 0 + ): continue # weird syntax around line 80: - if line in ('/*', '"'): + if line in ("/*", '"'): continue line = line.split('"') assert len(line) in (1, 2, 3), line if len(line) == 3 and len(line[2]) > 0: l2 = line[2].strip() - assert l2.startswith('/*') and l2.endswith('*/'), l2 - desc = '' if len(line) == 1 else line[1] + assert l2.startswith("/*") and l2.endswith("*/"), l2 + desc = "" if len(line) == 1 else line[1] line = line[0].split() assert len(line) == 3, (li + 1, line) kind, id_, dtype = line - unit = '-' + unit = "-" id_ = int(id_) val = [kind, dtype, unit] if id_ not in tag_dups: - assert id_ not in fif['tags'], (fif['tags'].get(id_), val) - fif['tags'][id_] = val + assert id_ not in fif["tags"], (fif["tags"].get(id_), val) + fif["tags"][id_] = val # Types and enums in_ = None @@ -186,10 +232,10 @@ def test_constants(tmp_path): re_enum_entry = re.compile(r'\s*(\S*)\s*(\S*)\s*"(.*)"$') re_defi = re.compile(r'#define\s*(\S*)\s*(\S*)\s*"(.*)"$') used_enums = list() - for extra in ('', '_MNE'): + for extra in ("", "_MNE"): with open(tmp_path / f"DictionaryTypes{extra}.txt", "rb") as fid: for li, line in enumerate(fid): - line = line.decode('ISO-8859-1').strip() + line = line.decode("ISO-8859-1").strip() if in_ is None: p = re_prim.match(line) e = re_enum.match(line) @@ -197,8 +243,8 @@ def test_constants(tmp_path): if p is not None: t, s, d = p.groups() s = int(s) - assert s not in fif['types'] - fif['types'][s] = [t, d] + assert s not in fif["types"] + fif["types"][s] = [t, d] elif e is not None: # entering an enum this_enum = e.group(1) @@ -210,18 +256,18 @@ def test_constants(tmp_path): elif d is not None: t, s, d = d.groups() s = int(s) - fif['defines'][t] = [s, d] + fif["defines"][t] = [s, d] else: - assert not line.startswith('enum(') + assert not line.startswith("enum(") else: # in an enum - if line == '{': + if line == "{": continue - elif line == '}': + elif line == "}": in_ = None continue t, s, d = re_enum_entry.match(line).groups() s = int(s) - if t != 'ecg' and s != 3: # ecg defined the same way + if t != "ecg" and s != 3: # ecg defined the same way assert s not in in_ in_[s] = [t, d] @@ -230,115 +276,135 @@ def test_constants(tmp_path): # # Version - mne_version = '%d.%d' % (FIFF.FIFFC_MAJOR_VERSION, - FIFF.FIFFC_MINOR_VERSION) + mne_version = "%d.%d" % (FIFF.FIFFC_MAJOR_VERSION, FIFF.FIFFC_MINOR_VERSION) assert fiff_version == mne_version unknowns = list() # Assert that all our constants are in the FIF def - assert 'FIFFV_SSS_JOB_NOTHING' in dir(FIFF) + assert "FIFFV_SSS_JOB_NOTHING" in dir(FIFF) for name in sorted(dir(FIFF)): - if name.startswith('_') or name in _dir_ignore_names: + if name.startswith("_") or name in _dir_ignore_names: continue check = None val = getattr(FIFF, name) - if name in fif['defines']: - assert fif['defines'][name][0] == val - elif name.startswith('FIFFC_'): + if name in fif["defines"]: + assert fif["defines"][name][0] == val + elif name.startswith("FIFFC_"): # Checked above - assert name in ('FIFFC_MAJOR_VERSION', 'FIFFC_MINOR_VERSION', - 'FIFFC_VERSION') - elif name.startswith('FIFFB_'): - check = 'iod' - elif name.startswith('FIFFT_'): - check = 'types' - elif name.startswith('FIFFV_'): - if name.startswith('FIFFV_MNE_') and name.endswith('_ORI'): - check = 'mne_ori' - elif name.startswith('FIFFV_MNE_') and name.endswith('_COV'): - check = 'covariance_type' - elif name.startswith('FIFFV_MNE_COORD'): - check = 'coord' # weird wrapper - elif name.endswith('_CH') or '_QUAT_' in name or name in \ - ('FIFFV_DIPOLE_WAVE', 'FIFFV_GOODNESS_FIT', - 'FIFFV_HPI_ERR', 'FIFFV_HPI_G', 'FIFFV_HPI_MOV'): - check = 'ch_type' - elif name.startswith('FIFFV_SUBJ_'): - check = name.split('_')[2].lower() - elif name in ('FIFFV_POINT_LPA', 'FIFFV_POINT_NASION', - 'FIFFV_POINT_RPA', 'FIFFV_POINT_INION'): - check = 'cardinal_point' + assert name in ( + "FIFFC_MAJOR_VERSION", + "FIFFC_MINOR_VERSION", + "FIFFC_VERSION", + ) + elif name.startswith("FIFFB_"): + check = "iod" + elif name.startswith("FIFFT_"): + check = "types" + elif name.startswith("FIFFV_"): + if name.startswith("FIFFV_MNE_") and name.endswith("_ORI"): + check = "mne_ori" + elif name.startswith("FIFFV_MNE_") and name.endswith("_COV"): + check = "covariance_type" + elif name.startswith("FIFFV_MNE_COORD"): + check = "coord" # weird wrapper + elif ( + name.endswith("_CH") + or "_QUAT_" in name + or name + in ( + "FIFFV_DIPOLE_WAVE", + "FIFFV_GOODNESS_FIT", + "FIFFV_HPI_ERR", + "FIFFV_HPI_G", + "FIFFV_HPI_MOV", + ) + ): + check = "ch_type" + elif name.startswith("FIFFV_SUBJ_"): + check = name.split("_")[2].lower() + elif name in ( + "FIFFV_POINT_LPA", + "FIFFV_POINT_NASION", + "FIFFV_POINT_RPA", + "FIFFV_POINT_INION", + ): + check = "cardinal_point" else: for check in used_enums: - if name.startswith('FIFFV_' + check.upper()): + if name.startswith("FIFFV_" + check.upper()): break else: if name not in _tag_ignore_names: - raise RuntimeError('Could not find %s' % (name,)) + raise RuntimeError("Could not find %s" % (name,)) assert check in used_enums, name - if 'SSS' in check: + if "SSS" in check: raise RuntimeError - elif name.startswith('FIFF_UNIT'): # units and multipliers - check = name.split('_')[1].lower() - elif name.startswith('FIFF_'): - check = 'tags' + elif name.startswith("FIFF_UNIT"): # units and multipliers + check = name.split("_")[1].lower() + elif name.startswith("FIFF_"): + check = "tags" else: unknowns.append((name, val)) if check is not None and name not in _tag_ignore_names: - assert val in fif[check], '%s: %s, %s' % (check, val, name) + assert val in fif[check], "%s: %s, %s" % (check, val, name) if val in con[check]: msg = "%s='%s' ?" % (name, con[check][val]) assert _aliases.get(name) == con[check][val], msg else: con[check][val] = name - unknowns = '\n\t'.join('%s (%s)' % u for u in unknowns) - assert len(unknowns) == 0, 'Unknown types\n\t%s' % unknowns + unknowns = "\n\t".join("%s (%s)" % u for u in unknowns) + assert len(unknowns) == 0, "Unknown types\n\t%s" % unknowns # Assert that all the FIF defs are in our constants assert set(fif.keys()) == set(con.keys()) - for key in sorted(set(fif.keys()) - {'defines'}): + for key in sorted(set(fif.keys()) - {"defines"}): this_fif, this_con = fif[key], con[key] assert len(set(this_fif.keys())) == len(this_fif) assert len(set(this_con.keys())) == len(this_con) missing_from_con = sorted(set(this_con.keys()) - set(this_fif.keys())) assert missing_from_con == [], key if key not in _ignore_incomplete_enums: - missing_from_fif = sorted(set(this_fif.keys()) - - set(this_con.keys())) + missing_from_fif = sorted(set(this_fif.keys()) - set(this_con.keys())) assert missing_from_fif == [], key # Assert that `coil_def.dat` has accurate descriptions of all enum(coil) coil_def = _read_coil_defs() - coil_desc = np.array([c['desc'] for c in coil_def]) - coil_def = np.array([(c['coil_type'], c['accuracy']) - for c in coil_def], int) - mask = (coil_def[:, 1] == FWD.COIL_ACCURACY_ACCURATE) + coil_desc = np.array([c["desc"] for c in coil_def]) + coil_def = np.array([(c["coil_type"], c["accuracy"]) for c in coil_def], int) + mask = coil_def[:, 1] == FWD.COIL_ACCURACY_ACCURATE coil_def = coil_def[mask, 0] coil_desc = coil_desc[mask] bad_list = [] - for key in fif['coil']: + for key in fif["coil"]: if key not in _missing_coil_def and key not in coil_def: - bad_list.append((' %s,' % key).ljust(10) + - ' # ' + fif['coil'][key][1]) - assert len(bad_list) == 0, \ - '\nIn fiff-constants, missing from coil_def:\n' + '\n'.join(bad_list) + bad_list.append((" %s," % key).ljust(10) + " # " + fif["coil"][key][1]) + assert ( + len(bad_list) == 0 + ), "\nIn fiff-constants, missing from coil_def:\n" + "\n".join(bad_list) # Assert that enum(coil) has all `coil_def.dat` entries for key, desc in zip(coil_def, coil_desc): - if key not in fif['coil']: - bad_list.append((' %s,' % key).ljust(10) + ' # ' + desc) - assert len(bad_list) == 0, \ - 'In coil_def, missing from fiff-constants:\n' + '\n'.join(bad_list) + if key not in fif["coil"]: + bad_list.append((" %s," % key).ljust(10) + " # " + desc) + assert ( + len(bad_list) == 0 + ), "In coil_def, missing from fiff-constants:\n" + "\n".join(bad_list) -@pytest.mark.parametrize('dict_, match, extras', [ - ({**_dig_kind_named, **_dig_cardinal_named}, 'FIFFV_POINT_', ()), - (_ch_kind_named, '^FIFFV_.*_CH$', - (FIFF.FIFFV_DIPOLE_WAVE, FIFF.FIFFV_GOODNESS_FIT)), - (_coord_frame_named, 'FIFFV_COORD_', ()), - (_ch_unit_named, 'FIFF_UNIT_', ()), - (_ch_unit_mul_named, 'FIFF_UNITM_', ()), - (_ch_coil_type_named, 'FIFFV_COIL_', ()), -]) +@pytest.mark.parametrize( + "dict_, match, extras", + [ + ({**_dig_kind_named, **_dig_cardinal_named}, "FIFFV_POINT_", ()), + ( + _ch_kind_named, + "^FIFFV_.*_CH$", + (FIFF.FIFFV_DIPOLE_WAVE, FIFF.FIFFV_GOODNESS_FIT), + ), + (_coord_frame_named, "FIFFV_COORD_", ()), + (_ch_unit_named, "FIFF_UNIT_", ()), + (_ch_unit_mul_named, "FIFF_UNITM_", ()), + (_ch_coil_type_named, "FIFFV_COIL_", ()), + ], +) def test_dict_completion(dict_, match, extras): """Test readable dict completions.""" regex = re.compile(match) diff --git a/mne/io/tests/test_meas_info.py b/mne/io/tests/test_meas_info.py index 4af6d0ebe78..ed7f9ed2616 100644 --- a/mne/io/tests/test_meas_info.py +++ b/mne/io/tests/test_meas_info.py @@ -14,27 +14,66 @@ from scipy import sparse import string -from mne import (Epochs, read_events, pick_info, pick_types, Annotations, - read_evokeds, make_forward_solution, make_sphere_model, - setup_volume_source_space, write_forward_solution, - read_forward_solution, write_cov, read_cov, read_epochs, - compute_covariance) -from mne.channels import (read_polhemus_fastscan, make_standard_montage, - equalize_channels) +from mne import ( + Epochs, + read_events, + pick_info, + pick_types, + Annotations, + read_evokeds, + make_forward_solution, + make_sphere_model, + setup_volume_source_space, + write_forward_solution, + read_forward_solution, + write_cov, + read_cov, + read_epochs, + compute_covariance, +) +from mne.channels import ( + read_polhemus_fastscan, + make_standard_montage, + equalize_channels, +) from mne.event import make_fixed_length_events from mne.datasets import testing -from mne.io import (read_fiducials, write_fiducials, _coil_trans_to_loc, - _loc_to_coil_trans, read_raw_fif, read_info, write_info, - meas_info, Projection, BaseRaw, read_raw_ctf, RawArray) +from mne.io import ( + read_fiducials, + write_fiducials, + _coil_trans_to_loc, + _loc_to_coil_trans, + read_raw_fif, + read_info, + write_info, + meas_info, + Projection, + BaseRaw, + read_raw_ctf, + RawArray, +) from mne.io.constants import FIFF from mne.io.write import _generate_meas_id, DATE_NONE -from mne.io.meas_info import (Info, create_info, _merge_info, - _force_update_info, RAW_INFO_FIELDS, - _bad_chans_comp, _get_valid_units, - anonymize_info, _stamp_to_dt, _dt_to_stamp, - _add_timedelta_to_stamp, _read_extended_ch_info) -from mne.minimum_norm import (make_inverse_operator, write_inverse_operator, - read_inverse_operator, apply_inverse) +from mne.io.meas_info import ( + Info, + create_info, + _merge_info, + _force_update_info, + RAW_INFO_FIELDS, + _bad_chans_comp, + _get_valid_units, + anonymize_info, + _stamp_to_dt, + _dt_to_stamp, + _add_timedelta_to_stamp, + _read_extended_ch_info, +) +from mne.minimum_norm import ( + make_inverse_operator, + write_inverse_operator, + read_inverse_operator, + apply_inverse, +) from mne.io._digitization import _write_dig_points, _make_dig_points, DigPoint from mne.transforms import Transform from mne.utils import catch_logging, assert_object_equal, _record_warnings @@ -61,25 +100,31 @@ raw_invalid_bday_fname = data_path / "misc" / "sample_invalid_birthday_raw.fif" -@pytest.mark.parametrize('kwargs, want', [ - (dict(meg=False, eeg=True), [0]), - (dict(meg=False, fnirs=True), [5]), - (dict(meg=False, fnirs='hbo'), [5]), - (dict(meg=False, fnirs='hbr'), []), - (dict(meg=False, misc=True), [1]), - (dict(meg=True), [2, 3, 4]), - (dict(meg='grad'), [2, 3]), - (dict(meg='planar1'), [2]), - (dict(meg='planar2'), [3]), - (dict(meg='mag'), [4]), -]) +@pytest.mark.parametrize( + "kwargs, want", + [ + (dict(meg=False, eeg=True), [0]), + (dict(meg=False, fnirs=True), [5]), + (dict(meg=False, fnirs="hbo"), [5]), + (dict(meg=False, fnirs="hbr"), []), + (dict(meg=False, misc=True), [1]), + (dict(meg=True), [2, 3, 4]), + (dict(meg="grad"), [2, 3]), + (dict(meg="planar1"), [2]), + (dict(meg="planar2"), [3]), + (dict(meg="mag"), [4]), + ], +) def test_create_info_grad(kwargs, want): """Test create_info behavior with grad coils.""" info = create_info(6, 256, ["eeg", "misc", "grad", "grad", "mag", "hbo"]) # Put these in an order such that grads get named "2" and "3", since # they get picked based first on coil_type then ch_name... - assert [ch['ch_name'] for ch in info['chs'] - if ch['coil_type'] == FIFF.FIFFV_COIL_VV_PLANAR_T1] == ['2', '3'] + assert [ + ch["ch_name"] + for ch in info["chs"] + if ch["coil_type"] == FIFF.FIFFV_COIL_VV_PLANAR_T1 + ] == ["2", "3"] picks = pick_types(info, **kwargs) assert_array_equal(picks, want) @@ -105,75 +150,81 @@ def test_coil_trans(): def test_make_info(): """Test some create_info properties.""" n_ch = np.longlong(1) - info = create_info(n_ch, 1000., 'eeg') + info = create_info(n_ch, 1000.0, "eeg") assert set(info.keys()) == set(RAW_INFO_FIELDS) - coil_types = {ch['coil_type'] for ch in info['chs']} + coil_types = {ch["coil_type"] for ch in info["chs"]} assert FIFF.FIFFV_COIL_EEG in coil_types - pytest.raises(TypeError, create_info, ch_names='Test Ch', sfreq=1000) - pytest.raises(ValueError, create_info, ch_names=['Test Ch'], sfreq=-1000) - pytest.raises(ValueError, create_info, ch_names=['Test Ch'], sfreq=1000, - ch_types=['eeg', 'eeg']) - pytest.raises(TypeError, create_info, ch_names=[np.array([1])], - sfreq=1000) - pytest.raises(KeyError, create_info, ch_names=['Test Ch'], sfreq=1000, - ch_types=np.array([1])) - pytest.raises(KeyError, create_info, ch_names=['Test Ch'], sfreq=1000, - ch_types='awesome') - pytest.raises(TypeError, create_info, ['Test Ch'], sfreq=1000, - montage=np.array([1])) - m = make_standard_montage('biosemi32') - info = create_info(ch_names=m.ch_names, sfreq=1000., ch_types='eeg') + pytest.raises(TypeError, create_info, ch_names="Test Ch", sfreq=1000) + pytest.raises(ValueError, create_info, ch_names=["Test Ch"], sfreq=-1000) + pytest.raises( + ValueError, + create_info, + ch_names=["Test Ch"], + sfreq=1000, + ch_types=["eeg", "eeg"], + ) + pytest.raises(TypeError, create_info, ch_names=[np.array([1])], sfreq=1000) + pytest.raises( + KeyError, create_info, ch_names=["Test Ch"], sfreq=1000, ch_types=np.array([1]) + ) + pytest.raises( + KeyError, create_info, ch_names=["Test Ch"], sfreq=1000, ch_types="awesome" + ) + pytest.raises( + TypeError, create_info, ["Test Ch"], sfreq=1000, montage=np.array([1]) + ) + m = make_standard_montage("biosemi32") + info = create_info(ch_names=m.ch_names, sfreq=1000.0, ch_types="eeg") info.set_montage(m) - ch_pos = [ch['loc'][:3] for ch in info['chs']] + ch_pos = [ch["loc"][:3] for ch in info["chs"]] ch_pos_mon = m._get_ch_pos() - ch_pos_mon = np.array( - [ch_pos_mon[ch_name] for ch_name in info['ch_names']]) + ch_pos_mon = np.array([ch_pos_mon[ch_name] for ch_name in info["ch_names"]]) # transform to head - ch_pos_mon += (0., 0., 0.04014) + ch_pos_mon += (0.0, 0.0, 0.04014) assert_allclose(ch_pos, ch_pos_mon, atol=1e-5) def test_duplicate_name_correction(): """Test duplicate channel names with running number.""" # When running number is possible - info = create_info(['A', 'A', 'A'], 1000., verbose='error') - assert info['ch_names'] == ['A-0', 'A-1', 'A-2'] + info = create_info(["A", "A", "A"], 1000.0, verbose="error") + assert info["ch_names"] == ["A-0", "A-1", "A-2"] # When running number is not possible but alpha numeric is - info = create_info(['A', 'A', 'A-0'], 1000., verbose='error') - assert info['ch_names'] == ['A-a', 'A-1', 'A-0'] + info = create_info(["A", "A", "A-0"], 1000.0, verbose="error") + assert info["ch_names"] == ["A-a", "A-1", "A-0"] # When a single addition is not sufficient - with pytest.raises(ValueError, match='Adding a single alphanumeric'): - ch_n = ['A', 'A'] + with pytest.raises(ValueError, match="Adding a single alphanumeric"): + ch_n = ["A", "A"] # add all options for first duplicate channel (0) - ch_n.extend([f'{ch_n[0]}-{c}' for c in string.ascii_lowercase + '0']) - create_info(ch_n, 1000., verbose='error') + ch_n.extend([f"{ch_n[0]}-{c}" for c in string.ascii_lowercase + "0"]) + create_info(ch_n, 1000.0, verbose="error") def test_fiducials_io(tmp_path): """Test fiducials i/o.""" pts, coord_frame = read_fiducials(fiducials_fname) - assert pts[0]['coord_frame'] == FIFF.FIFFV_COORD_MRI - assert pts[0]['ident'] == FIFF.FIFFV_POINT_CARDINAL + assert pts[0]["coord_frame"] == FIFF.FIFFV_COORD_MRI + assert pts[0]["ident"] == FIFF.FIFFV_POINT_CARDINAL - temp_fname = tmp_path / 'test.fif' + temp_fname = tmp_path / "test.fif" write_fiducials(temp_fname, pts, coord_frame) pts_1, coord_frame_1 = read_fiducials(temp_fname) assert coord_frame == coord_frame_1 for pt, pt_1 in zip(pts, pts_1): - assert pt['kind'] == pt_1['kind'] - assert pt['ident'] == pt_1['ident'] - assert pt['coord_frame'] == pt_1['coord_frame'] - assert_array_equal(pt['r'], pt_1['r']) + assert pt["kind"] == pt_1["kind"] + assert pt["ident"] == pt_1["ident"] + assert pt["coord_frame"] == pt_1["coord_frame"] + assert_array_equal(pt["r"], pt_1["r"]) assert isinstance(pt, DigPoint) assert isinstance(pt_1, DigPoint) # test safeguards - pts[0]['coord_frame'] += 1 - with pytest.raises(ValueError, match='coord_frame entries that are incom'): + pts[0]["coord_frame"] += 1 + with pytest.raises(ValueError, match="coord_frame entries that are incom"): write_fiducials(temp_fname, pts, coord_frame, overwrite=True) @@ -188,45 +239,44 @@ def test_info(): evoked = epochs.average() # Test subclassing was successful. - info = Info(a=7, b='aaaaa') - assert ('a' in info) - assert ('b' in info) + info = Info(a=7, b="aaaaa") + assert "a" in info + assert "b" in info # Test info attribute in API objects for obj in [raw, epochs, evoked]: - assert (isinstance(obj.info, Info)) + assert isinstance(obj.info, Info) rep = repr(obj.info) - assert '2002-12-03 19:01:10 UTC' in rep, rep - assert '146 items (3 Cardinal, 4 HPI, 61 EEG, 78 Extra)' in rep - dig_rep = repr(obj.info['dig'][0]) - assert 'LPA' in dig_rep, dig_rep - assert '(-71.4, 0.0, 0.0) mm' in dig_rep, dig_rep - assert 'head frame' in dig_rep, dig_rep + assert "2002-12-03 19:01:10 UTC" in rep, rep + assert "146 items (3 Cardinal, 4 HPI, 61 EEG, 78 Extra)" in rep + dig_rep = repr(obj.info["dig"][0]) + assert "LPA" in dig_rep, dig_rep + assert "(-71.4, 0.0, 0.0) mm" in dig_rep, dig_rep + assert "head frame" in dig_rep, dig_rep # Test our BunchConstNamed support for func in (str, repr): - assert '4 (FIFFV_COORD_HEAD)' == \ - func(obj.info['dig'][0]['coord_frame']) + assert "4 (FIFFV_COORD_HEAD)" == func(obj.info["dig"][0]["coord_frame"]) # Test read-only fields info = raw.info.copy() - nchan = len(info['chs']) - ch_names = [ch['ch_name'] for ch in info['chs']] - assert info['nchan'] == nchan - assert list(info['ch_names']) == ch_names + nchan = len(info["chs"]) + ch_names = [ch["ch_name"] for ch in info["chs"]] + assert info["nchan"] == nchan + assert list(info["ch_names"]) == ch_names # Deleting of regular fields should work - info['experimenter'] = 'bar' - del info['experimenter'] + info["experimenter"] = "bar" + del info["experimenter"] # Test updating of fields - del info['chs'][-1] + del info["chs"][-1] info._update_redundant() - assert info['nchan'] == nchan - 1 - assert list(info['ch_names']) == ch_names[:-1] + assert info["nchan"] == nchan - 1 + assert list(info["ch_names"]) == ch_names[:-1] - info['chs'][0]['ch_name'] = 'foo' + info["chs"][0]["ch_name"] = "foo" info._update_redundant() - assert info['ch_names'][0] == 'foo' + assert info["ch_names"][0] == "foo" # Test casting to and from a dict info_dict = dict(info) @@ -237,348 +287,348 @@ def test_info(): def test_read_write_info(tmp_path): """Test IO of info.""" info = read_info(raw_fname) - temp_file = tmp_path / 'info.fif' + temp_file = tmp_path / "info.fif" # check for bug `#1198` - info['dev_head_t']['trans'] = np.eye(4) - t1 = info['dev_head_t']['trans'] + info["dev_head_t"]["trans"] = np.eye(4) + t1 = info["dev_head_t"]["trans"] write_info(temp_file, info) info2 = read_info(temp_file) - t2 = info2['dev_head_t']['trans'] - assert (len(info['chs']) == len(info2['chs'])) + t2 = info2["dev_head_t"]["trans"] + assert len(info["chs"]) == len(info2["chs"]) assert_array_equal(t1, t2) # proc_history (e.g., GH#1875) - creator = 'é' + creator = "é" info = read_info(chpi_fname) - info['proc_history'][0]['creator'] = creator - info['hpi_meas'][0]['creator'] = creator - info['subject_info']['his_id'] = creator - info['subject_info']['weight'] = 11.1 - info['subject_info']['height'] = 2.3 + info["proc_history"][0]["creator"] = creator + info["hpi_meas"][0]["creator"] = creator + info["subject_info"]["his_id"] = creator + info["subject_info"]["weight"] = 11.1 + info["subject_info"]["height"] = 2.3 with info._unlock(): - if info['gantry_angle'] is None: # future testing data may include it - info['gantry_angle'] = 0. # Elekta supine position - gantry_angle = info['gantry_angle'] + if info["gantry_angle"] is None: # future testing data may include it + info["gantry_angle"] = 0.0 # Elekta supine position + gantry_angle = info["gantry_angle"] - meas_id = info['meas_id'] + meas_id = info["meas_id"] write_info(temp_file, info) info = read_info(temp_file) - assert info['proc_history'][0]['creator'] == creator - assert info['hpi_meas'][0]['creator'] == creator - assert info['subject_info']['his_id'] == creator - assert info['gantry_angle'] == gantry_angle - assert info['subject_info']['height'] == 2.3 - assert info['subject_info']['weight'] == 11.1 - for key in ['secs', 'usecs', 'version']: - assert info['meas_id'][key] == meas_id[key] - assert_array_equal(info['meas_id']['machid'], meas_id['machid']) + assert info["proc_history"][0]["creator"] == creator + assert info["hpi_meas"][0]["creator"] == creator + assert info["subject_info"]["his_id"] == creator + assert info["gantry_angle"] == gantry_angle + assert info["subject_info"]["height"] == 2.3 + assert info["subject_info"]["weight"] == 11.1 + for key in ["secs", "usecs", "version"]: + assert info["meas_id"][key] == meas_id[key] + assert_array_equal(info["meas_id"]["machid"], meas_id["machid"]) # Test that writing twice produces the same file m1 = hashlib.md5() - with open(temp_file, 'rb') as fid: + with open(temp_file, "rb") as fid: m1.update(fid.read()) m1 = m1.hexdigest() - temp_file_2 = tmp_path / 'info2.fif' + temp_file_2 = tmp_path / "info2.fif" assert temp_file_2 != temp_file write_info(temp_file_2, info) m2 = hashlib.md5() - with open(str(temp_file_2), 'rb') as fid: + with open(str(temp_file_2), "rb") as fid: m2.update(fid.read()) m2 = m2.hexdigest() assert m1 == m2 info = read_info(raw_fname) with info._unlock(): - info['meas_date'] = None - anonymize_info(info, verbose='error') - assert info['meas_date'] is None - tmp_fname_3 = tmp_path / 'info3.fif' + info["meas_date"] = None + anonymize_info(info, verbose="error") + assert info["meas_date"] is None + tmp_fname_3 = tmp_path / "info3.fif" write_info(tmp_fname_3, info) - assert info['meas_date'] is None + assert info["meas_date"] is None info2 = read_info(tmp_fname_3) - assert info2['meas_date'] is None + assert info2["meas_date"] is None # Check that having a very old date in fine until you try to save it to fif with info._unlock(check_after=True): - info['meas_date'] = datetime(1800, 1, 1, 0, 0, 0, tzinfo=timezone.utc) - fname = tmp_path / 'test.fif' - with pytest.raises(RuntimeError, match='must be between '): + info["meas_date"] = datetime(1800, 1, 1, 0, 0, 0, tzinfo=timezone.utc) + fname = tmp_path / "test.fif" + with pytest.raises(RuntimeError, match="must be between "): write_info(fname, info) def test_io_dig_points(tmp_path): """Test Writing for dig files.""" - points = read_polhemus_fastscan(hsp_fname, on_header_missing='ignore') + points = read_polhemus_fastscan(hsp_fname, on_header_missing="ignore") - dest = tmp_path / 'test.txt' - dest_bad = tmp_path / 'test.mne' - with pytest.raises(ValueError, match='must be of shape'): + dest = tmp_path / "test.txt" + dest_bad = tmp_path / "test.mne" + with pytest.raises(ValueError, match="must be of shape"): _write_dig_points(dest, points[:, :2]) - with pytest.raises(ValueError, match='extension'): + with pytest.raises(ValueError, match="extension"): _write_dig_points(dest_bad, points) _write_dig_points(dest, points) - points1 = read_polhemus_fastscan( - dest, unit='m', on_header_missing='ignore') + points1 = read_polhemus_fastscan(dest, unit="m", on_header_missing="ignore") err = "Dig points diverged after writing and reading." assert_array_equal(points, points1, err) points2 = np.array([[-106.93, 99.80], [99.80, 68.81]]) - np.savetxt(dest, points2, delimiter='\t', newline='\n') - with pytest.raises(ValueError, match='must be of shape'): - with pytest.warns(RuntimeWarning, match='FastSCAN header'): - read_polhemus_fastscan(dest, on_header_missing='warn') + np.savetxt(dest, points2, delimiter="\t", newline="\n") + with pytest.raises(ValueError, match="must be of shape"): + with pytest.warns(RuntimeWarning, match="FastSCAN header"): + read_polhemus_fastscan(dest, on_header_missing="warn") def test_io_coord_frame(tmp_path): """Test round trip for coordinate frame.""" - fname = tmp_path / 'test.fif' - for ch_type in ('eeg', 'seeg', 'ecog', 'dbs', 'hbo', 'hbr'): - info = create_info( - ch_names=['Test Ch'], sfreq=1000., ch_types=[ch_type]) - info['chs'][0]['loc'][:3] = [0.05, 0.01, -0.03] + fname = tmp_path / "test.fif" + for ch_type in ("eeg", "seeg", "ecog", "dbs", "hbo", "hbr"): + info = create_info(ch_names=["Test Ch"], sfreq=1000.0, ch_types=[ch_type]) + info["chs"][0]["loc"][:3] = [0.05, 0.01, -0.03] write_info(fname, info) info2 = read_info(fname) - assert info2['chs'][0]['coord_frame'] == FIFF.FIFFV_COORD_HEAD + assert info2["chs"][0]["coord_frame"] == FIFF.FIFFV_COORD_HEAD def test_make_dig_points(): """Test application of Polhemus HSP to info.""" - extra_points = read_polhemus_fastscan( - hsp_fname, on_header_missing='ignore') - info = create_info(ch_names=['Test Ch'], sfreq=1000.) - assert info['dig'] is None + extra_points = read_polhemus_fastscan(hsp_fname, on_header_missing="ignore") + info = create_info(ch_names=["Test Ch"], sfreq=1000.0) + assert info["dig"] is None with info._unlock(): - info['dig'] = _make_dig_points(extra_points=extra_points) - assert (info['dig']) - assert_allclose(info['dig'][0]['r'], [-.10693, .09980, .06881]) + info["dig"] = _make_dig_points(extra_points=extra_points) + assert info["dig"] + assert_allclose(info["dig"][0]["r"], [-0.10693, 0.09980, 0.06881]) - elp_points = read_polhemus_fastscan(elp_fname, on_header_missing='ignore') + elp_points = read_polhemus_fastscan(elp_fname, on_header_missing="ignore") nasion, lpa, rpa = elp_points[:3] - info = create_info(ch_names=['Test Ch'], sfreq=1000.) - assert info['dig'] is None + info = create_info(ch_names=["Test Ch"], sfreq=1000.0) + assert info["dig"] is None with info._unlock(): - info['dig'] = _make_dig_points(nasion, lpa, rpa, elp_points[3:], None) - assert (info['dig']) - idx = [d['ident'] for d in info['dig']].index(FIFF.FIFFV_POINT_NASION) - assert_allclose(info['dig'][idx]['r'], [.0013930, .0131613, -.0046967]) + info["dig"] = _make_dig_points(nasion, lpa, rpa, elp_points[3:], None) + assert info["dig"] + idx = [d["ident"] for d in info["dig"]].index(FIFF.FIFFV_POINT_NASION) + assert_allclose(info["dig"][idx]["r"], [0.0013930, 0.0131613, -0.0046967]) pytest.raises(ValueError, _make_dig_points, nasion[:2]) pytest.raises(ValueError, _make_dig_points, None, lpa[:2]) pytest.raises(ValueError, _make_dig_points, None, None, rpa[:2]) - pytest.raises(ValueError, _make_dig_points, None, None, None, - elp_points[:, :2]) - pytest.raises(ValueError, _make_dig_points, None, None, None, None, - elp_points[:, :2]) + pytest.raises(ValueError, _make_dig_points, None, None, None, elp_points[:, :2]) + pytest.raises( + ValueError, _make_dig_points, None, None, None, None, elp_points[:, :2] + ) def test_redundant(): """Test some of the redundant properties of info.""" # Indexing - info = create_info(ch_names=['a', 'b', 'c'], sfreq=1000.) - assert info['ch_names'][0] == 'a' - assert info['ch_names'][1] == 'b' - assert info['ch_names'][2] == 'c' + info = create_info(ch_names=["a", "b", "c"], sfreq=1000.0) + assert info["ch_names"][0] == "a" + assert info["ch_names"][1] == "b" + assert info["ch_names"][2] == "c" # Equality - assert info['ch_names'] == info['ch_names'] - assert info['ch_names'] == ['a', 'b', 'c'] + assert info["ch_names"] == info["ch_names"] + assert info["ch_names"] == ["a", "b", "c"] # No channels in info - info = create_info(ch_names=[], sfreq=1000.) - assert info['ch_names'] == [] + info = create_info(ch_names=[], sfreq=1000.0) + assert info["ch_names"] == [] # List should be read-only - info = create_info(ch_names=['a', 'b', 'c'], sfreq=1000.) + info = create_info(ch_names=["a", "b", "c"], sfreq=1000.0) def test_merge_info(): """Test merging of multiple Info objects.""" - info_a = create_info(ch_names=['a', 'b', 'c'], sfreq=1000.) - info_b = create_info(ch_names=['d', 'e', 'f'], sfreq=1000.) + info_a = create_info(ch_names=["a", "b", "c"], sfreq=1000.0) + info_b = create_info(ch_names=["d", "e", "f"], sfreq=1000.0) info_merged = _merge_info([info_a, info_b]) - assert info_merged['nchan'], 6 - assert info_merged['ch_names'], ['a', 'b', 'c', 'd', 'e', 'f'] + assert info_merged["nchan"], 6 + assert info_merged["ch_names"], ["a", "b", "c", "d", "e", "f"] pytest.raises(ValueError, _merge_info, [info_a, info_a]) # Testing for force updates before merging - info_c = create_info(ch_names=['g', 'h', 'i'], sfreq=500.) + info_c = create_info(ch_names=["g", "h", "i"], sfreq=500.0) # This will break because sfreq is not equal pytest.raises(RuntimeError, _merge_info, [info_a, info_c]) _force_update_info(info_a, info_c) - assert (info_c['sfreq'] == info_a['sfreq']) - assert (info_c['ch_names'][0] != info_a['ch_names'][0]) + assert info_c["sfreq"] == info_a["sfreq"] + assert info_c["ch_names"][0] != info_a["ch_names"][0] # Make sure it works now _merge_info([info_a, info_c]) # Check that you must supply Info - pytest.raises(ValueError, _force_update_info, info_a, - dict([('sfreq', 1000.)])) + pytest.raises(ValueError, _force_update_info, info_a, dict([("sfreq", 1000.0)])) # KIT System-ID info_a._unlocked = info_b._unlocked = True - info_a['kit_system_id'] = 50 - assert _merge_info((info_a, info_b))['kit_system_id'] == 50 - info_b['kit_system_id'] = 50 - assert _merge_info((info_a, info_b))['kit_system_id'] == 50 - info_b['kit_system_id'] = 60 + info_a["kit_system_id"] = 50 + assert _merge_info((info_a, info_b))["kit_system_id"] == 50 + info_b["kit_system_id"] = 50 + assert _merge_info((info_a, info_b))["kit_system_id"] == 50 + info_b["kit_system_id"] = 60 pytest.raises(ValueError, _merge_info, (info_a, info_b)) # hpi infos - info_d = create_info(ch_names=['d', 'e', 'f'], sfreq=1000.) + info_d = create_info(ch_names=["d", "e", "f"], sfreq=1000.0) info_merged = _merge_info([info_a, info_d]) - assert not info_merged['hpi_meas'] - assert not info_merged['hpi_results'] - info_a['hpi_meas'] = [{'f1': 3, 'f2': 4}] - assert _merge_info([info_a, info_d])['hpi_meas'] == info_a['hpi_meas'] + assert not info_merged["hpi_meas"] + assert not info_merged["hpi_results"] + info_a["hpi_meas"] = [{"f1": 3, "f2": 4}] + assert _merge_info([info_a, info_d])["hpi_meas"] == info_a["hpi_meas"] info_d._unlocked = True - info_d['hpi_meas'] = [{'f1': 3, 'f2': 4}] - assert _merge_info([info_a, info_d])['hpi_meas'] == info_d['hpi_meas'] + info_d["hpi_meas"] = [{"f1": 3, "f2": 4}] + assert _merge_info([info_a, info_d])["hpi_meas"] == info_d["hpi_meas"] # This will break because of inconsistency - info_d['hpi_meas'] = [{'f1': 3, 'f2': 5}] + info_d["hpi_meas"] = [{"f1": 3, "f2": 5}] pytest.raises(ValueError, _merge_info, [info_a, info_d]) info_0 = read_info(raw_fname) - info_0['bads'] = ['MEG 2443', 'EEG 053'] - assert len(info_0['chs']) == 376 - assert len(info_0['dig']) == 146 - info_1 = create_info(["STI YYY"], info_0['sfreq'], ['stim']) - assert info_1['bads'] == [] + info_0["bads"] = ["MEG 2443", "EEG 053"] + assert len(info_0["chs"]) == 376 + assert len(info_0["dig"]) == 146 + info_1 = create_info(["STI YYY"], info_0["sfreq"], ["stim"]) + assert info_1["bads"] == [] info_out = _merge_info([info_0, info_1], force_update_to_first=True) - assert len(info_out['chs']) == 377 - assert len(info_out['bads']) == 2 - assert len(info_out['dig']) == 146 - assert len(info_0['chs']) == 376 - assert len(info_0['bads']) == 2 - assert len(info_0['dig']) == 146 + assert len(info_out["chs"]) == 377 + assert len(info_out["bads"]) == 2 + assert len(info_out["dig"]) == 146 + assert len(info_0["chs"]) == 376 + assert len(info_0["bads"]) == 2 + assert len(info_0["dig"]) == 146 def test_check_consistency(): """Test consistency check of Info objects.""" - info = create_info(ch_names=['a', 'b', 'c'], sfreq=1000.) + info = create_info(ch_names=["a", "b", "c"], sfreq=1000.0) # This should pass info._check_consistency() # Info without any channels - info_empty = create_info(ch_names=[], sfreq=1000.) + info_empty = create_info(ch_names=[], sfreq=1000.0) info_empty._check_consistency() # Bad channels that are not in the info object info2 = info.copy() - info2['bads'] = ['b', 'foo', 'bar'] + info2["bads"] = ["b", "foo", "bar"] pytest.raises(RuntimeError, info2._check_consistency) # Bad data types info2 = info.copy() with info2._unlock(): - info2['sfreq'] = 'foo' + info2["sfreq"] = "foo" pytest.raises(ValueError, info2._check_consistency) info2 = info.copy() with info2._unlock(): - info2['highpass'] = 'foo' + info2["highpass"] = "foo" pytest.raises(ValueError, info2._check_consistency) info2 = info.copy() with info2._unlock(): - info2['lowpass'] = 'foo' + info2["lowpass"] = "foo" pytest.raises(ValueError, info2._check_consistency) # Silent type conversion to float info2 = info.copy() with info2._unlock(check_after=True): - info2['sfreq'] = 1 - info2['highpass'] = 2 - info2['lowpass'] = 2 - assert (isinstance(info2['sfreq'], float)) - assert (isinstance(info2['highpass'], float)) - assert (isinstance(info2['lowpass'], float)) + info2["sfreq"] = 1 + info2["highpass"] = 2 + info2["lowpass"] = 2 + assert isinstance(info2["sfreq"], float) + assert isinstance(info2["highpass"], float) + assert isinstance(info2["lowpass"], float) # Duplicate channel names info2 = info.copy() with info2._unlock(): - info2['chs'][2]['ch_name'] = 'b' + info2["chs"][2]["ch_name"] = "b" pytest.raises(RuntimeError, info2._check_consistency) # Duplicates appended with running numbers - with pytest.warns(RuntimeWarning, match='Channel names are not'): - info3 = create_info(ch_names=['a', 'b', 'b', 'c', 'b'], sfreq=1000.) - assert_array_equal(info3['ch_names'], ['a', 'b-0', 'b-1', 'c', 'b-2']) + with pytest.warns(RuntimeWarning, match="Channel names are not"): + info3 = create_info(ch_names=["a", "b", "b", "c", "b"], sfreq=1000.0) + assert_array_equal(info3["ch_names"], ["a", "b-0", "b-1", "c", "b-2"]) # a few bad ones idx = 0 - ch = info['chs'][idx] - for key, bad, match in (('ch_name', 1., 'not a string'), - ('loc', np.zeros(15), '12 elements'), - ('cal', np.ones(1), 'float or int')): + ch = info["chs"][idx] + for key, bad, match in ( + ("ch_name", 1.0, "not a string"), + ("loc", np.zeros(15), "12 elements"), + ("cal", np.ones(1), "float or int"), + ): info._check_consistency() # okay old = ch[key] ch[key] = bad - if key == 'ch_name': - info['ch_names'][idx] = bad + if key == "ch_name": + info["ch_names"][idx] = bad with pytest.raises(TypeError, match=match): info._check_consistency() ch[key] = old - if key == 'ch_name': - info['ch_names'][idx] = old + if key == "ch_name": + info["ch_names"][idx] = old # bad channel entries info2 = info.copy() - info2['chs'][0]['foo'] = 'bar' - with pytest.raises(KeyError, match='key errantly present'): + info2["chs"][0]["foo"] = "bar" + with pytest.raises(KeyError, match="key errantly present"): info2._check_consistency() info2 = info.copy() - del info2['chs'][0]['loc'] - with pytest.raises(KeyError, match='key missing'): + del info2["chs"][0]["loc"] + with pytest.raises(KeyError, match="key missing"): info2._check_consistency() def _test_anonymize_info(base_info): """Test that sensitive information can be anonymized.""" - pytest.raises(TypeError, anonymize_info, 'foo') + pytest.raises(TypeError, anonymize_info, "foo") default_anon_dos = datetime(2000, 1, 1, 0, 0, 0, tzinfo=timezone.utc) default_str = "mne_anonymize" default_subject_id = 0 - default_desc = ("Anonymized using a time shift" + - " to preserve age at acquisition") + default_desc = "Anonymized using a time shift" + " to preserve age at acquisition" # Test no error for incomplete info info = base_info.copy() - info.pop('file_id') + info.pop("file_id") anonymize_info(info) # Fake some subject data meas_date = datetime(2010, 1, 1, 0, 0, 0, tzinfo=timezone.utc) with base_info._unlock(): - base_info['meas_date'] = meas_date - base_info['subject_info'] = dict(id=1, - his_id='foobar', - last_name='bar', - first_name='bar', - birthday=(1987, 4, 8), - sex=0, hand=1) + base_info["meas_date"] = meas_date + base_info["subject_info"] = dict( + id=1, + his_id="foobar", + last_name="bar", + first_name="bar", + birthday=(1987, 4, 8), + sex=0, + hand=1, + ) # generate expected info... # first expected result with no options. # will move DOS from 2010/1/1 to 2000/1/1 which is 3653 days. exp_info = base_info.copy() exp_info._unlocked = True - exp_info['description'] = default_desc - exp_info['experimenter'] = default_str - exp_info['proj_name'] = default_str - exp_info['proj_id'] = np.array([0]) - exp_info['subject_info']['first_name'] = default_str - exp_info['subject_info']['last_name'] = default_str - exp_info['subject_info']['id'] = default_subject_id - exp_info['subject_info']['his_id'] = str(default_subject_id) - exp_info['subject_info']['sex'] = 0 - del exp_info['subject_info']['hand'] # there's no "unknown" setting + exp_info["description"] = default_desc + exp_info["experimenter"] = default_str + exp_info["proj_name"] = default_str + exp_info["proj_id"] = np.array([0]) + exp_info["subject_info"]["first_name"] = default_str + exp_info["subject_info"]["last_name"] = default_str + exp_info["subject_info"]["id"] = default_subject_id + exp_info["subject_info"]["his_id"] = str(default_subject_id) + exp_info["subject_info"]["sex"] = 0 + del exp_info["subject_info"]["hand"] # there's no "unknown" setting # this bday is 3653 days different. the change in day is due to a # different number of leap days between 1987 and 1977 than between # 2010 and 2000. - exp_info['subject_info']['birthday'] = (1977, 4, 7) - exp_info['meas_date'] = default_anon_dos + exp_info["subject_info"]["birthday"] = (1977, 4, 7) + exp_info["meas_date"] = default_anon_dos exp_info._unlocked = False # make copies @@ -586,37 +636,35 @@ def _test_anonymize_info(base_info): # adjust each expected outcome delta_t = timedelta(days=3653) - for key in ('file_id', 'meas_id'): + for key in ("file_id", "meas_id"): value = exp_info.get(key) if value is not None: - assert 'msecs' not in value - tmp = _add_timedelta_to_stamp( - (value['secs'], value['usecs']), -delta_t) - value['secs'] = tmp[0] - value['usecs'] = tmp[1] - value['machid'][:] = 0 + assert "msecs" not in value + tmp = _add_timedelta_to_stamp((value["secs"], value["usecs"]), -delta_t) + value["secs"] = tmp[0] + value["usecs"] = tmp[1] + value["machid"][:] = 0 # exp 2 tests the keep_his option exp_info_2 = exp_info.copy() with exp_info_2._unlock(): - exp_info_2['subject_info']['his_id'] = 'foobar' - exp_info_2['subject_info']['sex'] = 0 - exp_info_2['subject_info']['hand'] = 1 + exp_info_2["subject_info"]["his_id"] = "foobar" + exp_info_2["subject_info"]["sex"] = 0 + exp_info_2["subject_info"]["hand"] = 1 # exp 3 tests is a supplied daysback delta_t_2 = timedelta(days=43) with exp_info_3._unlock(): - exp_info_3['subject_info']['birthday'] = (1987, 2, 24) - exp_info_3['meas_date'] = meas_date - delta_t_2 - for key in ('file_id', 'meas_id'): + exp_info_3["subject_info"]["birthday"] = (1987, 2, 24) + exp_info_3["meas_date"] = meas_date - delta_t_2 + for key in ("file_id", "meas_id"): value = exp_info_3.get(key) if value is not None: - assert 'msecs' not in value - tmp = _add_timedelta_to_stamp( - (value['secs'], value['usecs']), -delta_t_2) - value['secs'] = tmp[0] - value['usecs'] = tmp[1] - value['machid'][:] = 0 + assert "msecs" not in value + tmp = _add_timedelta_to_stamp((value["secs"], value["usecs"]), -delta_t_2) + value["secs"] = tmp[0] + value["usecs"] = tmp[1] + value["machid"][:] = 0 # exp 4 tests is a supplied daysback delta_t_3 = timedelta(days=223 + 364 * 500) @@ -630,26 +678,25 @@ def _test_anonymize_info(base_info): new_info = anonymize_info(base_info.copy(), daysback=delta_t_2.days) assert_object_equal(new_info, exp_info_3) - with pytest.raises(RuntimeError, match='anonymize_info generated'): + with pytest.raises(RuntimeError, match="anonymize_info generated"): anonymize_info(base_info.copy(), daysback=delta_t_3.days) # assert_object_equal(new_info, exp_info_4) # test with meas_date = None with base_info._unlock(): - base_info['meas_date'] = None + base_info["meas_date"] = None exp_info_3._unlocked = True - exp_info_3['meas_date'] = None - exp_info_3['file_id']['secs'] = DATE_NONE[0] - exp_info_3['file_id']['usecs'] = DATE_NONE[1] - exp_info_3['meas_id']['secs'] = DATE_NONE[0] - exp_info_3['meas_id']['usecs'] = DATE_NONE[1] - exp_info_3['subject_info'].pop('birthday', None) + exp_info_3["meas_date"] = None + exp_info_3["file_id"]["secs"] = DATE_NONE[0] + exp_info_3["file_id"]["usecs"] = DATE_NONE[1] + exp_info_3["meas_id"]["secs"] = DATE_NONE[0] + exp_info_3["meas_id"]["usecs"] = DATE_NONE[1] + exp_info_3["subject_info"].pop("birthday", None) exp_info_3._unlocked = False - if base_info['meas_date'] is None: - with pytest.warns(RuntimeWarning, match='all information'): - new_info = anonymize_info(base_info.copy(), - daysback=delta_t_2.days) + if base_info["meas_date"] is None: + with pytest.warns(RuntimeWarning, match="all information"): + new_info = anonymize_info(base_info.copy(), daysback=delta_t_2.days) else: new_info = anonymize_info(base_info.copy(), daysback=delta_t_2.days) assert_object_equal(new_info, exp_info_3) @@ -659,13 +706,16 @@ def _test_anonymize_info(base_info): assert_object_equal(new_info, exp_info_3) -@pytest.mark.parametrize('stamp, dt', [ - [(1346981585, 835782), (2012, 9, 7, 1, 33, 5, 835782)], - # test old dates for BIDS anonymization - [(-1533443343, 24382), (1921, 5, 29, 19, 30, 57, 24382)], - # gh-7116 - [(-908196946, 988669), (1941, 3, 22, 11, 4, 14, 988669)], -]) +@pytest.mark.parametrize( + "stamp, dt", + [ + [(1346981585, 835782), (2012, 9, 7, 1, 33, 5, 835782)], + # test old dates for BIDS anonymization + [(-1533443343, 24382), (1921, 5, 29, 19, 30, 57, 24382)], + # gh-7116 + [(-908196946, 988669), (1941, 3, 22, 11, 4, 14, 988669)], + ], +) def test_meas_date_convert(stamp, dt): """Test conversions of meas_date to datetime objects.""" meas_datetime = _stamp_to_dt(stamp) @@ -673,22 +723,21 @@ def test_meas_date_convert(stamp, dt): assert stamp == stamp2 assert meas_datetime == datetime(*dt, tzinfo=timezone.utc) # smoke test for info __repr__ - info = create_info(1, 1000., 'eeg') + info = create_info(1, 1000.0, "eeg") with info._unlock(): - info['meas_date'] = meas_datetime + info["meas_date"] = meas_datetime assert str(dt[0]) in repr(info) def test_anonymize(tmp_path): """Test that sensitive information can be anonymized.""" - pytest.raises(TypeError, anonymize_info, 'foo') + pytest.raises(TypeError, anonymize_info, "foo") # Fake some subject data raw = read_raw_fif(raw_fname) - raw.set_annotations(Annotations(onset=[0, 1], - duration=[1, 1], - description='dummy', - orig_time=None)) + raw.set_annotations( + Annotations(onset=[0, 1], duration=[1, 1], description="dummy", orig_time=None) + ) first_samp = raw.first_samp expected_onset = np.arange(2) + raw._first_time assert raw.first_samp == first_samp @@ -696,7 +745,7 @@ def test_anonymize(tmp_path): # test mne.anonymize_info() events = read_events(event_name) - epochs = Epochs(raw, events[:1], 2, 0., 0.1, baseline=None) + epochs = Epochs(raw, events[:1], 2, 0.0, 0.1, baseline=None) _test_anonymize_info(raw.info.copy()) _test_anonymize_info(epochs.info.copy()) @@ -704,24 +753,24 @@ def test_anonymize(tmp_path): for inst, keep_his in zip((raw, epochs), (True, False)): inst = inst.copy() - subject_info = dict(his_id='Volunteer', sex=2, hand=1) - inst.info['subject_info'] = subject_info + subject_info = dict(his_id="Volunteer", sex=2, hand=1) + inst.info["subject_info"] = subject_info inst.anonymize(keep_his=keep_his) - si = inst.info['subject_info'] + si = inst.info["subject_info"] if keep_his: assert si == subject_info else: - assert si['his_id'] == '0' - assert si['sex'] == 0 - assert 'hand' not in si + assert si["his_id"] == "0" + assert si["sex"] == 0 + assert "hand" not in si # write to disk & read back - inst_type = 'raw' if isinstance(inst, BaseRaw) else 'epo' - fname = 'tmp_raw.fif' if inst_type == 'raw' else 'tmp_epo.fif' + inst_type = "raw" if isinstance(inst, BaseRaw) else "epo" + fname = "tmp_raw.fif" if inst_type == "raw" else "tmp_epo.fif" out_path = tmp_path / fname inst.save(out_path, overwrite=True) - if inst_type == 'raw': + if inst_type == "raw": read_raw_fif(out_path) else: read_epochs(out_path) @@ -730,14 +779,14 @@ def test_anonymize(tmp_path): raw.anonymize() assert raw.first_samp == first_samp assert_allclose(raw.annotations.onset, expected_onset) - assert raw.annotations.orig_time == raw.info['meas_date'] - stamp = _dt_to_stamp(raw.info['meas_date']) + assert raw.annotations.orig_time == raw.info["meas_date"] + stamp = _dt_to_stamp(raw.info["meas_date"]) assert raw.annotations.orig_time == _stamp_to_dt(stamp) with raw.info._unlock(): - raw.info['meas_date'] = None + raw.info["meas_date"] = None raw.anonymize(daysback=None) - with pytest.warns(RuntimeWarning, match='None'): + with pytest.warns(RuntimeWarning, match="None"): raw.anonymize(daysback=123) assert raw.annotations.orig_time is None assert raw.first_samp == first_samp @@ -748,12 +797,12 @@ def test_anonymize_with_io(tmp_path): """Test that IO does not break anonymization.""" raw = read_raw_fif(raw_fname) - temp_path = tmp_path / 'tmp_raw.fif' + temp_path = tmp_path / "tmp_raw.fif" raw.save(temp_path) raw2 = read_raw_fif(temp_path) - daysback = (raw2.info['meas_date'].date() - date(1924, 1, 1)).days + daysback = (raw2.info["meas_date"].date() - date(1924, 1, 1)).days raw2.anonymize(daysback=daysback) @@ -762,25 +811,25 @@ def test_csr_csc(tmp_path): """Test CSR and CSC.""" info = read_info(sss_ctc_fname) info = pick_info(info, pick_types(info, meg=True, exclude=[])) - sss_ctc = info['proc_history'][0]['max_info']['sss_ctc'] - ct = sss_ctc['decoupler'].copy() + sss_ctc = info["proc_history"][0]["max_info"]["sss_ctc"] + ct = sss_ctc["decoupler"].copy() # CSC assert isinstance(ct, sparse.csc_matrix) - fname = tmp_path / 'test.fif' + fname = tmp_path / "test.fif" write_info(fname, info) info_read = read_info(fname) - ct_read = info_read['proc_history'][0]['max_info']['sss_ctc']['decoupler'] + ct_read = info_read["proc_history"][0]["max_info"]["sss_ctc"]["decoupler"] assert isinstance(ct_read, sparse.csc_matrix) assert_array_equal(ct_read.toarray(), ct.toarray()) # Now CSR csr = ct.tocsr() assert isinstance(csr, sparse.csr_matrix) assert_array_equal(csr.toarray(), ct.toarray()) - info['proc_history'][0]['max_info']['sss_ctc']['decoupler'] = csr - fname = tmp_path / 'test1.fif' + info["proc_history"][0]["max_info"]["sss_ctc"]["decoupler"] = csr + fname = tmp_path / "test1.fif" write_info(fname, info) info_read = read_info(fname) - ct_read = info_read['proc_history'][0]['max_info']['sss_ctc']['decoupler'] + ct_read = info_read["proc_history"][0]["max_info"]["sss_ctc"]["decoupler"] assert isinstance(ct_read, sparse.csc_matrix) # this gets cast to CSC assert_array_equal(ct_read.toarray(), ct.toarray()) @@ -791,8 +840,8 @@ def test_check_compensation_consistency(): raw = read_raw_ctf(ctf_fname, preload=False) events = make_fixed_length_events(raw, 99999) picks = pick_types(raw.info, meg=True, exclude=[], ref_meg=True) - pick_ch_names = [raw.info['ch_names'][idx] for idx in picks] - for (comp, expected_result) in zip([0, 1], [False, False]): + pick_ch_names = [raw.info["ch_names"][idx] for idx in picks] + for comp, expected_result in zip([0, 1], [False, False]): raw.apply_gradient_compensation(comp) ret, missing = _bad_chans_comp(raw.info, pick_ch_names) assert ret == expected_result @@ -800,31 +849,31 @@ def test_check_compensation_consistency(): Epochs(raw, events, None, -0.2, 0.2, preload=False, picks=picks) picks = pick_types(raw.info, meg=True, exclude=[], ref_meg=False) - pick_ch_names = [raw.info['ch_names'][idx] for idx in picks] + pick_ch_names = [raw.info["ch_names"][idx] for idx in picks] - for (comp, expected_result) in zip([0, 1], [False, True]): + for comp, expected_result in zip([0, 1], [False, True]): raw.apply_gradient_compensation(comp) ret, missing = _bad_chans_comp(raw.info, pick_ch_names) assert ret == expected_result assert len(missing) == 17 with catch_logging() as log: - Epochs(raw, events, None, -0.2, 0.2, preload=False, - picks=picks, verbose=True) - assert 'Removing 5 compensators' in log.getvalue() + Epochs( + raw, events, None, -0.2, 0.2, preload=False, picks=picks, verbose=True + ) + assert "Removing 5 compensators" in log.getvalue() def test_field_round_trip(tmp_path): """Test round-trip for new fields.""" - info = create_info(1, 1000., 'eeg') + info = create_info(1, 1000.0, "eeg") with info._unlock(): - for key in ('file_id', 'meas_id'): + for key in ("file_id", "meas_id"): info[key] = _generate_meas_id() - info['device_info'] = dict( - type='a', model='b', serial='c', site='d') - info['helium_info'] = dict( - he_level_raw=1., helium_level=2., - orig_file_guid='e', meas_date=(1, 2)) - fname = tmp_path / 'temp-info.fif' + info["device_info"] = dict(type="a", model="b", serial="c", site="d") + info["helium_info"] = dict( + he_level_raw=1.0, helium_level=2.0, orig_file_guid="e", meas_date=(1, 2) + ) + fname = tmp_path / "temp-info.fif" write_info(fname, info) info_read = read_info(fname) assert_object_equal(info, info_read) @@ -832,55 +881,63 @@ def test_field_round_trip(tmp_path): def test_equalize_channels(): """Test equalization of channels for instances of Info.""" - info1 = create_info(['CH1', 'CH2', 'CH3'], sfreq=1.) - info2 = create_info(['CH4', 'CH2', 'CH1'], sfreq=1.) + info1 = create_info(["CH1", "CH2", "CH3"], sfreq=1.0) + info2 = create_info(["CH4", "CH2", "CH1"], sfreq=1.0) info1, info2 = equalize_channels([info1, info2]) - assert info1.ch_names == ['CH1', 'CH2'] - assert info2.ch_names == ['CH1', 'CH2'] + assert info1.ch_names == ["CH1", "CH2"] + assert info2.ch_names == ["CH1", "CH2"] def test_repr(): """Test Info repr.""" - info = create_info(1, 1000, 'eeg') - assert '7 non-empty values' in repr(info) + info = create_info(1, 1000, "eeg") + assert "7 non-empty values" in repr(info) - t = Transform('meg', 'head', np.ones((4, 4))) - info['dev_head_t'] = t - assert 'dev_head_t: MEG device -> head transform' in repr(info) + t = Transform("meg", "head", np.ones((4, 4))) + info["dev_head_t"] = t + assert "dev_head_t: MEG device -> head transform" in repr(info) def test_repr_html(): """Test Info HTML repr.""" info = read_info(raw_fname) - assert 'Projections' in info._repr_html_() + assert "Projections" in info._repr_html_() with info._unlock(): - info['projs'] = [] - assert 'Projections' not in info._repr_html_() - info['bads'] = [] - assert 'None' in info._repr_html_() - info['bads'] = ['MEG 2443', 'EEG 053'] - assert 'MEG 2443' in info._repr_html_() - assert 'EEG 053' in info._repr_html_() + info["projs"] = [] + assert "Projections" not in info._repr_html_() + info["bads"] = [] + assert "None" in info._repr_html_() + info["bads"] = ["MEG 2443", "EEG 053"] + assert "MEG 2443" in info._repr_html_() + assert "EEG 053" in info._repr_html_() html = info._repr_html_() - for ch in ['204 Gradiometers', '102 Magnetometers', '9 Stimulus', - '60 EEG', '1 EOG']: + for ch in [ + "204 Gradiometers", + "102 Magnetometers", + "9 Stimulus", + "60 EEG", + "1 EOG", + ]: assert ch in html @testing.requires_testing_data def test_invalid_subject_birthday(): """Test handling of an invalid birthday in the raw file.""" - with pytest.warns(RuntimeWarning, match='No birthday will be set'): + with pytest.warns(RuntimeWarning, match="No birthday will be set"): raw = read_raw_fif(raw_invalid_bday_fname) - assert 'birthday' not in raw.info['subject_info'] + assert "birthday" not in raw.info["subject_info"] -@pytest.mark.parametrize('fname', [ - pytest.param(ctf_fname, marks=testing._pytest_mark()), - raw_fname, -]) +@pytest.mark.parametrize( + "fname", + [ + pytest.param(ctf_fname, marks=testing._pytest_mark()), + raw_fname, + ], +) def test_channel_name_limit(tmp_path, monkeypatch, fname): """Test that our remapping works properly.""" # @@ -894,142 +951,145 @@ def test_channel_name_limit(tmp_path, monkeypatch, fname): else: assert fname.suffix == ".ds" raw = read_raw_ctf(fname) - ref_names = [raw.ch_names[pick] - for pick in pick_types(raw.info, meg=False, ref_meg=True)] + ref_names = [ + raw.ch_names[pick] for pick in pick_types(raw.info, meg=False, ref_meg=True) + ] data_names = raw.ch_names[32:35] - proj = dict(data=np.ones((1, len(data_names))), - col_names=data_names[:2].copy(), row_names=None, nrow=1) - proj = Projection( - data=proj, active=False, desc='test', kind=0, explained_var=0.) + proj = dict( + data=np.ones((1, len(data_names))), + col_names=data_names[:2].copy(), + row_names=None, + nrow=1, + ) + proj = Projection(data=proj, active=False, desc="test", kind=0, explained_var=0.0) raw.add_proj(proj, remove_existing=True) raw.info.normalize_proj() raw.pick_channels(data_names + ref_names, ordered=False).crop(0, 2) - long_names = ['123456789abcdefg' + name for name in raw.ch_names] - fname = tmp_path / 'test-raw.fif' + long_names = ["123456789abcdefg" + name for name in raw.ch_names] + fname = tmp_path / "test-raw.fif" with catch_logging() as log: raw.save(fname) log = log.getvalue() - assert 'truncated' not in log + assert "truncated" not in log rename = dict(zip(raw.ch_names, long_names)) long_data_names = [rename[name] for name in data_names] long_proj_names = long_data_names[:2] raw.rename_channels(rename) - for comp in raw.info['comps']: - for key in ('row_names', 'col_names'): - for name in comp['data'][key]: + for comp in raw.info["comps"]: + for key in ("row_names", "col_names"): + for name in comp["data"][key]: assert name in raw.ch_names - if raw.info['comps']: + if raw.info["comps"]: assert raw.compensation_grade == 0 raw.apply_gradient_compensation(3) assert raw.compensation_grade == 3 - assert len(raw.info['projs']) == 1 - assert raw.info['projs'][0]['data']['col_names'] == long_proj_names - raw.info['bads'] = bads = long_data_names[2:3] - good_long_data_names = [ - name for name in long_data_names if name not in bads] + assert len(raw.info["projs"]) == 1 + assert raw.info["projs"][0]["data"]["col_names"] == long_proj_names + raw.info["bads"] = bads = long_data_names[2:3] + good_long_data_names = [name for name in long_data_names if name not in bads] with catch_logging() as log: raw.save(fname, overwrite=True, verbose=True) log = log.getvalue() - assert 'truncated to 15' in log + assert "truncated to 15" in log for name in raw.ch_names: assert len(name) > 15 # first read the full waytmp_path with catch_logging() as log: raw_read = read_raw_fif(fname, verbose=True) log = log.getvalue() - assert 'Reading extended channel information' in log + assert "Reading extended channel information" in log for ra in (raw, raw_read): assert ra.ch_names == long_names - assert raw_read.info['projs'][0]['data']['col_names'] == long_proj_names + assert raw_read.info["projs"][0]["data"]["col_names"] == long_proj_names del raw_read # next read as if no longer names could be read - monkeypatch.setattr( - meas_info, '_read_extended_ch_info', lambda x, y, z: None) + monkeypatch.setattr(meas_info, "_read_extended_ch_info", lambda x, y, z: None) with catch_logging() as log: raw_read = read_raw_fif(fname, verbose=True) log = log.getvalue() - assert 'extended' not in log - if raw.info['comps']: + assert "extended" not in log + if raw.info["comps"]: assert raw_read.compensation_grade == 3 raw_read.apply_gradient_compensation(0) assert raw_read.compensation_grade == 0 monkeypatch.setattr( # restore - meas_info, '_read_extended_ch_info', _read_extended_ch_info) + meas_info, "_read_extended_ch_info", _read_extended_ch_info + ) short_proj_names = [ - f'{name[:13 - bool(len(ref_names))]}-{len(ref_names) + ni}' - for ni, name in enumerate(long_data_names[:2])] - assert raw_read.info['projs'][0]['data']['col_names'] == short_proj_names + f"{name[:13 - bool(len(ref_names))]}-{len(ref_names) + ni}" + for ni, name in enumerate(long_data_names[:2]) + ] + assert raw_read.info["projs"][0]["data"]["col_names"] == short_proj_names # # epochs # epochs = Epochs(raw, make_fixed_length_events(raw)) - fname = tmp_path / 'test-epo.fif' + fname = tmp_path / "test-epo.fif" epochs.save(fname) epochs_read = read_epochs(fname) for ep in (epochs, epochs_read): - assert ep.info['ch_names'] == long_names + assert ep.info["ch_names"] == long_names assert ep.ch_names == long_names del raw, epochs_read # cov - epochs.info['bads'] = [] - cov = compute_covariance(epochs, verbose='error') - fname = tmp_path / 'test-cov.fif' + epochs.info["bads"] = [] + cov = compute_covariance(epochs, verbose="error") + fname = tmp_path / "test-cov.fif" write_cov(fname, cov) cov_read = read_cov(fname) for co in (cov, cov_read): - assert co['names'] == long_data_names - assert co['bads'] == [] + assert co["names"] == long_data_names + assert co["bads"] == [] del cov_read # # evoked # evoked = epochs.average() - evoked.info['bads'] = bads + evoked.info["bads"] = bads assert evoked.nave == 1 - fname = tmp_path / 'test-ave.fif' + fname = tmp_path / "test-ave.fif" evoked.save(fname) evoked_read = read_evokeds(fname)[0] for ev in (evoked, evoked_read): assert ev.ch_names == long_names - assert ev.info['bads'] == bads + assert ev.info["bads"] == bads del evoked_read, epochs # # forward # with _record_warnings(): # not enough points for CTF - sphere = make_sphere_model('auto', 'auto', evoked.info) - src = setup_volume_source_space( - pos=dict(rr=[[0, 0, 0.04]], nn=[[0, 1., 0.]])) + sphere = make_sphere_model("auto", "auto", evoked.info) + src = setup_volume_source_space(pos=dict(rr=[[0, 0, 0.04]], nn=[[0, 1.0, 0.0]])) fwd = make_forward_solution(evoked.info, None, src, sphere) - fname = tmp_path / 'temp-fwd.fif' + fname = tmp_path / "temp-fwd.fif" write_forward_solution(fname, fwd) fwd_read = read_forward_solution(fname) for fw in (fwd, fwd_read): - assert fw['sol']['row_names'] == long_data_names - assert fw['info']['ch_names'] == long_data_names - assert fw['info']['bads'] == bads + assert fw["sol"]["row_names"] == long_data_names + assert fw["info"]["ch_names"] == long_data_names + assert fw["info"]["bads"] == bads del fwd_read # # inv # inv = make_inverse_operator(evoked.info, fwd, cov) - fname = tmp_path / 'test-inv.fif' + fname = tmp_path / "test-inv.fif" write_inverse_operator(fname, inv) inv_read = read_inverse_operator(fname) for iv in (inv, inv_read): - assert iv['info']['ch_names'] == good_long_data_names + assert iv["info"]["ch_names"] == good_long_data_names apply_inverse(evoked, inv) # smoke test -@pytest.mark.parametrize('fname_info', (raw_fname, 'create_info')) -@pytest.mark.parametrize('unlocked', (True, False)) +@pytest.mark.parametrize("fname_info", (raw_fname, "create_info")) +@pytest.mark.parametrize("unlocked", (True, False)) def test_pickle(fname_info, unlocked): """Test that Info can be (un)pickled.""" - if fname_info == 'create_info': - info = create_info(3, 1000., 'eeg') + if fname_info == "create_info": + info = create_info(3, 1000.0, "eeg") else: info = read_info(fname_info) assert not info._unlocked @@ -1043,41 +1103,39 @@ def test_pickle(fname_info, unlocked): def test_info_bad(): """Test our info sanity checkers.""" - info = create_info(2, 1000., 'eeg') - info['description'] = 'foo' - info['experimenter'] = 'bar' - info['line_freq'] = 50. - info['bads'] = info['ch_names'][:1] - info['temp'] = ('whatever', 1.) + info = create_info(2, 1000.0, "eeg") + info["description"] = "foo" + info["experimenter"] = "bar" + info["line_freq"] = 50.0 + info["bads"] = info["ch_names"][:1] + info["temp"] = ("whatever", 1.0) # After 0.24 these should be pytest.raises calls check, klass = pytest.raises, RuntimeError with check(klass, match=r"info\['temp'\]"): - info['bad_key'] = 1. - for (key, match) in ([ - ('sfreq', r'inst\.resample'), - ('chs', r'inst\.add_channels')]): + info["bad_key"] = 1.0 + for key, match in [("sfreq", r"inst\.resample"), ("chs", r"inst\.add_channels")]: with check(klass, match=match): info[key] = info[key] - with pytest.raises(ValueError, match='between meg<->head'): - info['dev_head_t'] = Transform('mri', 'head', np.eye(4)) + with pytest.raises(ValueError, match="between meg<->head"): + info["dev_head_t"] = Transform("mri", "head", np.eye(4)) def test_get_montage(): """Test ContainsMixin.get_montage().""" - ch_names = make_standard_montage('standard_1020').ch_names + ch_names = make_standard_montage("standard_1020").ch_names sfreq = 512 data = np.zeros((len(ch_names), sfreq * 2)) - raw = RawArray(data, create_info(ch_names, sfreq, 'eeg')) - raw.set_montage('standard_1020') + raw = RawArray(data, create_info(ch_names, sfreq, "eeg")) + raw.set_montage("standard_1020") assert len(raw.get_montage().ch_names) == len(ch_names) - raw.info['bads'] = [ch_names[0]] + raw.info["bads"] = [ch_names[0]] assert len(raw.get_montage().ch_names) == len(ch_names) # test info - raw = RawArray(data, create_info(ch_names, sfreq, 'eeg')) - raw.set_montage('standard_1020') + raw = RawArray(data, create_info(ch_names, sfreq, "eeg")) + raw.set_montage("standard_1020") assert len(raw.info.get_montage().ch_names) == len(ch_names) - raw.info['bads'] = [ch_names[0]] + raw.info["bads"] = [ch_names[0]] assert len(raw.info.get_montage().ch_names) == len(ch_names) diff --git a/mne/io/tests/test_pick.py b/mne/io/tests/test_pick.py index 1632455b50e..da1a9c97b2c 100644 --- a/mne/io/tests/test_pick.py +++ b/mne/io/tests/test_pick.py @@ -5,26 +5,38 @@ import numpy as np from numpy.testing import assert_array_equal -from mne import (pick_channels_regexp, pick_types, Epochs, - read_forward_solution, rename_channels, - pick_info, pick_channels, create_info, make_ad_hoc_cov) -from mne.io import (read_raw_fif, RawArray, read_raw_bti, read_raw_kit, - read_info) +from mne import ( + pick_channels_regexp, + pick_types, + Epochs, + read_forward_solution, + rename_channels, + pick_info, + pick_channels, + create_info, + make_ad_hoc_cov, +) +from mne.io import read_raw_fif, RawArray, read_raw_bti, read_raw_kit, read_info from mne.channels import make_standard_montage from mne.preprocessing import compute_current_source_density -from mne.io.pick import (channel_indices_by_type, channel_type, - pick_types_forward, _picks_by_type, _picks_to_idx, - _contains_ch_type, pick_channels_cov, - _get_channel_types, get_channel_type_constants, - _DATA_CH_TYPES_SPLIT) +from mne.io.pick import ( + channel_indices_by_type, + channel_type, + pick_types_forward, + _picks_by_type, + _picks_to_idx, + _contains_ch_type, + pick_channels_cov, + _get_channel_types, + get_channel_type_constants, + _DATA_CH_TYPES_SPLIT, +) from mne.io.constants import FIFF from mne.datasets import testing from mne.utils import catch_logging, assert_object_equal data_path = testing.data_path(download=False) -fname_meeg = ( - data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-4-fwd.fif" -) +fname_meeg = data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-4-fwd.fif" fname_mc = data_path / "SSS" / "test_move_anon_movecomp_raw_sss.fif" io_dir = Path(__file__).parent.parent @@ -32,47 +44,82 @@ fif_fname = io_dir / "tests" / "data" / "test_raw.fif" -def _picks_by_type_old(info, meg_combined=False, ref_meg=False, - exclude='bads'): +def _picks_by_type_old(info, meg_combined=False, ref_meg=False, exclude="bads"): """Use the old, slower _picks_by_type code.""" picks_list = [] has = [_contains_ch_type(info, k) for k in _DATA_CH_TYPES_SPLIT] has = dict(zip(_DATA_CH_TYPES_SPLIT, has)) - if has['mag'] and (meg_combined is not True or not has['grad']): + if has["mag"] and (meg_combined is not True or not has["grad"]): picks_list.append( - ('mag', pick_types(info, meg='mag', eeg=False, stim=False, - ref_meg=ref_meg, exclude=exclude)) + ( + "mag", + pick_types( + info, + meg="mag", + eeg=False, + stim=False, + ref_meg=ref_meg, + exclude=exclude, + ), + ) ) - if has['grad'] and (meg_combined is not True or not has['mag']): + if has["grad"] and (meg_combined is not True or not has["mag"]): picks_list.append( - ('grad', pick_types(info, meg='grad', eeg=False, stim=False, - ref_meg=ref_meg, exclude=exclude)) + ( + "grad", + pick_types( + info, + meg="grad", + eeg=False, + stim=False, + ref_meg=ref_meg, + exclude=exclude, + ), + ) ) - if has['mag'] and has['grad'] and meg_combined is True: + if has["mag"] and has["grad"] and meg_combined is True: picks_list.append( - ('meg', pick_types(info, meg=True, eeg=False, stim=False, - ref_meg=ref_meg, exclude=exclude)) + ( + "meg", + pick_types( + info, + meg=True, + eeg=False, + stim=False, + ref_meg=ref_meg, + exclude=exclude, + ), + ) ) for ch_type in _DATA_CH_TYPES_SPLIT: - if ch_type in ['grad', 'mag']: # exclude just MEG channels + if ch_type in ["grad", "mag"]: # exclude just MEG channels continue if has[ch_type]: picks_list.append( - (ch_type, pick_types(info, meg=False, stim=False, - ref_meg=ref_meg, exclude=exclude, - **{ch_type: True})) + ( + ch_type, + pick_types( + info, + meg=False, + stim=False, + ref_meg=ref_meg, + exclude=exclude, + **{ch_type: True}, + ), + ) ) return picks_list def _channel_type_old(info, idx): """Get channel type using old, slower scheme.""" - ch = info['chs'][idx] + ch = info["chs"][idx] # iterate through all defined channel types until we find a match with ch # go in order from most specific (most rules entries) to least specific - channel_types = sorted(get_channel_type_constants().items(), - key=lambda x: len(x[1]), reverse=True) + channel_types = sorted( + get_channel_type_constants().items(), key=lambda x: len(x[1]), reverse=True + ) for t, rules in channel_types: for key, vals in rules.items(): # all keys must match the values if ch.get(key, None) not in np.array(vals): @@ -84,7 +131,7 @@ def _channel_type_old(info, idx): def _assert_channel_types(info): - for k in range(info['nchan']): + for k in range(info["nchan"]): a, b = channel_type(info, k), _channel_type_old(info, k) assert a == b @@ -98,9 +145,7 @@ def test_pick_refs(): mrk_path = kit_dir / "test_mrk.sqd" elp_path = kit_dir / "test_elp.txt" hsp_path = kit_dir / "test_hsp.txt" - raw_kit = read_raw_kit( - sqd_path, str(mrk_path), str(elp_path), str(hsp_path) - ) + raw_kit = read_raw_kit(sqd_path, str(mrk_path), str(elp_path), str(hsp_path)) infos.append(raw_kit.info) # BTi bti_dir = io_dir / "bti" / "tests" / "data" @@ -114,52 +159,64 @@ def test_pick_refs(): raw_ctf = read_raw_fif(fname_ctf_raw) raw_ctf.apply_gradient_compensation(2) for info in infos: - info['bads'] = [] + info["bads"] = [] _assert_channel_types(info) with pytest.raises(ValueError, match="'planar2'] or bool, not foo"): - pick_types(info, meg='foo') + pick_types(info, meg="foo") with pytest.raises(ValueError, match="'planar2', 'auto'] or bool,"): - pick_types(info, ref_meg='foo') + pick_types(info, ref_meg="foo") picks_meg_ref = pick_types(info, meg=True, ref_meg=True) picks_meg = pick_types(info, meg=True, ref_meg=False) picks_ref = pick_types(info, meg=False, ref_meg=True) - assert_array_equal(picks_meg_ref, - np.sort(np.concatenate([picks_meg, picks_ref]))) - picks_grad = pick_types(info, meg='grad', ref_meg=False) - picks_ref_grad = pick_types(info, meg=False, ref_meg='grad') - picks_meg_ref_grad = pick_types(info, meg='grad', ref_meg='grad') - assert_array_equal(picks_meg_ref_grad, - np.sort(np.concatenate([picks_grad, - picks_ref_grad]))) - picks_mag = pick_types(info, meg='mag', ref_meg=False) - picks_ref_mag = pick_types(info, meg=False, ref_meg='mag') - picks_meg_ref_mag = pick_types(info, meg='mag', ref_meg='mag') - assert_array_equal(picks_meg_ref_mag, - np.sort(np.concatenate([picks_mag, - picks_ref_mag]))) - assert_array_equal(picks_meg, - np.sort(np.concatenate([picks_mag, picks_grad]))) - assert_array_equal(picks_ref, - np.sort(np.concatenate([picks_ref_mag, - picks_ref_grad]))) - assert_array_equal(picks_meg_ref, np.sort(np.concatenate( - [picks_grad, picks_mag, picks_ref_grad, picks_ref_mag]))) - - for pick in (picks_meg_ref, picks_meg, picks_ref, - picks_grad, picks_ref_grad, picks_meg_ref_grad, - picks_mag, picks_ref_mag, picks_meg_ref_mag): + assert_array_equal( + picks_meg_ref, np.sort(np.concatenate([picks_meg, picks_ref])) + ) + picks_grad = pick_types(info, meg="grad", ref_meg=False) + picks_ref_grad = pick_types(info, meg=False, ref_meg="grad") + picks_meg_ref_grad = pick_types(info, meg="grad", ref_meg="grad") + assert_array_equal( + picks_meg_ref_grad, np.sort(np.concatenate([picks_grad, picks_ref_grad])) + ) + picks_mag = pick_types(info, meg="mag", ref_meg=False) + picks_ref_mag = pick_types(info, meg=False, ref_meg="mag") + picks_meg_ref_mag = pick_types(info, meg="mag", ref_meg="mag") + assert_array_equal( + picks_meg_ref_mag, np.sort(np.concatenate([picks_mag, picks_ref_mag])) + ) + assert_array_equal(picks_meg, np.sort(np.concatenate([picks_mag, picks_grad]))) + assert_array_equal( + picks_ref, np.sort(np.concatenate([picks_ref_mag, picks_ref_grad])) + ) + assert_array_equal( + picks_meg_ref, + np.sort( + np.concatenate([picks_grad, picks_mag, picks_ref_grad, picks_ref_mag]) + ), + ) + + for pick in ( + picks_meg_ref, + picks_meg, + picks_ref, + picks_grad, + picks_ref_grad, + picks_meg_ref_grad, + picks_mag, + picks_ref_mag, + picks_meg_ref_mag, + ): if len(pick) > 0: pick_info(info, pick) # test CTF expected failures directly info = raw_ctf.info - info['bads'] = [] + info["bads"] = [] picks_meg_ref = pick_types(info, meg=True, ref_meg=True) picks_meg = pick_types(info, meg=True, ref_meg=False) picks_ref = pick_types(info, meg=False, ref_meg=True) - picks_mag = pick_types(info, meg='mag', ref_meg=False) - picks_ref_mag = pick_types(info, meg=False, ref_meg='mag') - picks_meg_ref_mag = pick_types(info, meg='mag', ref_meg='mag') + picks_mag = pick_types(info, meg="mag", ref_meg=False) + picks_ref_mag = pick_types(info, meg=False, ref_meg="mag") + picks_meg_ref_mag = pick_types(info, meg="mag", ref_meg="mag") for pick in (picks_meg_ref, picks_ref, picks_ref_mag, picks_meg_ref_mag): if len(pick) > 0: pick_info(info, pick) @@ -168,9 +225,10 @@ def test_pick_refs(): if len(pick) > 0: with catch_logging() as log: pick_info(info, pick, verbose=True) - assert ('Removing {} compensators'.format(len(info['comps'])) - in log.getvalue()) - picks_ref_grad = pick_types(info, meg=False, ref_meg='grad') + assert ( + "Removing {} compensators".format(len(info["comps"])) in log.getvalue() + ) + picks_ref_grad = pick_types(info, meg=False, ref_meg="grad") assert set(picks_ref_mag) == set(picks_ref) assert len(picks_ref_grad) == 0 all_meg = np.arange(3, 306) @@ -180,10 +238,10 @@ def test_pick_refs(): def test_pick_channels_regexp(): """Test pick with regular expression.""" - ch_names = ['MEG 2331', 'MEG 2332', 'MEG 2333'] - assert_array_equal(pick_channels_regexp(ch_names, 'MEG ...1'), [0]) - assert_array_equal(pick_channels_regexp(ch_names, 'MEG ...[2-3]'), [1, 2]) - assert_array_equal(pick_channels_regexp(ch_names, 'MEG *'), [0, 1, 2]) + ch_names = ["MEG 2331", "MEG 2332", "MEG 2333"] + assert_array_equal(pick_channels_regexp(ch_names, "MEG ...1"), [0]) + assert_array_equal(pick_channels_regexp(ch_names, "MEG ...[2-3]"), [1, 2]) + assert_array_equal(pick_channels_regexp(ch_names, "MEG *"), [0, 1, 2]) def assert_indexing(info, picks_by_type, ref_meg=False, all_data=True): @@ -202,20 +260,19 @@ def assert_indexing(info, picks_by_type, ref_meg=False, all_data=True): assert len(idx[key]) == 0 # Finally, picks_by_type (if relevant) if not all_data: - picks_by_type = [p for p in picks_by_type - if p[0] in _DATA_CH_TYPES_SPLIT] + picks_by_type = [p for p in picks_by_type if p[0] in _DATA_CH_TYPES_SPLIT] picks_by_type = [(p[0], np.array(p[1], int)) for p in picks_by_type] actual = _picks_by_type(info, ref_meg=ref_meg) assert_object_equal(actual, picks_by_type) - if not ref_meg and idx['hbo']: # our old code had a bug - with pytest.raises(TypeError, match='unexpected keyword argument'): + if not ref_meg and idx["hbo"]: # our old code had a bug + with pytest.raises(TypeError, match="unexpected keyword argument"): _picks_by_type_old(info, ref_meg=ref_meg) else: old = _picks_by_type_old(info, ref_meg=ref_meg) assert_object_equal(old, picks_by_type) # test bads info = info.copy() - info['bads'] = [info['chs'][picks_by_type[0][1][0]]['ch_name']] + info["bads"] = [info["chs"][picks_by_type[0][1][0]]["ch_name"]] picks_by_type = deepcopy(picks_by_type) picks_by_type[0] = (picks_by_type[0][0], picks_by_type[0][1][1:]) actual = _picks_by_type(info, ref_meg=ref_meg) @@ -224,20 +281,29 @@ def assert_indexing(info, picks_by_type, ref_meg=False, all_data=True): def test_pick_seeg_ecog(): """Test picking with sEEG and ECoG.""" - names = 'A1 A2 Fz O OTp1 OTp2 E1 OTp3 E2 E3'.split() - types = 'mag mag eeg eeg seeg seeg ecog seeg ecog ecog'.split() - info = create_info(names, 1024., types) - picks_by_type = [('mag', [0, 1]), ('eeg', [2, 3]), - ('seeg', [4, 5, 7]), ('ecog', [6, 8, 9])] + names = "A1 A2 Fz O OTp1 OTp2 E1 OTp3 E2 E3".split() + types = "mag mag eeg eeg seeg seeg ecog seeg ecog ecog".split() + info = create_info(names, 1024.0, types) + picks_by_type = [ + ("mag", [0, 1]), + ("eeg", [2, 3]), + ("seeg", [4, 5, 7]), + ("ecog", [6, 8, 9]), + ] assert_indexing(info, picks_by_type) assert_array_equal(pick_types(info, meg=False, seeg=True), [4, 5, 7]) for i, t in enumerate(types): assert channel_type(info, i) == types[i] raw = RawArray(np.zeros((len(names), 10)), info) events = np.array([[1, 0, 0], [2, 0, 0]]) - epochs = Epochs(raw, events=events, event_id={'event': 0}, - tmin=-1e-5, tmax=1e-5, - baseline=(0, 0)) # only one sample + epochs = Epochs( + raw, + events=events, + event_id={"event": 0}, + tmin=-1e-5, + tmax=1e-5, + baseline=(0, 0), + ) # only one sample evoked = epochs.average(pick_types(epochs.info, meg=True, seeg=True)) e_seeg = evoked.copy().pick_types(meg=False, seeg=True) for lt, rt in zip(e_seeg.ch_names, [names[4], names[5], names[7]]): @@ -250,19 +316,24 @@ def test_pick_seeg_ecog(): def test_pick_dbs(): """Test picking with DBS.""" # gh-8739 - names = 'A1 A2 Fz O OTp1 OTp2 OTp3'.split() - types = 'mag mag eeg eeg dbs dbs dbs'.split() - info = create_info(names, 1024., types) - picks_by_type = [('mag', [0, 1]), ('eeg', [2, 3]), ('dbs', [4, 5, 6])] + names = "A1 A2 Fz O OTp1 OTp2 OTp3".split() + types = "mag mag eeg eeg dbs dbs dbs".split() + info = create_info(names, 1024.0, types) + picks_by_type = [("mag", [0, 1]), ("eeg", [2, 3]), ("dbs", [4, 5, 6])] assert_indexing(info, picks_by_type) assert_array_equal(pick_types(info, meg=False, dbs=True), [4, 5, 6]) for i, t in enumerate(types): assert channel_type(info, i) == types[i] raw = RawArray(np.zeros((len(names), 7)), info) events = np.array([[1, 0, 0], [2, 0, 0]]) - epochs = Epochs(raw, events=events, event_id={'event': 0}, - tmin=-1e-5, tmax=1e-5, - baseline=(0, 0)) # only one sample + epochs = Epochs( + raw, + events=events, + event_id={"event": 0}, + tmin=-1e-5, + tmax=1e-5, + baseline=(0, 0), + ) # only one sample evoked = epochs.average(pick_types(epochs.info, meg=True, dbs=True)) e_dbs = evoked.copy().pick_types(meg=False, dbs=True) for lt, rt in zip(e_dbs.ch_names, [names[4], names[5], names[6]]): @@ -277,58 +348,70 @@ def test_pick_chpi(): info = read_info(io_dir / "tests" / "data" / "test_chpi_raw_sss.fif") _assert_channel_types(info) channel_types = _get_channel_types(info) - assert 'chpi' in channel_types - assert 'seeg' not in channel_types - assert 'ecog' not in channel_types + assert "chpi" in channel_types + assert "seeg" not in channel_types + assert "ecog" not in channel_types def test_pick_csd(): """Test picking current source density channels.""" # Make sure we don't mis-classify cHPI channels - names = ['MEG 2331', 'MEG 2332', 'MEG 2333', 'A1', 'A2', 'Fz'] - types = 'mag mag grad csd csd csd'.split() - info = create_info(names, 1024., types) - picks_by_type = [('mag', [0, 1]), ('grad', [2]), ('csd', [3, 4, 5])] + names = ["MEG 2331", "MEG 2332", "MEG 2333", "A1", "A2", "Fz"] + types = "mag mag grad csd csd csd".split() + info = create_info(names, 1024.0, types) + picks_by_type = [("mag", [0, 1]), ("grad", [2]), ("csd", [3, 4, 5])] assert_indexing(info, picks_by_type, all_data=False) def test_pick_bio(): """Test picking BIO channels.""" - names = 'A1 A2 Fz O BIO1 BIO2 BIO3'.split() - types = 'mag mag eeg eeg bio bio bio'.split() - info = create_info(names, 1024., types) - picks_by_type = [('mag', [0, 1]), ('eeg', [2, 3]), ('bio', [4, 5, 6])] + names = "A1 A2 Fz O BIO1 BIO2 BIO3".split() + types = "mag mag eeg eeg bio bio bio".split() + info = create_info(names, 1024.0, types) + picks_by_type = [("mag", [0, 1]), ("eeg", [2, 3]), ("bio", [4, 5, 6])] assert_indexing(info, picks_by_type, all_data=False) def test_pick_fnirs(): """Test picking fNIRS channels.""" - names = 'A1 A2 Fz O hbo1 hbo2 hbr1 fnirsRaw1 fnirsRaw2 fnirsOD1'.split() - types = 'mag mag eeg eeg hbo hbo hbr fnirs_cw_' \ - 'amplitude fnirs_cw_amplitude fnirs_od'.split() - info = create_info(names, 1024., types) - picks_by_type = [('mag', [0, 1]), ('eeg', [2, 3]), - ('hbo', [4, 5]), ('hbr', [6]), - ('fnirs_cw_amplitude', [7, 8]), ('fnirs_od', [9])] + names = "A1 A2 Fz O hbo1 hbo2 hbr1 fnirsRaw1 fnirsRaw2 fnirsOD1".split() + types = ( + "mag mag eeg eeg hbo hbo hbr fnirs_cw_" + "amplitude fnirs_cw_amplitude fnirs_od".split() + ) + info = create_info(names, 1024.0, types) + picks_by_type = [ + ("mag", [0, 1]), + ("eeg", [2, 3]), + ("hbo", [4, 5]), + ("hbr", [6]), + ("fnirs_cw_amplitude", [7, 8]), + ("fnirs_od", [9]), + ] assert_indexing(info, picks_by_type) def test_pick_ref(): """Test picking ref_meg channels.""" info = read_info(ctf_fname) - picks_by_type = [('stim', [0]), ('eog', [306, 307]), ('ecg', [308]), - ('misc', [1]), - ('mag', np.arange(31, 306)), - ('ref_meg', np.arange(2, 31))] + picks_by_type = [ + ("stim", [0]), + ("eog", [306, 307]), + ("ecg", [308]), + ("misc", [1]), + ("mag", np.arange(31, 306)), + ("ref_meg", np.arange(2, 31)), + ] assert_indexing(info, picks_by_type, all_data=False) - picks_by_type.append(('mag', np.concatenate([picks_by_type.pop(-1)[1], - picks_by_type.pop(-1)[1]]))) + picks_by_type.append( + ("mag", np.concatenate([picks_by_type.pop(-1)[1], picks_by_type.pop(-1)[1]])) + ) assert_indexing(info, picks_by_type, ref_meg=True, all_data=False) def _check_fwd_n_chan_consistent(fwd, n_expected): - n_ok = len(fwd['info']['ch_names']) - n_sol = fwd['sol']['data'].shape[0] + n_ok = len(fwd["info"]["ch_names"]) + n_sol = fwd["sol"]["data"].shape[0] assert n_expected == n_sol assert n_expected == n_ok @@ -337,47 +420,47 @@ def _check_fwd_n_chan_consistent(fwd, n_expected): def test_pick_forward_seeg_ecog(): """Test picking forward with SEEG and ECoG.""" fwd = read_forward_solution(fname_meeg) - counts = channel_indices_by_type(fwd['info']) + counts = channel_indices_by_type(fwd["info"]) for key in counts.keys(): counts[key] = len(counts[key]) - counts['meg'] = counts['mag'] + counts['grad'] + counts["meg"] = counts["mag"] + counts["grad"] fwd_ = pick_types_forward(fwd, meg=True) - _check_fwd_n_chan_consistent(fwd_, counts['meg']) + _check_fwd_n_chan_consistent(fwd_, counts["meg"]) fwd_ = pick_types_forward(fwd, meg=False, eeg=True) - _check_fwd_n_chan_consistent(fwd_, counts['eeg']) + _check_fwd_n_chan_consistent(fwd_, counts["eeg"]) # should raise exception related to emptiness pytest.raises(ValueError, pick_types_forward, fwd, meg=False, seeg=True) pytest.raises(ValueError, pick_types_forward, fwd, meg=False, ecog=True) # change last chan from EEG to sEEG, second-to-last to ECoG - ecog_name = 'E1' - seeg_name = 'OTp1' - rename_channels(fwd['info'], {'EEG 059': ecog_name}) - rename_channels(fwd['info'], {'EEG 060': seeg_name}) - for ch in fwd['info']['chs']: - if ch['ch_name'] == seeg_name: - ch['kind'] = FIFF.FIFFV_SEEG_CH - ch['coil_type'] = FIFF.FIFFV_COIL_EEG - elif ch['ch_name'] == ecog_name: - ch['kind'] = FIFF.FIFFV_ECOG_CH - ch['coil_type'] = FIFF.FIFFV_COIL_EEG - fwd['sol']['row_names'][-1] = fwd['info']['chs'][-1]['ch_name'] - fwd['sol']['row_names'][-2] = fwd['info']['chs'][-2]['ch_name'] - counts['eeg'] -= 2 - counts['seeg'] += 1 - counts['ecog'] += 1 + ecog_name = "E1" + seeg_name = "OTp1" + rename_channels(fwd["info"], {"EEG 059": ecog_name}) + rename_channels(fwd["info"], {"EEG 060": seeg_name}) + for ch in fwd["info"]["chs"]: + if ch["ch_name"] == seeg_name: + ch["kind"] = FIFF.FIFFV_SEEG_CH + ch["coil_type"] = FIFF.FIFFV_COIL_EEG + elif ch["ch_name"] == ecog_name: + ch["kind"] = FIFF.FIFFV_ECOG_CH + ch["coil_type"] = FIFF.FIFFV_COIL_EEG + fwd["sol"]["row_names"][-1] = fwd["info"]["chs"][-1]["ch_name"] + fwd["sol"]["row_names"][-2] = fwd["info"]["chs"][-2]["ch_name"] + counts["eeg"] -= 2 + counts["seeg"] += 1 + counts["ecog"] += 1 # repick & check fwd_seeg = pick_types_forward(fwd, meg=False, seeg=True) - assert fwd_seeg['sol']['row_names'] == [seeg_name] - assert fwd_seeg['info']['ch_names'] == [seeg_name] + assert fwd_seeg["sol"]["row_names"] == [seeg_name] + assert fwd_seeg["info"]["ch_names"] == [seeg_name] # should work fine fwd_ = pick_types_forward(fwd, meg=True) - _check_fwd_n_chan_consistent(fwd_, counts['meg']) + _check_fwd_n_chan_consistent(fwd_, counts["meg"]) fwd_ = pick_types_forward(fwd, meg=False, eeg=True) - _check_fwd_n_chan_consistent(fwd_, counts['eeg']) + _check_fwd_n_chan_consistent(fwd_, counts["eeg"]) fwd_ = pick_types_forward(fwd, meg=False, seeg=True) - _check_fwd_n_chan_consistent(fwd_, counts['seeg']) + _check_fwd_n_chan_consistent(fwd_, counts["seeg"]) fwd_ = pick_types_forward(fwd, meg=False, ecog=True) - _check_fwd_n_chan_consistent(fwd_, counts['ecog']) + _check_fwd_n_chan_consistent(fwd_, counts["ecog"]) def test_picks_by_channels(): @@ -385,8 +468,8 @@ def test_picks_by_channels(): rng = np.random.RandomState(909) test_data = rng.random_sample((4, 2000)) - ch_names = ['MEG %03d' % i for i in [1, 2, 3, 4]] - ch_types = ['grad', 'mag', 'mag', 'eeg'] + ch_names = ["MEG %03d" % i for i in [1, 2, 3, 4]] + ch_types = ["grad", "mag", "mag", "eeg"] sfreq = 250.0 info = create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types) _assert_channel_types(info) @@ -394,54 +477,54 @@ def test_picks_by_channels(): pick_list = _picks_by_type(raw.info) assert len(pick_list) == 3 - assert pick_list[0][0] == 'mag' + assert pick_list[0][0] == "mag" pick_list2 = _picks_by_type(raw.info, meg_combined=False) assert len(pick_list) == len(pick_list2) - assert pick_list2[0][0] == 'mag' + assert pick_list2[0][0] == "mag" pick_list2 = _picks_by_type(raw.info, meg_combined=True) assert len(pick_list) == len(pick_list2) + 1 - assert pick_list2[0][0] == 'meg' + assert pick_list2[0][0] == "meg" test_data = rng.random_sample((4, 2000)) - ch_names = ['MEG %03d' % i for i in [1, 2, 3, 4]] - ch_types = ['mag', 'mag', 'mag', 'mag'] + ch_names = ["MEG %03d" % i for i in [1, 2, 3, 4]] + ch_types = ["mag", "mag", "mag", "mag"] sfreq = 250.0 info = create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types) raw = RawArray(test_data, info) # This acts as a set, not an order assert_array_equal( - pick_channels(info['ch_names'], ['MEG 002', 'MEG 001'], ordered=False), - [0, 1]) + pick_channels(info["ch_names"], ["MEG 002", "MEG 001"], ordered=False), [0, 1] + ) # Make sure checks for list input work. - pytest.raises(ValueError, pick_channels, ch_names, 'MEG 001') - pytest.raises(ValueError, pick_channels, ch_names, ['MEG 001'], 'hi') + pytest.raises(ValueError, pick_channels, ch_names, "MEG 001") + pytest.raises(ValueError, pick_channels, ch_names, ["MEG 001"], "hi") pick_list = _picks_by_type(raw.info) assert len(pick_list) == 1 - assert pick_list[0][0] == 'mag' + assert pick_list[0][0] == "mag" pick_list2 = _picks_by_type(raw.info, meg_combined=True) assert len(pick_list) == len(pick_list2) - assert pick_list2[0][0] == 'mag' + assert pick_list2[0][0] == "mag" # pick_types type check - with pytest.raises(ValueError, match='must be of type'): - raw.pick_types(eeg='string') + with pytest.raises(ValueError, match="must be of type"): + raw.pick_types(eeg="string") # duplicate check - names = ['MEG 002', 'MEG 002'] - assert len(pick_channels(raw.info['ch_names'], names, ordered=False)) == 1 - with pytest.warns(FutureWarning, match='ordered=False'): + names = ["MEG 002", "MEG 002"] + assert len(pick_channels(raw.info["ch_names"], names, ordered=False)) == 1 + with pytest.warns(FutureWarning, match="ordered=False"): assert len(raw.copy().pick_channels(names)[0][0]) == 1 # missing ch_name - bad_names = names + ['BAD'] - with pytest.raises(ValueError, match='Missing channels'): - pick_channels(raw.info['ch_names'], bad_names, ordered=True) - with pytest.raises(ValueError, match='Missing channels'): + bad_names = names + ["BAD"] + with pytest.raises(ValueError, match="Missing channels"): + pick_channels(raw.info["ch_names"], bad_names, ordered=True) + with pytest.raises(ValueError, match="Missing channels"): raw.copy().pick_channels(bad_names, ordered=True) - with pytest.raises(ValueError, match='could not be picked'): + with pytest.raises(ValueError, match="could not be picked"): raw.copy().pick(bad_names) @@ -456,17 +539,17 @@ def test_clean_info_bads(): # select 3 eeg channels as bads idx_eeg_bad_ch = picks_eeg[[1, 5, 14]] - eeg_bad_ch = [raw.info['ch_names'][k] for k in idx_eeg_bad_ch] + eeg_bad_ch = [raw.info["ch_names"][k] for k in idx_eeg_bad_ch] # select meg channels picks_meg = pick_types(raw.info, meg=True, eeg=False) # select randomly 3 meg channels as bads idx_meg_bad_ch = picks_meg[[0, 15, 34]] - meg_bad_ch = [raw.info['ch_names'][k] for k in idx_meg_bad_ch] + meg_bad_ch = [raw.info["ch_names"][k] for k in idx_meg_bad_ch] # simulate the bad channels - raw.info['bads'] = eeg_bad_ch + meg_bad_ch + raw.info["bads"] = eeg_bad_ch + meg_bad_ch # simulate the call to pick_info excluding the bad eeg channels info_eeg = pick_info(raw.info, picks_eeg) @@ -474,93 +557,88 @@ def test_clean_info_bads(): # simulate the call to pick_info excluding the bad meg channels info_meg = pick_info(raw.info, picks_meg) - assert info_eeg['bads'] == eeg_bad_ch - assert info_meg['bads'] == meg_bad_ch + assert info_eeg["bads"] == eeg_bad_ch + assert info_meg["bads"] == meg_bad_ch info = pick_info(raw.info, picks_meg) info._check_consistency() - info['bads'] += ['EEG 053'] + info["bads"] += ["EEG 053"] pytest.raises(RuntimeError, info._check_consistency) - with pytest.raises(ValueError, match='unique'): + with pytest.raises(ValueError, match="unique"): pick_info(raw.info, [0, 0]) @testing.requires_testing_data def test_picks_to_idx(): """Test checking type integrity checks of picks.""" - info = create_info(12, 1000., 'eeg') + info = create_info(12, 1000.0, "eeg") _assert_channel_types(info) - picks = np.arange(info['nchan']) + picks = np.arange(info["nchan"]) # Array and list assert_array_equal(picks, _picks_to_idx(info, picks)) assert_array_equal(picks, _picks_to_idx(info, list(picks))) - with pytest.raises(TypeError, match='data type float64 is invalid'): - _picks_to_idx(info, 1.) + with pytest.raises(TypeError, match="data type float64 is invalid"): + _picks_to_idx(info, 1.0) # None assert_array_equal(picks, _picks_to_idx(info, None)) # Type indexing - assert_array_equal(picks, _picks_to_idx(info, 'eeg')) - assert_array_equal(picks, _picks_to_idx(info, ['eeg'])) + assert_array_equal(picks, _picks_to_idx(info, "eeg")) + assert_array_equal(picks, _picks_to_idx(info, ["eeg"])) # Negative indexing assert_array_equal([len(picks) - 1], _picks_to_idx(info, len(picks) - 1)) assert_array_equal([len(picks) - 1], _picks_to_idx(info, -1)) assert_array_equal([len(picks) - 1], _picks_to_idx(info, [-1])) # Name indexing - assert_array_equal([2], _picks_to_idx(info, info['ch_names'][2])) - assert_array_equal(np.arange(5, 9), - _picks_to_idx(info, info['ch_names'][5:9])) - with pytest.raises(ValueError, match='must be >= '): + assert_array_equal([2], _picks_to_idx(info, info["ch_names"][2])) + assert_array_equal(np.arange(5, 9), _picks_to_idx(info, info["ch_names"][5:9])) + with pytest.raises(ValueError, match="must be >= "): _picks_to_idx(info, -len(picks) - 1) - with pytest.raises(ValueError, match='must be < '): + with pytest.raises(ValueError, match="must be < "): _picks_to_idx(info, len(picks)) - with pytest.raises(ValueError, match='could not be interpreted'): - _picks_to_idx(info, ['a', 'b']) - with pytest.raises(ValueError, match='could not be interpreted'): - _picks_to_idx(info, 'b') + with pytest.raises(ValueError, match="could not be interpreted"): + _picks_to_idx(info, ["a", "b"]) + with pytest.raises(ValueError, match="could not be interpreted"): + _picks_to_idx(info, "b") # bads behavior - info['bads'] = info['ch_names'][1:2] + info["bads"] = info["ch_names"][1:2] picks_good = np.array([0] + list(range(2, 12))) assert_array_equal(picks_good, _picks_to_idx(info, None)) - assert_array_equal(picks_good, _picks_to_idx(info, None, - exclude=info['bads'])) + assert_array_equal(picks_good, _picks_to_idx(info, None, exclude=info["bads"])) assert_array_equal(picks, _picks_to_idx(info, None, exclude=())) - with pytest.raises(ValueError, match=' 1D, got'): + with pytest.raises(ValueError, match=" 1D, got"): _picks_to_idx(info, [[1]]) # MEG types info = read_info(fname_mc) meg_picks = np.arange(306) mag_picks = np.arange(2, 306, 3) grad_picks = np.setdiff1d(meg_picks, mag_picks) - assert_array_equal(meg_picks, _picks_to_idx(info, 'meg')) - assert_array_equal(meg_picks, _picks_to_idx(info, ('mag', 'grad'))) - assert_array_equal(mag_picks, _picks_to_idx(info, 'mag')) - assert_array_equal(grad_picks, _picks_to_idx(info, 'grad')) - - info = create_info(['eeg', 'foo'], 1000., 'eeg') - with pytest.raises(RuntimeError, match='equivalent to channel types'): - _picks_to_idx(info, 'eeg') - with pytest.raises(ValueError, match='same length'): - create_info(['a', 'b'], 1000., dict(hbo=['a'], hbr=['b'])) - info = create_info(['a', 'b'], 1000., ['hbo', 'hbr']) - assert_array_equal(np.arange(2), _picks_to_idx(info, 'fnirs')) - assert_array_equal([0], _picks_to_idx(info, 'hbo')) - assert_array_equal([1], _picks_to_idx(info, 'hbr')) - info = create_info(['a', 'b'], 1000., ['hbo', 'misc']) - assert_array_equal(np.arange(len(info['ch_names'])), - _picks_to_idx(info, 'all')) - assert_array_equal([0], _picks_to_idx(info, 'data')) - info = create_info(['a', 'b'], 1000., ['fnirs_cw_amplitude', 'fnirs_od']) - assert_array_equal(np.arange(2), _picks_to_idx(info, 'fnirs')) - assert_array_equal([0], _picks_to_idx(info, 'fnirs_cw_amplitude')) - assert_array_equal([1], _picks_to_idx(info, 'fnirs_od')) - info = create_info(['a', 'b'], 1000., ['fnirs_cw_amplitude', 'misc']) - assert_array_equal(np.arange(len(info['ch_names'])), - _picks_to_idx(info, 'all')) - assert_array_equal([0], _picks_to_idx(info, 'data')) - info = create_info(['a', 'b'], 1000., ['fnirs_od', 'misc']) - assert_array_equal(np.arange(len(info['ch_names'])), - _picks_to_idx(info, 'all')) - assert_array_equal([0], _picks_to_idx(info, 'data')) + assert_array_equal(meg_picks, _picks_to_idx(info, "meg")) + assert_array_equal(meg_picks, _picks_to_idx(info, ("mag", "grad"))) + assert_array_equal(mag_picks, _picks_to_idx(info, "mag")) + assert_array_equal(grad_picks, _picks_to_idx(info, "grad")) + + info = create_info(["eeg", "foo"], 1000.0, "eeg") + with pytest.raises(RuntimeError, match="equivalent to channel types"): + _picks_to_idx(info, "eeg") + with pytest.raises(ValueError, match="same length"): + create_info(["a", "b"], 1000.0, dict(hbo=["a"], hbr=["b"])) + info = create_info(["a", "b"], 1000.0, ["hbo", "hbr"]) + assert_array_equal(np.arange(2), _picks_to_idx(info, "fnirs")) + assert_array_equal([0], _picks_to_idx(info, "hbo")) + assert_array_equal([1], _picks_to_idx(info, "hbr")) + info = create_info(["a", "b"], 1000.0, ["hbo", "misc"]) + assert_array_equal(np.arange(len(info["ch_names"])), _picks_to_idx(info, "all")) + assert_array_equal([0], _picks_to_idx(info, "data")) + info = create_info(["a", "b"], 1000.0, ["fnirs_cw_amplitude", "fnirs_od"]) + assert_array_equal(np.arange(2), _picks_to_idx(info, "fnirs")) + assert_array_equal([0], _picks_to_idx(info, "fnirs_cw_amplitude")) + assert_array_equal([1], _picks_to_idx(info, "fnirs_od")) + info = create_info(["a", "b"], 1000.0, ["fnirs_cw_amplitude", "misc"]) + assert_array_equal(np.arange(len(info["ch_names"])), _picks_to_idx(info, "all")) + assert_array_equal([0], _picks_to_idx(info, "data")) + info = create_info(["a", "b"], 1000.0, ["fnirs_od", "misc"]) + assert_array_equal(np.arange(len(info["ch_names"])), _picks_to_idx(info, "all")) + assert_array_equal([0], _picks_to_idx(info, "data")) # MEG reference sensors info_ref = read_info(ctf_fname) picks_meg = pick_types(info_ref, meg=True, ref_meg=False) @@ -569,42 +647,40 @@ def test_picks_to_idx(): assert len(picks_ref) == 29 picks_meg_ref = np.sort(np.concatenate([picks_meg, picks_ref])) assert len(picks_meg_ref) == 275 + 29 - assert_array_equal( - picks_meg_ref, pick_types(info_ref, meg=True, ref_meg=True)) - assert_array_equal( - picks_meg, _picks_to_idx(info_ref, 'meg', with_ref_meg=False)) + assert_array_equal(picks_meg_ref, pick_types(info_ref, meg=True, ref_meg=True)) + assert_array_equal(picks_meg, _picks_to_idx(info_ref, "meg", with_ref_meg=False)) assert_array_equal( # explicit trumps implicit - picks_ref, _picks_to_idx(info_ref, 'ref_meg', with_ref_meg=False)) - assert_array_equal( - picks_meg_ref, _picks_to_idx(info_ref, 'meg', with_ref_meg=True)) + picks_ref, _picks_to_idx(info_ref, "ref_meg", with_ref_meg=False) + ) + assert_array_equal(picks_meg_ref, _picks_to_idx(info_ref, "meg", with_ref_meg=True)) def test_pick_channels_cov(): """Test picking channels from a Covariance object.""" - info = create_info(['CH1', 'CH2', 'CH3'], 1., ch_types='eeg') + info = create_info(["CH1", "CH2", "CH3"], 1.0, ch_types="eeg") cov = make_ad_hoc_cov(info) - cov['data'] = np.array([1., 2., 3.]) + cov["data"] = np.array([1.0, 2.0, 3.0]) - cov_copy = pick_channels_cov(cov, ['CH2', 'CH1'], ordered=False, copy=True) - assert cov_copy.ch_names == ['CH1', 'CH2'] - assert_array_equal(cov_copy['data'], [1., 2.]) + cov_copy = pick_channels_cov(cov, ["CH2", "CH1"], ordered=False, copy=True) + assert cov_copy.ch_names == ["CH1", "CH2"] + assert_array_equal(cov_copy["data"], [1.0, 2.0]) # Test re-ordering channels - cov_copy = pick_channels_cov(cov, ['CH2', 'CH1'], ordered=True, copy=True) - assert cov_copy.ch_names == ['CH2', 'CH1'] - assert_array_equal(cov_copy['data'], [2., 1.]) + cov_copy = pick_channels_cov(cov, ["CH2", "CH1"], ordered=True, copy=True) + assert cov_copy.ch_names == ["CH2", "CH1"] + assert_array_equal(cov_copy["data"], [2.0, 1.0]) # Test picking in-place - pick_channels_cov(cov, ['CH2', 'CH1'], copy=False, ordered=False) - assert cov.ch_names == ['CH1', 'CH2'] - assert_array_equal(cov['data'], [1., 2.]) + pick_channels_cov(cov, ["CH2", "CH1"], copy=False, ordered=False) + assert cov.ch_names == ["CH1", "CH2"] + assert_array_equal(cov["data"], [1.0, 2.0]) # Test whether `method` and `loglik` are dropped when None - cov['method'] = None - cov['loglik'] = None - cov_copy = pick_channels_cov(cov, ['CH1', 'CH2'], copy=True) - assert 'method' not in cov_copy - assert 'loglik' not in cov_copy + cov["method"] = None + cov["loglik"] = None + cov_copy = pick_channels_cov(cov, ["CH1", "CH2"], copy=True) + assert "method" not in cov_copy + assert "loglik" not in cov_copy def test_pick_types_meg(): @@ -617,8 +693,8 @@ def test_pick_types_meg(): assert list(pick_types(info1, meg=True)) == [1, 2, 4] assert not list(pick_types(info1, meg=False)) # empty - assert list(pick_types(info1, meg='planar1')) == [2] - assert not list(pick_types(info1, meg='planar2')) # empty + assert list(pick_types(info1, meg="planar1")) == [2] + assert not list(pick_types(info1, meg="planar2")) # empty # info without any MEG channels info2 = create_info(6, 256, ["eeg", "eeg", "eog", "misc", "stim", "hbo"]) @@ -630,23 +706,29 @@ def test_pick_types_meg(): def test_pick_types_csd(): """Test pick_types(csd=True).""" # info with laplacian/CSD channels at indices 1, 2 - names = ['F1', 'F2', 'C1', 'C2', 'A1', 'A2', 'misc1', 'CSD1'] - info1 = create_info(names, 256, ["eeg", "eeg", "eeg", "eeg", "mag", - "mag", 'misc', 'csd']) + names = ["F1", "F2", "C1", "C2", "A1", "A2", "misc1", "CSD1"] + info1 = create_info( + names, 256, ["eeg", "eeg", "eeg", "eeg", "mag", "mag", "misc", "csd"] + ) raw = RawArray(np.zeros((8, 512)), info1) - raw.set_montage(make_standard_montage('standard_1020'), verbose='error') - raw_csd = compute_current_source_density(raw, verbose='error') + raw.set_montage(make_standard_montage("standard_1020"), verbose="error") + raw_csd = compute_current_source_density(raw, verbose="error") assert_array_equal(pick_types(info1, csd=True), [7]) # pick from the raw object assert raw_csd.copy().pick_types(csd=True).ch_names == [ - 'F1', 'F2', 'C1', 'C2', 'CSD1'] + "F1", + "F2", + "C1", + "C2", + "CSD1", + ] -@pytest.mark.parametrize('meg', [True, False, 'grad', 'mag']) -@pytest.mark.parametrize('eeg', [True, False]) -@pytest.mark.parametrize('ordered', [True, False]) +@pytest.mark.parametrize("meg", [True, False, "grad", "mag"]) +@pytest.mark.parametrize("eeg", [True, False]) +@pytest.mark.parametrize("ordered", [True, False]) def test_get_channel_types_equiv(meg, eeg, ordered): """Test equivalence of get_channel_types.""" raw = read_raw_fif(fif_fname) @@ -655,7 +737,7 @@ def test_get_channel_types_equiv(meg, eeg, ordered): if not ordered: picks = np.random.RandomState(0).permutation(picks) if not meg and not eeg: - with pytest.raises(ValueError, match='No appropriate channels'): + with pytest.raises(ValueError, match="No appropriate channels"): raw.get_channel_types(picks=picks) return types = np.array(raw.get_channel_types(picks=picks)) diff --git a/mne/io/tests/test_proc_history.py b/mne/io/tests/test_proc_history.py index 964464522cf..c82b7c63094 100644 --- a/mne/io/tests/test_proc_history.py +++ b/mne/io/tests/test_proc_history.py @@ -17,22 +17,26 @@ def test_maxfilter_io(): """Test maxfilter io.""" info = read_info(raw_fname) - mf = info['proc_history'][1]['max_info'] + mf = info["proc_history"][1]["max_info"] - assert mf['sss_info']['frame'] == FIFF.FIFFV_COORD_HEAD + assert mf["sss_info"]["frame"] == FIFF.FIFFV_COORD_HEAD # based on manual 2.0, rev. 5.0 page 23 - assert 5 <= mf['sss_info']['in_order'] <= 11 - assert mf['sss_info']['out_order'] <= 5 - assert mf['sss_info']['nchan'] > len(mf['sss_info']['components']) - - assert (info['ch_names'][:mf['sss_info']['nchan']] == - mf['sss_ctc']['proj_items_chs']) - assert (mf['sss_ctc']['decoupler'].shape == - (mf['sss_info']['nchan'], mf['sss_info']['nchan'])) + assert 5 <= mf["sss_info"]["in_order"] <= 11 + assert mf["sss_info"]["out_order"] <= 5 + assert mf["sss_info"]["nchan"] > len(mf["sss_info"]["components"]) + + assert ( + info["ch_names"][: mf["sss_info"]["nchan"]] == mf["sss_ctc"]["proj_items_chs"] + ) + assert mf["sss_ctc"]["decoupler"].shape == ( + mf["sss_info"]["nchan"], + mf["sss_info"]["nchan"], + ) assert_array_equal( - np.unique(np.diag(mf['sss_ctc']['decoupler'].toarray())), - np.array([1.], dtype=np.float32)) - assert mf['sss_cal']['cal_corrs'].shape == (306, 14) - assert mf['sss_cal']['cal_chans'].shape == (306, 2) - vv_coils = [v for k, v in FIFF.items() if 'FIFFV_COIL_VV' in k] - assert all(k in vv_coils for k in set(mf['sss_cal']['cal_chans'][:, 1])) + np.unique(np.diag(mf["sss_ctc"]["decoupler"].toarray())), + np.array([1.0], dtype=np.float32), + ) + assert mf["sss_cal"]["cal_corrs"].shape == (306, 14) + assert mf["sss_cal"]["cal_chans"].shape == (306, 2) + vv_coils = [v for k, v in FIFF.items() if "FIFFV_COIL_VV" in k] + assert all(k in vv_coils for k in set(mf["sss_cal"]["cal_chans"][:, 1])) diff --git a/mne/io/tests/test_raw.py b/mne/io/tests/test_raw.py index 4c728df90ef..27ac8001b5b 100644 --- a/mne/io/tests/test_raw.py +++ b/mne/io/tests/test_raw.py @@ -14,8 +14,12 @@ import pytest import numpy as np -from numpy.testing import (assert_allclose, assert_array_almost_equal, - assert_array_equal, assert_array_less) +from numpy.testing import ( + assert_allclose, + assert_array_almost_equal, + assert_array_equal, + assert_array_less, +) import mne from mne import concatenate_raws, create_info, Annotations, pick_types @@ -25,50 +29,68 @@ from mne.io._digitization import _dig_kind_dict from mne.io.base import _get_scaling from mne.io.pick import _ELECTRODE_CH_TYPES, _FNIRS_CH_TYPES_SPLIT -from mne.utils import (_TempDir, catch_logging, _raw_annot, _stamp_to_dt, - object_diff, check_version, requires_pandas, - _import_h5io_funcs) +from mne.utils import ( + _TempDir, + catch_logging, + _raw_annot, + _stamp_to_dt, + object_diff, + check_version, + requires_pandas, + _import_h5io_funcs, +) from mne.io.meas_info import _get_valid_units from mne.io._digitization import DigPoint from mne.io.proj import Projection from mne.io.utils import _mult_cal_one from mne.io.constants import FIFF -raw_fname = op.join(op.dirname(__file__), '..', '..', 'io', 'tests', - 'data', 'test_raw.fif') +raw_fname = op.join( + op.dirname(__file__), "..", "..", "io", "tests", "data", "test_raw.fif" +) def assert_named_constants(info): """Assert that info['chs'] has named constants.""" # for now we just check one __tracebackhide__ = True - r = repr(info['chs'][0]) - for check in ('.*FIFFV_COORD_.*', '.*FIFFV_COIL_.*', '.*FIFF_UNIT_.*', - '.*FIFF_UNITM_.*',): + r = repr(info["chs"][0]) + for check in ( + ".*FIFFV_COORD_.*", + ".*FIFFV_COIL_.*", + ".*FIFF_UNIT_.*", + ".*FIFF_UNITM_.*", + ): assert re.match(check, r, re.DOTALL) is not None, (check, r) def test_orig_units(): """Test the error handling for original units.""" # Should work fine - info = create_info(ch_names=['Cz'], sfreq=100, ch_types='eeg') - BaseRaw(info, last_samps=[1], orig_units={'Cz': 'nV'}) + info = create_info(ch_names=["Cz"], sfreq=100, ch_types="eeg") + BaseRaw(info, last_samps=[1], orig_units={"Cz": "nV"}) # Should complain that channel Cz does not have a corresponding original # unit. - with pytest.raises(ValueError, match='has no associated original unit.'): - info = create_info(ch_names=['Cz'], sfreq=100, ch_types='eeg') - BaseRaw(info, last_samps=[1], orig_units={'not_Cz': 'nV'}) + with pytest.raises(ValueError, match="has no associated original unit."): + info = create_info(ch_names=["Cz"], sfreq=100, ch_types="eeg") + BaseRaw(info, last_samps=[1], orig_units={"not_Cz": "nV"}) # Test that a non-dict orig_units argument raises a ValueError - with pytest.raises(ValueError, match='orig_units must be of type dict'): - info = create_info(ch_names=['Cz'], sfreq=100, ch_types='eeg') + with pytest.raises(ValueError, match="orig_units must be of type dict"): + info = create_info(ch_names=["Cz"], sfreq=100, ch_types="eeg") BaseRaw(info, last_samps=[1], orig_units=True) -def _test_raw_reader(reader, test_preloading=True, test_kwargs=True, - boundary_decimal=2, test_scaling=True, test_rank=True, - **kwargs): +def _test_raw_reader( + reader, + test_preloading=True, + test_kwargs=True, + boundary_decimal=2, + test_scaling=True, + test_rank=True, + **kwargs, +): """Test reading, writing and slicing of raw classes. Parameters @@ -95,86 +117,91 @@ def _test_raw_reader(reader, test_preloading=True, test_kwargs=True, rng = np.random.RandomState(0) montage = None if "montage" in kwargs: - montage = kwargs['montage'] - del kwargs['montage'] + montage = kwargs["montage"] + del kwargs["montage"] if test_preloading: raw = reader(preload=True, **kwargs) rep = repr(raw) - assert rep.count('<') == 1 - assert rep.count('>') == 1 + assert rep.count("<") == 1 + assert rep.count(">") == 1 if montage is not None: raw.set_montage(montage) # don't assume the first is preloaded - buffer_fname = op.join(tempdir, 'buffer') + buffer_fname = op.join(tempdir, "buffer") picks = rng.permutation(np.arange(len(raw.ch_names) - 1))[:10] picks = np.append(picks, len(raw.ch_names) - 1) # test trigger channel - bnd = min(int(round(raw.buffer_size_sec * - raw.info['sfreq'])), raw.n_times) - slices = [slice(0, bnd), slice(bnd - 1, bnd), slice(3, bnd), - slice(3, 300), slice(None), slice(1, bnd)] + bnd = min(int(round(raw.buffer_size_sec * raw.info["sfreq"])), raw.n_times) + slices = [ + slice(0, bnd), + slice(bnd - 1, bnd), + slice(3, bnd), + slice(3, 300), + slice(None), + slice(1, bnd), + ] if raw.n_times >= 2 * bnd: # at least two complete blocks - slices += [slice(bnd, 2 * bnd), slice(bnd, bnd + 1), - slice(0, bnd + 100)] - other_raws = [reader(preload=buffer_fname, **kwargs), - reader(preload=False, **kwargs)] + slices += [slice(bnd, 2 * bnd), slice(bnd, bnd + 1), slice(0, bnd + 100)] + other_raws = [ + reader(preload=buffer_fname, **kwargs), + reader(preload=False, **kwargs), + ] for sl_time in slices: data1, times1 = raw[picks, sl_time] for other_raw in other_raws: data2, times2 = other_raw[picks, sl_time] - assert_allclose( - data1, data2, err_msg='Data mismatch with preload') + assert_allclose(data1, data2, err_msg="Data mismatch with preload") assert_allclose(times1, times2) # test projection vs cals and data units other_raw = reader(preload=False, **kwargs) other_raw.del_proj() eeg = meg = fnirs = False - if 'eeg' in raw: + if "eeg" in raw: eeg, atol = True, 1e-18 - elif 'grad' in raw: - meg, atol = 'grad', 1e-24 - elif 'mag' in raw: - meg, atol = 'mag', 1e-24 - elif 'hbo' in raw: - fnirs, atol = 'hbo', 1e-10 - elif 'hbr' in raw: - fnirs, atol = 'hbr', 1e-10 + elif "grad" in raw: + meg, atol = "grad", 1e-24 + elif "mag" in raw: + meg, atol = "mag", 1e-24 + elif "hbo" in raw: + fnirs, atol = "hbo", 1e-10 + elif "hbr" in raw: + fnirs, atol = "hbr", 1e-10 else: - assert 'fnirs_cw_amplitude' in raw, 'New channel type necessary?' - fnirs, atol = 'fnirs_cw_amplitude', 1e-10 - picks = pick_types( - other_raw.info, meg=meg, eeg=eeg, fnirs=fnirs) + assert "fnirs_cw_amplitude" in raw, "New channel type necessary?" + fnirs, atol = "fnirs_cw_amplitude", 1e-10 + picks = pick_types(other_raw.info, meg=meg, eeg=eeg, fnirs=fnirs) col_names = [other_raw.ch_names[pick] for pick in picks] proj = np.ones((1, len(picks))) proj /= np.sqrt(proj.shape[1]) proj = Projection( - data=dict(data=proj, nrow=1, row_names=None, - col_names=col_names, ncol=len(picks)), - active=False) - assert len(other_raw.info['projs']) == 0 + data=dict( + data=proj, nrow=1, row_names=None, col_names=col_names, ncol=len(picks) + ), + active=False, + ) + assert len(other_raw.info["projs"]) == 0 other_raw.add_proj(proj) - assert len(other_raw.info['projs']) == 1 + assert len(other_raw.info["projs"]) == 1 # Orders of projector application, data loading, and reordering # equivalent: # 1. load->apply->get - data_load_apply_get = \ - other_raw.copy().load_data().apply_proj().get_data(picks) + data_load_apply_get = other_raw.copy().load_data().apply_proj().get_data(picks) # 2. apply->get (and don't allow apply->pick) apply = other_raw.copy().apply_proj() data_apply_get = apply.get_data(picks) data_apply_get_0 = apply.get_data(picks[0])[0] - with pytest.raises(RuntimeError, match='loaded'): + with pytest.raises(RuntimeError, match="loaded"): apply.copy().pick(picks[0]).get_data() # 3. apply->load->get data_apply_load_get = apply.copy().load_data().get_data(picks) - data_apply_load_get_0, data_apply_load_get_1 = \ + data_apply_load_get_0, data_apply_load_get_1 = ( apply.copy().load_data().pick(picks[:2]).get_data() + ) # 4. reorder->apply->load->get all_picks = np.arange(len(other_raw.ch_names)) - reord = np.concatenate(( - picks[1::2], - picks[0::2], - np.setdiff1d(all_picks, picks))) + reord = np.concatenate( + (picks[1::2], picks[0::2], np.setdiff1d(all_picks, picks)) + ) rev = np.argsort(reord) assert_array_equal(reord[rev], all_picks) assert_array_equal(rev[reord], all_picks) @@ -185,17 +212,22 @@ def _test_raw_reader(reader, test_preloading=True, test_kwargs=True, reorder_apply = reorder.copy().apply_proj() assert reorder_apply.ch_names == reorder.ch_names assert reorder_apply.ch_names[0] == apply.ch_names[picks[1]] - assert_allclose(reorder_apply.get_data([0]), apply.get_data(picks[1]), - atol=1e-18) - data_reorder_apply_load_get = \ - reorder_apply.load_data().get_data(rev[:len(picks)]) - data_reorder_apply_load_get_1 = \ + assert_allclose( + reorder_apply.get_data([0]), apply.get_data(picks[1]), atol=1e-18 + ) + data_reorder_apply_load_get = reorder_apply.load_data().get_data( + rev[: len(picks)] + ) + data_reorder_apply_load_get_1 = ( reorder_apply.copy().load_data().pick([0]).get_data()[0] + ) assert reorder_apply.ch_names[0] == apply.ch_names[picks[1]] - assert (data_load_apply_get.shape == - data_apply_get.shape == - data_apply_load_get.shape == - data_reorder_apply_load_get.shape) + assert ( + data_load_apply_get.shape + == data_apply_get.shape + == data_apply_load_get.shape + == data_reorder_apply_load_get.shape + ) del apply # first check that our data are (probably) in the right units data = data_load_apply_get.copy() @@ -209,12 +241,12 @@ def _test_raw_reader(reader, test_preloading=True, test_kwargs=True, else: atol = 1e-7 * np.median(data) # 1e-7 * MAD # ranks should all be reduced by 1 - if test_rank == 'less': + if test_rank == "less": cmp = np.less elif test_rank is False: cmp = None else: # anything else is like True or 'equal' - assert test_rank is True or test_rank == 'equal', test_rank + assert test_rank is True or test_rank == "equal", test_rank cmp = np.equal rank_load_apply_get = np.linalg.matrix_rank(data_load_apply_get) rank_apply_get = np.linalg.matrix_rank(data_apply_get) @@ -224,59 +256,54 @@ def _test_raw_reader(reader, test_preloading=True, test_kwargs=True, assert cmp(rank_apply_get, len(col_names) - 1) assert cmp(rank_apply_load_get, len(col_names) - 1) # and they should all match - t_kw = dict( - atol=atol, err_msg='before != after, likely _mult_cal_one prob') + t_kw = dict(atol=atol, err_msg="before != after, likely _mult_cal_one prob") assert_allclose(data_apply_get[0], data_apply_get_0, **t_kw) - assert_allclose(data_apply_load_get_1, - data_reorder_apply_load_get_1, **t_kw) + assert_allclose(data_apply_load_get_1, data_reorder_apply_load_get_1, **t_kw) assert_allclose(data_load_apply_get[0], data_apply_load_get_0, **t_kw) assert_allclose(data_load_apply_get, data_apply_get, **t_kw) assert_allclose(data_load_apply_get, data_apply_load_get, **t_kw) - if 'eeg' in raw: + if "eeg" in raw: other_raw.del_proj() - direct = \ - other_raw.copy().load_data().set_eeg_reference().get_data() + direct = other_raw.copy().load_data().set_eeg_reference().get_data() other_raw.set_eeg_reference(projection=True) - assert len(other_raw.info['projs']) == 1 - this_proj = other_raw.info['projs'][0]['data'] - assert this_proj['col_names'] == col_names - assert this_proj['data'].shape == proj['data']['data'].shape - assert_allclose( - np.linalg.norm(proj['data']['data']), 1., atol=1e-6) - assert_allclose( - np.linalg.norm(this_proj['data']), 1., atol=1e-6) - assert_allclose(this_proj['data'], proj['data']['data']) + assert len(other_raw.info["projs"]) == 1 + this_proj = other_raw.info["projs"][0]["data"] + assert this_proj["col_names"] == col_names + assert this_proj["data"].shape == proj["data"]["data"].shape + assert_allclose(np.linalg.norm(proj["data"]["data"]), 1.0, atol=1e-6) + assert_allclose(np.linalg.norm(this_proj["data"]), 1.0, atol=1e-6) + assert_allclose(this_proj["data"], proj["data"]["data"]) proj = other_raw.apply_proj().get_data() assert_allclose(proj[picks], data_load_apply_get, atol=1e-10) - assert_allclose(proj, direct, atol=1e-10, err_msg=t_kw['err_msg']) + assert_allclose(proj, direct, atol=1e-10, err_msg=t_kw["err_msg"]) else: raw = reader(**kwargs) n_samp = len(raw.times) assert_named_constants(raw.info) # smoke test for gh #9743 - ids = [id(ch['loc']) for ch in raw.info['chs']] + ids = [id(ch["loc"]) for ch in raw.info["chs"]] assert len(set(ids)) == len(ids) full_data = raw._data assert raw.__class__.__name__ in repr(raw) # to test repr assert raw.info.__class__.__name__ in repr(raw.info) - assert isinstance(raw.info['dig'], (type(None), list)) + assert isinstance(raw.info["dig"], (type(None), list)) data_max = full_data.max() data_min = full_data.min() # these limits could be relaxed if we actually find data with # huge values (in SI units) assert data_max < 1e5 assert data_min > -1e5 - if isinstance(raw.info['dig'], list): - for di, d in enumerate(raw.info['dig']): + if isinstance(raw.info["dig"], list): + for di, d in enumerate(raw.info["dig"]): assert isinstance(d, DigPoint), (di, d) # gh-5604 - meas_date = raw.info['meas_date'] + meas_date = raw.info["meas_date"] assert meas_date is None or meas_date >= _stamp_to_dt((0, 0)) # test repr_html - assert 'Good channels' in raw.info._repr_html_() + assert "Good channels" in raw.info._repr_html_() # test resetting raw if test_kwargs: @@ -285,28 +312,29 @@ def _test_raw_reader(reader, test_preloading=True, test_kwargs=True, assert_array_equal(raw.times, raw2.times) # Test saving and reading - out_fname = op.join(tempdir, 'test_raw.fif') + out_fname = op.join(tempdir, "test_raw.fif") raw = concatenate_raws([raw]) raw.save(out_fname, tmax=raw.times[-1], overwrite=True, buffer_size_sec=1) # Test saving with not correct extension - out_fname_h5 = op.join(tempdir, 'test_raw.h5') - with pytest.raises(OSError, match='raw must end with .fif or .fif.gz'): + out_fname_h5 = op.join(tempdir, "test_raw.h5") + with pytest.raises(OSError, match="raw must end with .fif or .fif.gz"): raw.save(out_fname_h5) raw3 = read_raw_fif(out_fname) assert_named_constants(raw3.info) assert set(raw.info.keys()) == set(raw3.info.keys()) - assert_allclose(raw3[0:20][0], full_data[0:20], rtol=1e-6, - atol=1e-20) # atol is very small but > 0 + assert_allclose( + raw3[0:20][0], full_data[0:20], rtol=1e-6, atol=1e-20 + ) # atol is very small but > 0 assert_allclose(raw.times, raw3.times, atol=1e-6, rtol=1e-6) - assert not math.isnan(raw3.info['highpass']) - assert not math.isnan(raw3.info['lowpass']) - assert not math.isnan(raw.info['highpass']) - assert not math.isnan(raw.info['lowpass']) + assert not math.isnan(raw3.info["highpass"]) + assert not math.isnan(raw3.info["lowpass"]) + assert not math.isnan(raw.info["highpass"]) + assert not math.isnan(raw.info["lowpass"]) - assert raw3.info['kit_system_id'] == raw.info['kit_system_id'] + assert raw3.info["kit_system_id"] == raw.info["kit_system_id"] # Make sure concatenation works first_samp = raw.first_samp @@ -315,19 +343,22 @@ def _test_raw_reader(reader, test_preloading=True, test_kwargs=True, assert concat_raw.n_times == 2 * raw.n_times assert concat_raw.first_samp == first_samp assert concat_raw.last_samp - last_samp + first_samp == last_samp + 1 - idx = np.where(concat_raw.annotations.description == 'BAD boundary')[0] + idx = np.where(concat_raw.annotations.description == "BAD boundary")[0] expected_bad_boundary_onset = raw._last_time - assert_array_almost_equal(concat_raw.annotations.onset[idx], - expected_bad_boundary_onset, - decimal=boundary_decimal) + assert_array_almost_equal( + concat_raw.annotations.onset[idx], + expected_bad_boundary_onset, + decimal=boundary_decimal, + ) - if raw.info['meas_id'] is not None: - for key in ['secs', 'usecs', 'version']: - assert raw.info['meas_id'][key] == raw3.info['meas_id'][key] - assert_array_equal(raw.info['meas_id']['machid'], - raw3.info['meas_id']['machid']) + if raw.info["meas_id"] is not None: + for key in ["secs", "usecs", "version"]: + assert raw.info["meas_id"][key] == raw3.info["meas_id"][key] + assert_array_equal( + raw.info["meas_id"]["machid"], raw3.info["meas_id"]["machid"] + ) assert isinstance(raw.annotations, Annotations) @@ -351,14 +382,15 @@ def _test_raw_reader(reader, test_preloading=True, test_kwargs=True, these_kwargs = kwargs.copy() these_kwargs.update(preload_kwarg) # don't use the same filename or it could create problems - if isinstance(these_kwargs.get('preload', None), str) and \ - op.isfile(these_kwargs['preload']): - these_kwargs['preload'] += '-1' + if isinstance(these_kwargs.get("preload", None), str) and op.isfile( + these_kwargs["preload"] + ): + these_kwargs["preload"] += "-1" whole_raw = reader(**these_kwargs) print(whole_raw) # __repr__ assert n_ch >= 2 - picks_1 = picks[:n_ch // 2] - picks_2 = picks[n_ch // 2:] + picks_1 = picks[: n_ch // 2] + picks_2 = picks[n_ch // 2 :] raw_1 = whole_raw.copy().pick(picks_1) raw_2 = whole_raw.copy().pick(picks_2) data, times = whole_raw[:] @@ -366,30 +398,34 @@ def _test_raw_reader(reader, test_preloading=True, test_kwargs=True, data_2, times_2 = raw_2[:] assert_array_equal(times, times_1) assert_array_equal(data[picks_1], data_1) - assert_array_equal(times, times_2,) + assert_array_equal( + times, + times_2, + ) assert_array_equal(data[picks_2], data_2) # Make sure that writing info to h5 format # (all fields should be compatible) - if check_version('h5io'): + if check_version("h5io"): read_hdf5, write_hdf5 = _import_h5io_funcs() - fname_h5 = op.join(tempdir, 'info.h5') + fname_h5 = op.join(tempdir, "info.h5") with _writing_info_hdf5(raw.info), _numpy_h5py_dep(): write_hdf5(fname_h5, raw.info) new_info = Info(read_hdf5(fname_h5)) - assert object_diff(new_info, raw.info) == '' + assert object_diff(new_info, raw.info) == "" # Make sure that changing directory does not break anything if test_preloading: these_kwargs = kwargs.copy() key = None - for key in ('fname', - 'input_fname', # artemis123 - 'vhdr_fname', # BV - 'pdf_fname', # BTi - 'directory', # CTF - 'filename', # nedf - ): + for key in ( + "fname", + "input_fname", # artemis123 + "vhdr_fname", # BV + "pdf_fname", # BTi + "directory", # CTF + "filename", # nedf + ): try: fname = kwargs[key] except KeyError: @@ -402,7 +438,7 @@ def _test_raw_reader(reader, test_preloading=True, test_kwargs=True, this_fname = fname[0] if isinstance(fname, list) else fname dirname = op.dirname(this_fname) these_kwargs[key] = op.basename(this_fname) - these_kwargs['preload'] = False + these_kwargs["preload"] = False orig_dir = os.getcwd() try: os.chdir(dirname) @@ -413,40 +449,42 @@ def _test_raw_reader(reader, test_preloading=True, test_kwargs=True, # make sure that cropping works (with first_samp shift) if n_samp >= 50: # we crop to this number of samples below - for t_prop in (0., 0.5): + for t_prop in (0.0, 0.5): _test_raw_crop(reader, t_prop, kwargs) if test_preloading: use_kwargs = kwargs.copy() - use_kwargs['preload'] = True + use_kwargs["preload"] = True _test_raw_crop(reader, t_prop, use_kwargs) # make sure electrode-like sensor locations show up as dig points - eeg_dig = [d for d in (raw.info['dig'] or []) - if d['kind'] == _dig_kind_dict['eeg']] + eeg_dig = [d for d in (raw.info["dig"] or []) if d["kind"] == _dig_kind_dict["eeg"]] pick_kwargs = dict() - for t in _ELECTRODE_CH_TYPES + ('fnirs',): + for t in _ELECTRODE_CH_TYPES + ("fnirs",): pick_kwargs[t] = True dig_picks = pick_types(raw.info, exclude=(), **pick_kwargs) dig_types = _ELECTRODE_CH_TYPES + _FNIRS_CH_TYPES_SPLIT assert (len(dig_picks) > 0) == any(t in raw for t in dig_types) if len(dig_picks): - eeg_loc = np.array([ # eeg_loc a bit of a misnomer to match eeg_dig - raw.info['chs'][pick]['loc'][:3] for pick in dig_picks]) + eeg_loc = np.array( + [ # eeg_loc a bit of a misnomer to match eeg_dig + raw.info["chs"][pick]["loc"][:3] for pick in dig_picks + ] + ) eeg_loc = eeg_loc[np.isfinite(eeg_loc).all(axis=1)] if len(eeg_loc): - if 'fnirs_cw_amplitude' in raw: + if "fnirs_cw_amplitude" in raw: assert 2 * len(eeg_dig) >= len(eeg_loc) else: assert len(eeg_dig) >= len(eeg_loc) # could have some excluded # make sure that dig points in head coords implies that fiducials are # present - if len(raw.info['dig'] or []) > 0: - card_pts = [d for d in raw.info['dig'] - if d['kind'] == _dig_kind_dict['cardinal']] - eeg_dig_head = [ - d for d in eeg_dig if d['coord_frame'] == FIFF.FIFFV_COORD_HEAD] + if len(raw.info["dig"] or []) > 0: + card_pts = [ + d for d in raw.info["dig"] if d["kind"] == _dig_kind_dict["cardinal"] + ] + eeg_dig_head = [d for d in eeg_dig if d["coord_frame"] == FIFF.FIFFV_COORD_HEAD] if len(eeg_dig_head): - assert len(card_pts) == 3, 'Cardinal points missing' + assert len(card_pts) == 3, "Cardinal points missing" if len(card_pts) == 3: # they should all be in head coords then assert len(eeg_dig_head) == len(eeg_dig) @@ -456,17 +494,17 @@ def _test_raw_reader(reader, test_preloading=True, test_kwargs=True, def _test_raw_crop(reader, t_prop, kwargs): raw_1 = reader(**kwargs) n_samp = 50 # crop to this number of samples (per instance) - crop_t = n_samp / raw_1.info['sfreq'] + crop_t = n_samp / raw_1.info["sfreq"] t_start = t_prop * crop_t # also crop to some fraction into the first inst extra = f' t_start={t_start}, preload={kwargs.get("preload", False)}' - stop = (n_samp - 1) / raw_1.info['sfreq'] + stop = (n_samp - 1) / raw_1.info["sfreq"] raw_1.crop(0, stop) assert len(raw_1.times) == 50 first_time = raw_1.first_time - atol = 0.5 / raw_1.info['sfreq'] + atol = 0.5 / raw_1.info["sfreq"] assert_allclose(raw_1.times[-1], stop, atol=atol) raw_2, raw_3 = raw_1.copy(), raw_1.copy() - t_tot = raw_1.times[-1] * 3 + 2. / raw_1.info['sfreq'] + t_tot = raw_1.times[-1] * 3 + 2.0 / raw_1.info["sfreq"] raw_concat = concatenate_raws([raw_1, raw_2, raw_3]) assert len(raw_concat._filenames) == 3 assert_allclose(raw_concat.times[-1], t_tot) @@ -476,24 +514,31 @@ def _test_raw_crop(reader, t_prop, kwargs): assert len(raw_concat._filenames) == 3 assert_allclose(raw_concat.times[-1], t_tot - t_start, atol=atol) assert_allclose( - raw_concat.first_time, first_time + t_start, atol=atol, - err_msg=f'Base concat, {extra}') + raw_concat.first_time, + first_time + t_start, + atol=atol, + err_msg=f"Base concat, {extra}", + ) # drop the first instance raw_concat.crop(crop_t, None) assert len(raw_concat._filenames) == 2 + assert_allclose(raw_concat.times[-1], t_tot - t_start - crop_t, atol=atol) assert_allclose( - raw_concat.times[-1], t_tot - t_start - crop_t, atol=atol) - assert_allclose( - raw_concat.first_time, first_time + t_start + crop_t, - atol=atol, err_msg=f'Dropping one, {extra}') + raw_concat.first_time, + first_time + t_start + crop_t, + atol=atol, + err_msg=f"Dropping one, {extra}", + ) # drop the second instance, leaving just one raw_concat.crop(crop_t, None) assert len(raw_concat._filenames) == 1 + assert_allclose(raw_concat.times[-1], t_tot - t_start - 2 * crop_t, atol=atol) assert_allclose( - raw_concat.times[-1], t_tot - t_start - 2 * crop_t, atol=atol) - assert_allclose( - raw_concat.first_time, first_time + t_start + 2 * crop_t, - atol=atol, err_msg=f'Dropping two, {extra}') + raw_concat.first_time, + first_time + t_start + 2 * crop_t, + atol=atol, + err_msg=f"Dropping two, {extra}", + ) def _test_concat(reader, *args): @@ -514,8 +559,7 @@ def _test_concat(reader, *args): data = raw[:, :][0] for preloads in ((True, True), (True, False), (False, False)): for last_preload in (True, False): - t_crops = raw.times[np.argmin(np.abs(raw.times - 0.5)) + - [0, 1]] + t_crops = raw.times[np.argmin(np.abs(raw.times - 0.5)) + [0, 1]] raw1 = raw.copy().crop(0, t_crops[0]) if preloads[0]: raw1.load_data() @@ -542,8 +586,8 @@ def test_time_as_index(): assert_array_equal(new_inds, np.arange(len(raw.times))) -@pytest.mark.parametrize('meas_date', [None, "orig"]) -@pytest.mark.parametrize('first_samp', [0, 10000]) +@pytest.mark.parametrize("meas_date", [None, "orig"]) +@pytest.mark.parametrize("first_samp", [0, 10000]) def test_crop_by_annotations(meas_date, first_samp): """Test crop by annotations of raw.""" raw = read_raw_fif(raw_fname) @@ -560,7 +604,8 @@ def test_crop_by_annotations(meas_date, first_samp): onset=onset, duration=[1, 0.5], description=["a", "b"], - orig_time=raw.info['meas_date']) + orig_time=raw.info["meas_date"], + ) raw.set_annotations(annot) raws = raw.crop_by_annotations() @@ -573,21 +618,23 @@ def test_crop_by_annotations(meas_date, first_samp): assert raws[1].annotations.description[0] == annot.description[1] -@pytest.mark.parametrize('offset, origin', [ - pytest.param(0, None, id='times in s. relative to first_samp (default)'), - pytest.param(0, 2.0, id='times in s. relative to first_samp'), - pytest.param(1, 1.0, id='times in s. relative to meas_date'), - pytest.param(2, 0.0, id='absolute times in s. relative to 0')]) +@pytest.mark.parametrize( + "offset, origin", + [ + pytest.param(0, None, id="times in s. relative to first_samp (default)"), + pytest.param(0, 2.0, id="times in s. relative to first_samp"), + pytest.param(1, 1.0, id="times in s. relative to meas_date"), + pytest.param(2, 0.0, id="absolute times in s. relative to 0"), + ], +) def test_time_as_index_ref(offset, origin): """Test indexing of raw times.""" - info = create_info(ch_names=10, sfreq=10.) + info = create_info(ch_names=10, sfreq=10.0) raw = RawArray(data=np.empty((10, 10)), info=info, first_samp=10) raw.set_meas_date(1) relative_times = raw.times - inds = raw.time_as_index(relative_times + offset, - use_rounding=True, - origin=origin) + inds = raw.time_as_index(relative_times + offset, use_rounding=True, origin=origin) assert_array_equal(inds, np.arange(raw.n_times)) @@ -611,7 +658,7 @@ def test_meas_date_orig_time(): # Raise error, it makes no sense to have an annotations object that we know # when was acquired and set it to a raw object that does not know when was # it acquired. - with pytest.raises(RuntimeError, match='Ambiguous operation'): + with pytest.raises(RuntimeError, match="Ambiguous operation"): _raw_annot(None, 1.5) # meas_time is None and orig_time is None: @@ -628,19 +675,22 @@ def test_get_data_reject(): ch_names = ["C3", "Cz", "C4"] info = create_info(ch_names, sfreq=fs) raw = RawArray(np.zeros((len(ch_names), 10 * fs)), info) - raw.set_annotations(Annotations(onset=[2, 4], duration=[3, 2], - description="bad")) + raw.set_annotations(Annotations(onset=[2, 4], duration=[3, 2], description="bad")) with catch_logging() as log: data = raw.get_data(reject_by_annotation="omit", verbose=True) - msg = ('Omitting 1024 of 2560 (40.00%) samples, retaining 1536' + - ' (60.00%) samples.') + msg = ( + "Omitting 1024 of 2560 (40.00%) samples, retaining 1536" + + " (60.00%) samples." + ) assert log.getvalue().strip() == msg assert data.shape == (len(ch_names), 1536) with catch_logging() as log: data = raw.get_data(reject_by_annotation="nan", verbose=True) - msg = ('Setting 1024 of 2560 (40.00%) samples to NaN, retaining 1536' + - ' (60.00%) samples.') + msg = ( + "Setting 1024 of 2560 (40.00%) samples to NaN, retaining 1536" + + " (60.00%) samples." + ) assert log.getvalue().strip() == msg assert data.shape == (len(ch_names), 2560) # shape doesn't change assert np.isnan(data).sum() == 3072 # but NaNs are introduced instead @@ -665,19 +715,22 @@ def test_5839(): # latency . 0 0 1 1 2 2 3 # . 5 0 5 0 5 0 # - EXPECTED_ONSET = [1.5, 2., 2., 2.5] - EXPECTED_DURATION = [0.2, 0., 0., 0.2] - EXPECTED_DESCRIPTION = ['dummy', 'BAD boundary', 'EDGE boundary', 'dummy'] + EXPECTED_ONSET = [1.5, 2.0, 2.0, 2.5] + EXPECTED_DURATION = [0.2, 0.0, 0.0, 0.2] + EXPECTED_DESCRIPTION = ["dummy", "BAD boundary", "EDGE boundary", "dummy"] def raw_factory(meas_date): - raw = RawArray(data=np.empty((10, 10)), - info=create_info(ch_names=10, sfreq=10.), - first_samp=10) + raw = RawArray( + data=np.empty((10, 10)), + info=create_info(ch_names=10, sfreq=10.0), + first_samp=10, + ) raw.set_meas_date(meas_date) - raw.set_annotations(annotations=Annotations(onset=[.5], - duration=[.2], - description='dummy', - orig_time=None)) + raw.set_annotations( + annotations=Annotations( + onset=[0.5], duration=[0.2], description="dummy", orig_time=None + ) + ) return raw raw_A, raw_B = [raw_factory((x, 0)) for x in [0, 2]] @@ -695,16 +748,16 @@ def test_repr(): info = create_info(3, sfreq) raw = RawArray(np.zeros((3, 10 * sfreq)), info) r = repr(raw) - assert re.search('', - r) is not None, r + assert ( + re.search("", r) is not None + ), r assert raw._repr_html_() # A class that sets channel data to np.arange, for testing _test_raw_reader class _RawArange(BaseRaw): - def __init__(self, preload=False, verbose=None): - info = create_info(list(str(x) for x in range(1, 9)), 1000., 'eeg') + info = create_info(list(str(x) for x in range(1, 9)), 1000.0, "eeg") super().__init__(info, preload, last_samps=(999,), verbose=verbose) assert len(self.times) == 1000 @@ -720,7 +773,7 @@ def _read_raw_arange(preload=False, verbose=None): def test_test_raw_reader(): """Test _test_raw_reader.""" - _test_raw_reader(_read_raw_arange, test_scaling=False, test_rank='less') + _test_raw_reader(_read_raw_arange, test_scaling=False, test_rank="less") @pytest.mark.slowtest @@ -736,17 +789,25 @@ def test_describe_print(): s = f.getvalue().strip().split("\n") assert len(s) == 378 # Can be 3.1, 3.3, etc. - assert re.match( - r'', s[0]) is not None, s[0] assert ( - s[1] == " ch name type unit min Q1 median Q3 max" # noqa: E501 + re.match( + r"", + s[0], + ) + is not None + ), s[0] + assert ( + s[1] + == " ch name type unit min Q1 median Q3 max" # noqa: E501 ) assert ( - s[2] == " 0 MEG 0113 GRAD fT/cm -221.80 -38.57 -9.64 19.29 414.67" # noqa: E501 + s[2] + == " 0 MEG 0113 GRAD fT/cm -221.80 -38.57 -9.64 19.29 414.67" # noqa: E501 ) assert ( - s[-1] == "375 EOG 061 EOG µV -231.41 271.28 277.16 285.66 334.69" # noqa: E501 + s[-1] + == "375 EOG 061 EOG µV -231.41 271.28 277.16 285.66 334.69" # noqa: E501 ) @@ -759,36 +820,50 @@ def test_describe_df(): df = raw.describe(data_frame=True) assert df.shape == (376, 8) - assert (df.columns.tolist() == ["name", "type", "unit", "min", "Q1", - "median", "Q3", "max"]) + assert df.columns.tolist() == [ + "name", + "type", + "unit", + "min", + "Q1", + "median", + "Q3", + "max", + ] assert df.index.name == "ch" - assert_allclose(df.iloc[0, 3:].astype(float), - np.array([-2.218017605790535e-11, - -3.857421923113974e-12, - -9.643554807784935e-13, - 1.928710961556987e-12, - 4.146728567347522e-11])) + assert_allclose( + df.iloc[0, 3:].astype(float), + np.array( + [ + -2.218017605790535e-11, + -3.857421923113974e-12, + -9.643554807784935e-13, + 1.928710961556987e-12, + 4.146728567347522e-11, + ] + ), + ) def test_get_data_units(): """Test the "units" argument of get_data method.""" # Test the unit conversion function - assert _get_scaling('eeg', 'uV') == 1e6 - assert _get_scaling('eeg', 'dV') == 1e1 - assert _get_scaling('eeg', 'pV') == 1e12 - assert _get_scaling('mag', 'fT') == 1e15 - assert _get_scaling('grad', 'T/m') == 1 - assert _get_scaling('grad', 'T/mm') == 1e-3 - assert _get_scaling('grad', 'fT/m') == 1e15 - assert _get_scaling('grad', 'fT/cm') == 1e13 - assert _get_scaling('csd', 'uV/cm²') == 1e2 + assert _get_scaling("eeg", "uV") == 1e6 + assert _get_scaling("eeg", "dV") == 1e1 + assert _get_scaling("eeg", "pV") == 1e12 + assert _get_scaling("mag", "fT") == 1e15 + assert _get_scaling("grad", "T/m") == 1 + assert _get_scaling("grad", "T/mm") == 1e-3 + assert _get_scaling("grad", "fT/m") == 1e15 + assert _get_scaling("grad", "fT/cm") == 1e13 + assert _get_scaling("csd", "uV/cm²") == 1e2 fname = Path(__file__).parent / "data" / "test_raw.fif" raw = read_raw_fif(fname) last = np.array([4.63803098e-05, 7.66563736e-05, 2.71933595e-04]) last_eeg = np.array([7.12207023e-05, 4.63803098e-05, 7.66563736e-05]) - last_grad = np.array([-3.85742192e-12, 9.64355481e-13, -1.06079103e-11]) + last_grad = np.array([-3.85742192e-12, 9.64355481e-13, -1.06079103e-11]) # None data_none = raw.get_data() @@ -796,62 +871,68 @@ def test_get_data_units(): assert_array_almost_equal(data_none[-3:, -1], last) # str: unit no conversion - data_str_noconv = raw.get_data(picks=['eeg'], units='V') + data_str_noconv = raw.get_data(picks=["eeg"], units="V") assert data_str_noconv.shape == (60, 14400) assert_array_almost_equal(data_str_noconv[-3:, -1], last_eeg) # str: simple unit - data_str_simple = raw.get_data(picks=['eeg'], units='uV') + data_str_simple = raw.get_data(picks=["eeg"], units="uV") assert data_str_simple.shape == (60, 14400) assert_array_almost_equal(data_str_simple[-3:, -1], last_eeg * 1e6) # str: fraction unit - data_str_fraction = raw.get_data(picks=['grad'], units='fT/cm') + data_str_fraction = raw.get_data(picks=["grad"], units="fT/cm") assert data_str_fraction.shape == (204, 14400) - assert_array_almost_equal(data_str_fraction[-3:, -1], - last_grad * (1e15 / 1e2)) + assert_array_almost_equal(data_str_fraction[-3:, -1], last_grad * (1e15 / 1e2)) # str: more than one channel type but one with unit - data_str_simplestim = raw.get_data(picks=['eeg', 'stim'], units='V') + data_str_simplestim = raw.get_data(picks=["eeg", "stim"], units="V") assert data_str_simplestim.shape == (69, 14400) assert_array_almost_equal(data_str_simplestim[-3:, -1], last_eeg) # str: too many channels - with pytest.raises(ValueError, match='more than one channel'): - raw.get_data(units='uV') + with pytest.raises(ValueError, match="more than one channel"): + raw.get_data(units="uV") # str: invalid unit - with pytest.raises(ValueError, match='is not a valid unit'): - raw.get_data(picks=['eeg'], units='fV/cm') + with pytest.raises(ValueError, match="is not a valid unit"): + raw.get_data(picks=["eeg"], units="fV/cm") # dict: combination of simple and fraction units - data_dict = raw.get_data(units=dict(grad='fT/cm', mag='fT', eeg='uV')) + data_dict = raw.get_data(units=dict(grad="fT/cm", mag="fT", eeg="uV")) assert data_dict.shape == (376, 14400) - assert_array_almost_equal(data_dict[0, -1], - -3.857421923113974e-12 * (1e15 / 1e2)) + assert_array_almost_equal(data_dict[0, -1], -3.857421923113974e-12 * (1e15 / 1e2)) assert_array_almost_equal(data_dict[2, -1], -2.1478272253525944e-13 * 1e15) assert_array_almost_equal(data_dict[-2, -1], 7.665637356879529e-05 * 1e6) # dict: channel type not in instance - data_dict_notin = raw.get_data(units=dict(hbo='uM')) + data_dict_notin = raw.get_data(units=dict(hbo="uM")) assert data_dict_notin.shape == (376, 14400) assert_array_almost_equal(data_dict_notin[-3:, -1], last) # dict: one invalid unit - with pytest.raises(ValueError, match='is not a valid unit'): - raw.get_data(units=dict(grad='fT/cV', mag='fT', eeg='uV')) + with pytest.raises(ValueError, match="is not a valid unit"): + raw.get_data(units=dict(grad="fT/cV", mag="fT", eeg="uV")) # dict: one invalid channel type - with pytest.raises(KeyError, match='is not a channel type'): - raw.get_data(units=dict(bad_type='fT/cV', mag='fT', eeg='uV')) + with pytest.raises(KeyError, match="is not a channel type"): + raw.get_data(units=dict(bad_type="fT/cV", mag="fT", eeg="uV")) # not the good type - with pytest.raises(TypeError, match='instance of None, str, or dict'): - raw.get_data(units=['fT/cm', 'fT', 'uV']) + with pytest.raises(TypeError, match="instance of None, str, or dict"): + raw.get_data(units=["fT/cm", "fT", "uV"]) def test_repr_dig_point(): """Test printing of DigPoint.""" - dp = DigPoint(r=np.arange(3), coord_frame=FIFF.FIFFV_COORD_HEAD, - kind=FIFF.FIFFV_POINT_EEG, ident=0) - assert 'mm' in repr(dp) + dp = DigPoint( + r=np.arange(3), + coord_frame=FIFF.FIFFV_COORD_HEAD, + kind=FIFF.FIFFV_POINT_EEG, + ident=0, + ) + assert "mm" in repr(dp) - dp = DigPoint(r=np.arange(3), coord_frame=FIFF.FIFFV_MNE_COORD_MRI_VOXEL, - kind=FIFF.FIFFV_POINT_CARDINAL, ident=0) - assert 'mm' not in repr(dp) - assert 'voxel' in repr(dp) + dp = DigPoint( + r=np.arange(3), + coord_frame=FIFF.FIFFV_MNE_COORD_MRI_VOXEL, + kind=FIFF.FIFFV_POINT_CARDINAL, + ident=0, + ) + assert "mm" not in repr(dp) + assert "voxel" in repr(dp) def test_get_data_tmin_tmax(): @@ -865,7 +946,7 @@ def test_get_data_tmin_tmax(): d2 = raw.get_data(tmin=tmin, tmax=tmax) idxs = raw.time_as_index([tmin, tmax]) - assert_allclose(d1[:, idxs[0]:idxs[1]], d2) + assert_allclose(d1[:, idxs[0] : idxs[1]], d2) # specifying a too low tmin truncates to idx 0 d3 = raw.get_data(tmin=-5) @@ -880,14 +961,14 @@ def test_get_data_tmin_tmax(): assert d5.shape[1] == 1 # validate inputs are properly raised - with pytest.raises(TypeError, match='start must be .* int'): + with pytest.raises(TypeError, match="start must be .* int"): raw.get_data(start=None) - with pytest.raises(TypeError, match='stop must be .* int'): + with pytest.raises(TypeError, match="stop must be .* int"): raw.get_data(stop=2.3) - with pytest.raises(TypeError, match='tmin must be .* float'): + with pytest.raises(TypeError, match="tmin must be .* float"): raw.get_data(tmin=[1, 2]) - with pytest.raises(TypeError, match='tmax must be .* float'): + with pytest.raises(TypeError, match="tmax must be .* float"): raw.get_data(tmax=[1, 2]) diff --git a/mne/io/tests/test_read_raw.py b/mne/io/tests/test_read_raw.py index 13c696f0f17..1a901fd211c 100644 --- a/mne/io/tests/test_read_raw.py +++ b/mne/io/tests/test_read_raw.py @@ -18,52 +18,57 @@ test_base = Path(testing.data_path(download=False)) -@pytest.mark.parametrize('fname', ['x.xxx', 'x']) +@pytest.mark.parametrize("fname", ["x.xxx", "x"]) def test_read_raw_unsupported_single(fname): """Test handling of unsupported file types.""" - with pytest.raises(ValueError, match='Unsupported file type'): + with pytest.raises(ValueError, match="Unsupported file type"): read_raw(fname) -@pytest.mark.parametrize('fname', ['x.bin']) +@pytest.mark.parametrize("fname", ["x.bin"]) def test_read_raw_unsupported_multi(fname, tmp_path): """Test handling of supported file types but with bad data.""" fname = tmp_path / fname - fname.write_text('') - with pytest.raises(RuntimeError, match='Could not read.*using any'): + fname.write_text("") + with pytest.raises(RuntimeError, match="Could not read.*using any"): read_raw(fname) -@pytest.mark.parametrize('fname', ['x.vmrk', 'y.amrk']) +@pytest.mark.parametrize("fname", ["x.vmrk", "y.amrk"]) def test_read_raw_suggested(fname): """Test handling of unsupported file types with suggested alternatives.""" - with pytest.raises(ValueError, match='Try reading'): + with pytest.raises(ValueError, match="Try reading"): read_raw(fname) _testing_mark = testing._pytest_mark() -@pytest.mark.parametrize('fname', [ - base / 'tests/data/test_raw.fif', - base / 'tests/data/test_raw.fif.gz', - base / 'edf/tests/data/test.edf', - base / 'edf/tests/data/test.bdf', - base / 'brainvision/tests/data/test.vhdr', - base / 'kit/tests/data/test.sqd', - pytest.param(test_base / 'KIT' / 'data_berlin.con', marks=_testing_mark), - pytest.param( - test_base / 'ARTEMIS123' / - 'Artemis_Data_2017-04-14-10h-38m-59s_Phantom_1k_HPI_1s.bin', - marks=_testing_mark), - pytest.param( - test_base / 'FIL' / - 'sub-noise_ses-001_task-noise220622_run-001_meg.bin', - marks=( - _testing_mark, - pytest.mark.filterwarnings( - 'ignore:.*problems later!:RuntimeWarning'))), -]) +@pytest.mark.parametrize( + "fname", + [ + base / "tests/data/test_raw.fif", + base / "tests/data/test_raw.fif.gz", + base / "edf/tests/data/test.edf", + base / "edf/tests/data/test.bdf", + base / "brainvision/tests/data/test.vhdr", + base / "kit/tests/data/test.sqd", + pytest.param(test_base / "KIT" / "data_berlin.con", marks=_testing_mark), + pytest.param( + test_base + / "ARTEMIS123" + / "Artemis_Data_2017-04-14-10h-38m-59s_Phantom_1k_HPI_1s.bin", + marks=_testing_mark, + ), + pytest.param( + test_base / "FIL" / "sub-noise_ses-001_task-noise220622_run-001_meg.bin", + marks=( + _testing_mark, + pytest.mark.filterwarnings("ignore:.*problems later!:RuntimeWarning"), + ), + ), + ], +) def test_read_raw_supported(fname): """Test supported file types.""" read_raw(fname) @@ -85,7 +90,7 @@ def test_split_name_ext(): def test_read_raw_multiple_dots(tmp_path): """Test if file names with multiple dots work correctly.""" - src = base / 'edf/tests/data/test.edf' + src = base / "edf/tests/data/test.edf" dst = tmp_path / "test.this.file.edf" copyfile(src, dst) read_raw(dst) diff --git a/mne/io/tests/test_reference.py b/mne/io/tests/test_reference.py index 0cfb2a5349e..b6ab8af2515 100644 --- a/mne/io/tests/test_reference.py +++ b/mne/io/tests/test_reference.py @@ -12,12 +12,22 @@ from numpy.testing import assert_array_equal, assert_allclose, assert_equal import pytest -from mne import (pick_channels, pick_types, Epochs, read_events, - set_eeg_reference, set_bipolar_reference, - add_reference_channels, create_info, make_sphere_model, - make_forward_solution, setup_volume_source_space, - pick_channels_forward, read_evokeds, - find_events) +from mne import ( + pick_channels, + pick_types, + Epochs, + read_events, + set_eeg_reference, + set_bipolar_reference, + add_reference_channels, + create_info, + make_sphere_model, + make_forward_solution, + setup_volume_source_space, + pick_channels_forward, + read_evokeds, + find_events, +) from mne.epochs import BaseEpochs, make_fixed_length_epochs from mne.io import RawArray, read_raw_fif from mne.io.constants import FIFF @@ -37,9 +47,10 @@ def _test_reference(raw, reref, ref_data, ref_from): """Test whether a reference has been correctly applied.""" # Separate EEG channels from other channel types - picks_eeg = pick_types(raw.info, meg=False, eeg=True, exclude='bads') - picks_other = pick_types(raw.info, meg=True, eeg=False, eog=True, - stim=True, exclude='bads') + picks_eeg = pick_types(raw.info, meg=False, eeg=True, exclude="bads") + picks_other = pick_types( + raw.info, meg=True, eeg=False, eog=True, stim=True, exclude="bads" + ) # Calculate indices of reference channesl picks_ref = [raw.ch_names.index(ch) for ch in ref_from] @@ -78,16 +89,15 @@ def test_apply_reference(): raw = read_raw_fif(fif_fname, preload=True) # Rereference raw data by creating a copy of original data - reref, ref_data = _apply_reference( - raw.copy(), ref_from=['EEG 001', 'EEG 002']) - assert reref.info['custom_ref_applied'] - _test_reference(raw, reref, ref_data, ['EEG 001', 'EEG 002']) + reref, ref_data = _apply_reference(raw.copy(), ref_from=["EEG 001", "EEG 002"]) + assert reref.info["custom_ref_applied"] + _test_reference(raw, reref, ref_data, ["EEG 001", "EEG 002"]) # The CAR reference projection should have been removed by the function assert not _has_eeg_average_ref_proj(reref.info) # Test that data is modified in place when copy=False - reref, ref_data = _apply_reference(raw, ['EEG 001', 'EEG 002']) + reref, ref_data = _apply_reference(raw, ["EEG 001", "EEG 002"]) assert raw is reref # Test that disabling the reference does not change anything @@ -98,23 +108,28 @@ def test_apply_reference(): raw = read_raw_fif(fif_fname, preload=False) events = read_events(eve_fname) picks_eeg = pick_types(raw.info, meg=False, eeg=True) - epochs = Epochs(raw, events=events, event_id=1, tmin=-0.2, tmax=0.5, - picks=picks_eeg, preload=True) - reref, ref_data = _apply_reference( - epochs.copy(), ref_from=['EEG 001', 'EEG 002']) - assert reref.info['custom_ref_applied'] - _test_reference(epochs, reref, ref_data, ['EEG 001', 'EEG 002']) + epochs = Epochs( + raw, + events=events, + event_id=1, + tmin=-0.2, + tmax=0.5, + picks=picks_eeg, + preload=True, + ) + reref, ref_data = _apply_reference(epochs.copy(), ref_from=["EEG 001", "EEG 002"]) + assert reref.info["custom_ref_applied"] + _test_reference(epochs, reref, ref_data, ["EEG 001", "EEG 002"]) # Test re-referencing Evoked object evoked = epochs.average() - reref, ref_data = _apply_reference( - evoked.copy(), ref_from=['EEG 001', 'EEG 002']) - assert reref.info['custom_ref_applied'] - _test_reference(evoked, reref, ref_data, ['EEG 001', 'EEG 002']) + reref, ref_data = _apply_reference(evoked.copy(), ref_from=["EEG 001", "EEG 002"]) + assert reref.info["custom_ref_applied"] + _test_reference(evoked, reref, ref_data, ["EEG 001", "EEG 002"]) # Referencing needs data to be preloaded raw_np = read_raw_fif(fif_fname, preload=False) - pytest.raises(RuntimeError, _apply_reference, raw_np, ['EEG 001']) + pytest.raises(RuntimeError, _apply_reference, raw_np, ["EEG 001"]) # Test having inactive SSP projections that deal with channels involved # during re-referencing @@ -123,26 +138,26 @@ def test_apply_reference(): Projection( active=False, data=dict( - col_names=['EEG 001', 'EEG 002'], + col_names=["EEG 001", "EEG 002"], row_names=None, data=np.array([[1, 1]]), ncol=2, - nrow=1 + nrow=1, ), - desc='test', + desc="test", kind=1, ) ) # Projection concerns channels mentioned in projector - with pytest.raises(RuntimeError, match='Inactive signal space'): - _apply_reference(raw, ['EEG 001']) + with pytest.raises(RuntimeError, match="Inactive signal space"): + _apply_reference(raw, ["EEG 001"]) # Projection does not concern channels mentioned in projector, no error - _apply_reference(raw, ['EEG 003'], ['EEG 004']) + _apply_reference(raw, ["EEG 003"], ["EEG 004"]) # CSD cannot be rereferenced with raw.info._unlock(): - raw.info['custom_ref_applied'] = FIFF.FIFFV_MNE_CUSTOM_REF_CSD + raw.info["custom_ref_applied"] = FIFF.FIFFV_MNE_CUSTOM_REF_CSD with pytest.raises(RuntimeError, match="Cannot set.* type 'CSD'"): raw.set_eeg_reference() @@ -152,66 +167,65 @@ def test_set_eeg_reference(): """Test rereference eeg data.""" raw = read_raw_fif(fif_fname, preload=True) with raw.info._unlock(): - raw.info['projs'] = [] + raw.info["projs"] = [] # Test setting an average reference projection assert not _has_eeg_average_ref_proj(raw.info) reref, ref_data = set_eeg_reference(raw, projection=True) assert _has_eeg_average_ref_proj(reref.info) - assert not reref.info['projs'][0]['active'] + assert not reref.info["projs"][0]["active"] assert ref_data is None reref.apply_proj() - eeg_chans = [raw.ch_names[ch] - for ch in pick_types(raw.info, meg=False, eeg=True)] - _test_reference(raw, reref, ref_data, - [ch for ch in eeg_chans if ch not in raw.info['bads']]) + eeg_chans = [raw.ch_names[ch] for ch in pick_types(raw.info, meg=False, eeg=True)] + _test_reference( + raw, reref, ref_data, [ch for ch in eeg_chans if ch not in raw.info["bads"]] + ) # Test setting an average reference when one was already present - with pytest.warns(RuntimeWarning, match='untouched'): + with pytest.warns(RuntimeWarning, match="untouched"): reref, ref_data = set_eeg_reference(raw, copy=False, projection=True) assert ref_data is None # Test setting an average reference on non-preloaded data raw_nopreload = read_raw_fif(fif_fname, preload=False) with raw_nopreload.info._unlock(): - raw_nopreload.info['projs'] = [] + raw_nopreload.info["projs"] = [] reref, ref_data = set_eeg_reference(raw_nopreload, projection=True) assert _has_eeg_average_ref_proj(reref.info) - assert not reref.info['projs'][0]['active'] + assert not reref.info["projs"][0]["active"] # Rereference raw data by creating a copy of original data - reref, ref_data = set_eeg_reference(raw, ['EEG 001', 'EEG 002'], copy=True) - assert reref.info['custom_ref_applied'] - _test_reference(raw, reref, ref_data, ['EEG 001', 'EEG 002']) + reref, ref_data = set_eeg_reference(raw, ["EEG 001", "EEG 002"], copy=True) + assert reref.info["custom_ref_applied"] + _test_reference(raw, reref, ref_data, ["EEG 001", "EEG 002"]) # Test that data is modified in place when copy=False - reref, ref_data = set_eeg_reference(raw, ['EEG 001', 'EEG 002'], - copy=False) + reref, ref_data = set_eeg_reference(raw, ["EEG 001", "EEG 002"], copy=False) assert raw is reref # Test moving from custom to average reference - reref, ref_data = set_eeg_reference(raw, ['EEG 001', 'EEG 002']) + reref, ref_data = set_eeg_reference(raw, ["EEG 001", "EEG 002"]) reref, _ = set_eeg_reference(reref, projection=True) assert _has_eeg_average_ref_proj(reref.info) - assert not reref.info['custom_ref_applied'] + assert not reref.info["custom_ref_applied"] # When creating an average reference fails, make sure the # custom_ref_applied flag remains untouched. reref = raw.copy() with reref.info._unlock(): - reref.info['custom_ref_applied'] = FIFF.FIFFV_MNE_CUSTOM_REF_ON + reref.info["custom_ref_applied"] = FIFF.FIFFV_MNE_CUSTOM_REF_ON reref.pick_types(meg=True, eeg=False) # Cause making average ref fail # should have turned it off - assert reref.info['custom_ref_applied'] == FIFF.FIFFV_MNE_CUSTOM_REF_OFF - with pytest.raises(ValueError, match='found to rereference'): + assert reref.info["custom_ref_applied"] == FIFF.FIFFV_MNE_CUSTOM_REF_OFF + with pytest.raises(ValueError, match="found to rereference"): set_eeg_reference(reref, projection=True) # Test moving from average to custom reference reref, ref_data = set_eeg_reference(raw, projection=True) - reref, _ = set_eeg_reference(reref, ['EEG 001', 'EEG 002']) + reref, _ = set_eeg_reference(reref, ["EEG 001", "EEG 002"]) assert not _has_eeg_average_ref_proj(reref.info) - assert len(reref.info['projs']) == 0 - assert reref.info['custom_ref_applied'] == FIFF.FIFFV_MNE_CUSTOM_REF_ON + assert len(reref.info["projs"]) == 0 + assert reref.info["custom_ref_applied"] == FIFF.FIFFV_MNE_CUSTOM_REF_ON # Test that disabling the reference does not change the data assert _has_eeg_average_ref_proj(raw.info) @@ -227,100 +241,109 @@ def test_set_eeg_reference(): # Test that average reference gives identical results when calculated # via SSP projection (projection=True) or directly (projection=False) with raw.info._unlock(): - raw.info['projs'] = [] + raw.info["projs"] = [] reref_1, _ = set_eeg_reference(raw.copy(), projection=True) reref_1.apply_proj() reref_2, _ = set_eeg_reference(raw.copy(), projection=False) assert_allclose(reref_1._data, reref_2._data, rtol=1e-6, atol=1e-15) # Test average reference without projection - reref, ref_data = set_eeg_reference(raw.copy(), ref_channels="average", - projection=False) + reref, ref_data = set_eeg_reference( + raw.copy(), ref_channels="average", projection=False + ) _test_reference(raw, reref, ref_data, eeg_chans) with pytest.raises(ValueError, match='supported for ref_channels="averag'): set_eeg_reference(raw, [], True, True) with pytest.raises(ValueError, match='supported for ref_channels="averag'): - set_eeg_reference(raw, ['EEG 001'], True, True) - - -@pytest.mark.parametrize('ch_type, msg', - [('auto', ('ECoG',)), - ('ecog', ('ECoG',)), - ('dbs', ('DBS',)), - (['ecog', 'dbs'], ('ECoG', 'DBS'))]) -@pytest.mark.parametrize('projection', [False, True]) + set_eeg_reference(raw, ["EEG 001"], True, True) + + +@pytest.mark.parametrize( + "ch_type, msg", + [ + ("auto", ("ECoG",)), + ("ecog", ("ECoG",)), + ("dbs", ("DBS",)), + (["ecog", "dbs"], ("ECoG", "DBS")), + ], +) +@pytest.mark.parametrize("projection", [False, True]) def test_set_eeg_reference_ch_type(ch_type, msg, projection): """Test setting EEG reference for ECoG or DBS.""" # gh-6454 # gh-8739 added DBS - ch_names = ['ECOG01', 'ECOG02', 'DBS01', 'DBS02', 'MISC'] + ch_names = ["ECOG01", "ECOG02", "DBS01", "DBS02", "MISC"] rng = np.random.RandomState(0) data = rng.randn(5, 1000) - raw = RawArray(data, create_info(ch_names, 1000., ['ecog'] * 2 - + ['dbs'] * 2 + ['misc'])) + raw = RawArray( + data, create_info(ch_names, 1000.0, ["ecog"] * 2 + ["dbs"] * 2 + ["misc"]) + ) - if ch_type == 'auto': + if ch_type == "auto": ref_ch = ch_names[:2] else: ref_ch = raw.copy().pick(picks=ch_type).ch_names with catch_logging() as log: - reref, ref_data = set_eeg_reference(raw.copy(), ch_type=ch_type, - projection=projection, - verbose=True) + reref, ref_data = set_eeg_reference( + raw.copy(), ch_type=ch_type, projection=projection, verbose=True + ) if not projection: assert f"Applying a custom {msg}" in log.getvalue() - assert reref.info['custom_ref_applied'] # gh-7350 + assert reref.info["custom_ref_applied"] # gh-7350 _test_reference(raw, reref, ref_data, ref_ch) match = "no EEG data found" if projection else "No channels supplied" with pytest.raises(ValueError, match=match): - set_eeg_reference(raw, ch_type='eeg', projection=projection) + set_eeg_reference(raw, ch_type="eeg", projection=projection) # gh-8739 - raw2 = RawArray(data, create_info(5, 1000., ['mag'] * 4 + ['misc'])) - with pytest.raises(ValueError, match='No EEG, ECoG, sEEG or DBS channels ' - 'found to rereference.'): - set_eeg_reference(raw2, ch_type='auto', projection=projection) + raw2 = RawArray(data, create_info(5, 1000.0, ["mag"] * 4 + ["misc"])) + with pytest.raises( + ValueError, match="No EEG, ECoG, sEEG or DBS channels " "found to rereference." + ): + set_eeg_reference(raw2, ch_type="auto", projection=projection) @testing.requires_testing_data def test_set_eeg_reference_rest(): """Test setting a REST reference.""" - raw = read_raw_fif(fif_fname).crop(0, 1).pick_types( - meg=False, eeg=True, exclude=()).load_data() - raw.info['bads'] = ['EEG 057'] # should be excluded - same = [raw.ch_names.index(raw.info['bads'][0])] + raw = ( + read_raw_fif(fif_fname) + .crop(0, 1) + .pick_types(meg=False, eeg=True, exclude=()) + .load_data() + ) + raw.info["bads"] = ["EEG 057"] # should be excluded + same = [raw.ch_names.index(raw.info["bads"][0])] picks = np.setdiff1d(np.arange(len(raw.ch_names)), same) trans = None - sphere = make_sphere_model('auto', 'auto', raw.info) - src = setup_volume_source_space(pos=20., sphere=sphere, exclude=30.) - assert src[0]['nuse'] == 223 # low but fast + sphere = make_sphere_model("auto", "auto", raw.info) + src = setup_volume_source_space(pos=20.0, sphere=sphere, exclude=30.0) + assert src[0]["nuse"] == 223 # low but fast fwd = make_forward_solution(raw.info, trans, src, sphere) orig_data = raw.get_data() - avg_data = raw.copy().set_eeg_reference('average').get_data() + avg_data = raw.copy().set_eeg_reference("average").get_data() assert_array_equal(avg_data[same], orig_data[same]) # not processed - raw.set_eeg_reference('REST', forward=fwd) + raw.set_eeg_reference("REST", forward=fwd) rest_data = raw.get_data() assert_array_equal(rest_data[same], orig_data[same]) # should be more similar to an avg ref than nose ref - orig_corr = np.corrcoef(rest_data[picks].ravel(), - orig_data[picks].ravel())[0, 1] - avg_corr = np.corrcoef(rest_data[picks].ravel(), - avg_data[picks].ravel())[0, 1] + orig_corr = np.corrcoef(rest_data[picks].ravel(), orig_data[picks].ravel())[0, 1] + avg_corr = np.corrcoef(rest_data[picks].ravel(), avg_data[picks].ravel())[0, 1] assert -0.6 < orig_corr < -0.5 assert 0.1 < avg_corr < 0.2 # and applying an avg ref after should work - avg_after = raw.set_eeg_reference('average').get_data() + avg_after = raw.set_eeg_reference("average").get_data() assert_allclose(avg_after, avg_data, atol=1e-12) with pytest.raises(TypeError, match='forward when ref_channels="REST"'): - raw.set_eeg_reference('REST') + raw.set_eeg_reference("REST") fwd_bad = pick_channels_forward(fwd, raw.ch_names[:-1]) - with pytest.raises(ValueError, match='Missing channels'): - raw.set_eeg_reference('REST', forward=fwd_bad) + with pytest.raises(ValueError, match="Missing channels"): + raw.set_eeg_reference("REST", forward=fwd_bad) # compare to FieldTrip evoked = read_evokeds(ave_fname, baseline=(None, 0))[0] - evoked.info['bads'] = [] + evoked.info["bads"] = [] evoked.pick_types(meg=False, eeg=True, exclude=()) assert len(evoked.ch_names) == 60 # Data obtained from FieldTrip with something like (after evoked.save'ing @@ -337,7 +360,7 @@ def test_set_eeg_reference_rest(): old = evoked.data[:, idx].ravel() exp_var = 1 - np.linalg.norm(want - old) / norm assert 0.006 < exp_var < 0.008 - evoked.set_eeg_reference('REST', forward=fwd) + evoked.set_eeg_reference("REST", forward=fwd) exp_var_old = 1 - np.linalg.norm(evoked.data[:, idx] - old) / norm assert 0.005 < exp_var_old <= 0.009 exp_var = 1 - np.linalg.norm(evoked.data[:, idx] - want) / norm @@ -345,118 +368,151 @@ def test_set_eeg_reference_rest(): @testing.requires_testing_data -@pytest.mark.parametrize('inst_type', ('raw', 'epochs', 'evoked')) +@pytest.mark.parametrize("inst_type", ("raw", "epochs", "evoked")) def test_set_bipolar_reference(inst_type): """Test bipolar referencing.""" raw = read_raw_fif(fif_fname, preload=True) raw.apply_proj() - if inst_type == 'raw': + if inst_type == "raw": inst = raw del raw - elif inst_type in ['epochs', 'evoked']: - events = find_events(raw, stim_channel='STI 014') + elif inst_type in ["epochs", "evoked"]: + events = find_events(raw, stim_channel="STI 014") epochs = Epochs(raw, events, tmin=-0.3, tmax=0.7, preload=True) inst = epochs - if inst_type == 'evoked': + if inst_type == "evoked": inst = epochs.average() del epochs - ch_info = {'kind': FIFF.FIFFV_EOG_CH, 'extra': 'some extra value'} - with pytest.raises(KeyError, match='key errantly present'): - set_bipolar_reference(inst, 'EEG 001', 'EEG 002', 'bipolar', ch_info) - ch_info.pop('extra') - reref = set_bipolar_reference( - inst, 'EEG 001', 'EEG 002', 'bipolar', ch_info) - assert reref.info['custom_ref_applied'] + ch_info = {"kind": FIFF.FIFFV_EOG_CH, "extra": "some extra value"} + with pytest.raises(KeyError, match="key errantly present"): + set_bipolar_reference(inst, "EEG 001", "EEG 002", "bipolar", ch_info) + ch_info.pop("extra") + reref = set_bipolar_reference(inst, "EEG 001", "EEG 002", "bipolar", ch_info) + assert reref.info["custom_ref_applied"] # Compare result to a manual calculation - a = inst.copy().pick_channels(['EEG 001', 'EEG 002']) + a = inst.copy().pick_channels(["EEG 001", "EEG 002"]) a = a._data[..., 0, :] - a._data[..., 1, :] - b = reref.copy().pick_channels(['bipolar'])._data[..., 0, :] + b = reref.copy().pick_channels(["bipolar"])._data[..., 0, :] assert_allclose(a, b) # Original channels should be replaced by a virtual one - assert 'EEG 001' not in reref.ch_names - assert 'EEG 002' not in reref.ch_names - assert 'bipolar' in reref.ch_names + assert "EEG 001" not in reref.ch_names + assert "EEG 002" not in reref.ch_names + assert "bipolar" in reref.ch_names # Check channel information - bp_info = reref.info['chs'][reref.ch_names.index('bipolar')] - an_info = inst.info['chs'][inst.ch_names.index('EEG 001')] + bp_info = reref.info["chs"][reref.ch_names.index("bipolar")] + an_info = inst.info["chs"][inst.ch_names.index("EEG 001")] for key in bp_info: - if key == 'coil_type': + if key == "coil_type": assert bp_info[key] == FIFF.FIFFV_COIL_EEG_BIPOLAR, key - elif key == 'kind': + elif key == "kind": assert bp_info[key] == FIFF.FIFFV_EOG_CH, key - elif key != 'ch_name': + elif key != "ch_name": assert_equal(bp_info[key], an_info[key], err_msg=key) # Minimalist call - reref = set_bipolar_reference(inst, 'EEG 001', 'EEG 002') - assert 'EEG 001-EEG 002' in reref.ch_names + reref = set_bipolar_reference(inst, "EEG 001", "EEG 002") + assert "EEG 001-EEG 002" in reref.ch_names # Minimalist call with twice the same anode - reref = set_bipolar_reference(inst, - ['EEG 001', 'EEG 001', 'EEG 002'], - ['EEG 002', 'EEG 003', 'EEG 003']) - assert 'EEG 001-EEG 002' in reref.ch_names - assert 'EEG 001-EEG 003' in reref.ch_names + reref = set_bipolar_reference( + inst, ["EEG 001", "EEG 001", "EEG 002"], ["EEG 002", "EEG 003", "EEG 003"] + ) + assert "EEG 001-EEG 002" in reref.ch_names + assert "EEG 001-EEG 003" in reref.ch_names # Set multiple references at once reref = set_bipolar_reference( inst, - ['EEG 001', 'EEG 003'], - ['EEG 002', 'EEG 004'], - ['bipolar1', 'bipolar2'], - [{'kind': FIFF.FIFFV_EOG_CH}, - {'kind': FIFF.FIFFV_EOG_CH}], + ["EEG 001", "EEG 003"], + ["EEG 002", "EEG 004"], + ["bipolar1", "bipolar2"], + [{"kind": FIFF.FIFFV_EOG_CH}, {"kind": FIFF.FIFFV_EOG_CH}], ) - a = inst.copy().pick_channels(['EEG 001', 'EEG 002', 'EEG 003', 'EEG 004']) + a = inst.copy().pick_channels(["EEG 001", "EEG 002", "EEG 003", "EEG 004"]) a = np.concatenate( - [a._data[..., :1, :] - a._data[..., 1:2, :], - a._data[..., 2:3, :] - a._data[..., 3:4, :]], - axis=-2 + [ + a._data[..., :1, :] - a._data[..., 1:2, :], + a._data[..., 2:3, :] - a._data[..., 3:4, :], + ], + axis=-2, ) - b = reref.copy().pick_channels(['bipolar1', 'bipolar2'])._data + b = reref.copy().pick_channels(["bipolar1", "bipolar2"])._data assert_allclose(a, b) # Test creating a bipolar reference that doesn't involve EEG channels: # it should not set the custom_ref_applied flag - reref = set_bipolar_reference(inst, 'MEG 0111', 'MEG 0112', - ch_info={'kind': FIFF.FIFFV_MEG_CH}, - verbose='error') - assert not reref.info['custom_ref_applied'] - assert 'MEG 0111-MEG 0112' in reref.ch_names + reref = set_bipolar_reference( + inst, + "MEG 0111", + "MEG 0112", + ch_info={"kind": FIFF.FIFFV_MEG_CH}, + verbose="error", + ) + assert not reref.info["custom_ref_applied"] + assert "MEG 0111-MEG 0112" in reref.ch_names # Test a battery of invalid inputs - pytest.raises(ValueError, set_bipolar_reference, inst, - 'EEG 001', ['EEG 002', 'EEG 003'], 'bipolar') - pytest.raises(ValueError, set_bipolar_reference, inst, - ['EEG 001', 'EEG 002'], 'EEG 003', 'bipolar') - pytest.raises(ValueError, set_bipolar_reference, inst, - 'EEG 001', 'EEG 002', ['bipolar1', 'bipolar2']) - pytest.raises(ValueError, set_bipolar_reference, inst, - 'EEG 001', 'EEG 002', 'bipolar', - ch_info=[{'foo': 'bar'}, {'foo': 'bar'}]) - pytest.raises(ValueError, set_bipolar_reference, inst, - 'EEG 001', 'EEG 002', ch_name='EEG 003') + pytest.raises( + ValueError, + set_bipolar_reference, + inst, + "EEG 001", + ["EEG 002", "EEG 003"], + "bipolar", + ) + pytest.raises( + ValueError, + set_bipolar_reference, + inst, + ["EEG 001", "EEG 002"], + "EEG 003", + "bipolar", + ) + pytest.raises( + ValueError, + set_bipolar_reference, + inst, + "EEG 001", + "EEG 002", + ["bipolar1", "bipolar2"], + ) + pytest.raises( + ValueError, + set_bipolar_reference, + inst, + "EEG 001", + "EEG 002", + "bipolar", + ch_info=[{"foo": "bar"}, {"foo": "bar"}], + ) + pytest.raises( + ValueError, set_bipolar_reference, inst, "EEG 001", "EEG 002", ch_name="EEG 003" + ) # Test if bad anode/cathode raises error if on_bad="raise" inst.info["bads"] = ["EEG 001"] - pytest.raises(ValueError, set_bipolar_reference, inst, - 'EEG 001', 'EEG 002', on_bad="raise") + pytest.raises( + ValueError, set_bipolar_reference, inst, "EEG 001", "EEG 002", on_bad="raise" + ) inst.info["bads"] = ["EEG 002"] - pytest.raises(ValueError, set_bipolar_reference, inst, - 'EEG 001', 'EEG 002', on_bad="raise") + pytest.raises( + ValueError, set_bipolar_reference, inst, "EEG 001", "EEG 002", on_bad="raise" + ) # Test if bad anode/cathode raises warning if on_bad="warn" inst.info["bads"] = ["EEG 001"] - pytest.warns(RuntimeWarning, set_bipolar_reference, inst, - 'EEG 001', 'EEG 002', on_bad="warn") + pytest.warns( + RuntimeWarning, set_bipolar_reference, inst, "EEG 001", "EEG 002", on_bad="warn" + ) inst.info["bads"] = ["EEG 002"] - pytest.warns(RuntimeWarning, set_bipolar_reference, inst, - 'EEG 001', 'EEG 002', on_bad="warn") + pytest.warns( + RuntimeWarning, set_bipolar_reference, inst, "EEG 001", "EEG 002", on_bad="warn" + ) def _check_channel_names(inst, ref_names): @@ -465,7 +521,7 @@ def _check_channel_names(inst, ref_names): ref_names = [ref_names] # Test that the names of the reference channels are present in `ch_names` - ref_idx = pick_channels(inst.info['ch_names'], ref_names) + ref_idx = pick_channels(inst.info["ch_names"], ref_names) assert len(ref_idx) == len(ref_names) # Test that the names of the reference channels are present in the `chs` @@ -479,61 +535,60 @@ def test_add_reference(): raw = read_raw_fif(fif_fname, preload=True) picks_eeg = pick_types(raw.info, meg=False, eeg=True) # check if channel already exists - pytest.raises(ValueError, add_reference_channels, - raw, raw.info['ch_names'][0]) + pytest.raises(ValueError, add_reference_channels, raw, raw.info["ch_names"][0]) # add reference channel to Raw - raw_ref = add_reference_channels(raw, 'Ref', copy=True) + raw_ref = add_reference_channels(raw, "Ref", copy=True) assert_equal(raw_ref._data.shape[0], raw._data.shape[0] + 1) assert_array_equal(raw._data[picks_eeg, :], raw_ref._data[picks_eeg, :]) - _check_channel_names(raw_ref, 'Ref') + _check_channel_names(raw_ref, "Ref") - orig_nchan = raw.info['nchan'] - raw = add_reference_channels(raw, 'Ref', copy=False) + orig_nchan = raw.info["nchan"] + raw = add_reference_channels(raw, "Ref", copy=False) assert_array_equal(raw._data, raw_ref._data) - assert_equal(raw.info['nchan'], orig_nchan + 1) - _check_channel_names(raw, 'Ref') + assert_equal(raw.info["nchan"], orig_nchan + 1) + _check_channel_names(raw, "Ref") # for Neuromag fif's, the reference electrode location is placed in # elements [3:6] of each "data" electrode location - assert_allclose(raw.info['chs'][-1]['loc'][:3], - raw.info['chs'][picks_eeg[0]]['loc'][3:6], 1e-6) + assert_allclose( + raw.info["chs"][-1]["loc"][:3], raw.info["chs"][picks_eeg[0]]["loc"][3:6], 1e-6 + ) - ref_idx = raw.ch_names.index('Ref') + ref_idx = raw.ch_names.index("Ref") ref_data, _ = raw[ref_idx] assert_array_equal(ref_data, 0) # add reference channel to Raw when no digitization points exist raw = read_raw_fif(fif_fname).crop(0, 1).load_data() picks_eeg = pick_types(raw.info, meg=False, eeg=True) - del raw.info['dig'] + del raw.info["dig"] - raw_ref = add_reference_channels(raw, 'Ref', copy=True) + raw_ref = add_reference_channels(raw, "Ref", copy=True) assert_equal(raw_ref._data.shape[0], raw._data.shape[0] + 1) assert_array_equal(raw._data[picks_eeg, :], raw_ref._data[picks_eeg, :]) - _check_channel_names(raw_ref, 'Ref') + _check_channel_names(raw_ref, "Ref") - orig_nchan = raw.info['nchan'] - raw = add_reference_channels(raw, 'Ref', copy=False) + orig_nchan = raw.info["nchan"] + raw = add_reference_channels(raw, "Ref", copy=False) assert_array_equal(raw._data, raw_ref._data) - assert_equal(raw.info['nchan'], orig_nchan + 1) - _check_channel_names(raw, 'Ref') + assert_equal(raw.info["nchan"], orig_nchan + 1) + _check_channel_names(raw, "Ref") # Test adding an existing channel as reference channel - pytest.raises(ValueError, add_reference_channels, raw, - raw.info['ch_names'][0]) + pytest.raises(ValueError, add_reference_channels, raw, raw.info["ch_names"][0]) # add two reference channels to Raw - raw_ref = add_reference_channels(raw, ['M1', 'M2'], copy=True) - _check_channel_names(raw_ref, ['M1', 'M2']) + raw_ref = add_reference_channels(raw, ["M1", "M2"], copy=True) + _check_channel_names(raw_ref, ["M1", "M2"]) assert_equal(raw_ref._data.shape[0], raw._data.shape[0] + 2) assert_array_equal(raw._data[picks_eeg, :], raw_ref._data[picks_eeg, :]) assert_array_equal(raw_ref._data[-2:, :], 0) - raw = add_reference_channels(raw, ['M1', 'M2'], copy=False) - _check_channel_names(raw, ['M1', 'M2']) - ref_idx = raw.ch_names.index('M1') - ref_idy = raw.ch_names.index('M2') + raw = add_reference_channels(raw, ["M1", "M2"], copy=False) + _check_channel_names(raw, ["M1", "M2"]) + ref_idx = raw.ch_names.index("M1") + ref_idy = raw.ch_names.index("M2") ref_data, _ = raw[[ref_idx, ref_idy]] assert_array_equal(ref_data, 0) @@ -541,116 +596,155 @@ def test_add_reference(): raw = read_raw_fif(fif_fname, preload=True) events = read_events(eve_fname) picks_eeg = pick_types(raw.info, meg=False, eeg=True) - epochs = Epochs(raw, events=events, event_id=1, tmin=-0.2, tmax=0.5, - picks=picks_eeg, preload=True) + epochs = Epochs( + raw, + events=events, + event_id=1, + tmin=-0.2, + tmax=0.5, + picks=picks_eeg, + preload=True, + ) # default: proj=True, after which adding a Ref channel is prohibited - pytest.raises(RuntimeError, add_reference_channels, epochs, 'Ref') + pytest.raises(RuntimeError, add_reference_channels, epochs, "Ref") # create epochs in delayed mode, allowing removal of CAR when re-reffing - epochs = Epochs(raw, events=events, event_id=1, tmin=-0.2, tmax=0.5, - picks=picks_eeg, preload=True, proj='delayed') - epochs_ref = add_reference_channels(epochs, 'Ref', copy=True) + epochs = Epochs( + raw, + events=events, + event_id=1, + tmin=-0.2, + tmax=0.5, + picks=picks_eeg, + preload=True, + proj="delayed", + ) + epochs_ref = add_reference_channels(epochs, "Ref", copy=True) assert_equal(epochs_ref._data.shape[1], epochs._data.shape[1] + 1) - _check_channel_names(epochs_ref, 'Ref') - ref_idx = epochs_ref.ch_names.index('Ref') + _check_channel_names(epochs_ref, "Ref") + ref_idx = epochs_ref.ch_names.index("Ref") ref_data = epochs_ref.get_data()[:, ref_idx, :] assert_array_equal(ref_data, 0) picks_eeg = pick_types(epochs.info, meg=False, eeg=True) - assert_array_equal(epochs.get_data()[:, picks_eeg, :], - epochs_ref.get_data()[:, picks_eeg, :]) + assert_array_equal( + epochs.get_data()[:, picks_eeg, :], epochs_ref.get_data()[:, picks_eeg, :] + ) # add two reference channels to epochs raw = read_raw_fif(fif_fname, preload=True) events = read_events(eve_fname) picks_eeg = pick_types(raw.info, meg=False, eeg=True) # create epochs in delayed mode, allowing removal of CAR when re-reffing - epochs = Epochs(raw, events=events, event_id=1, tmin=-0.2, tmax=0.5, - picks=picks_eeg, preload=True, proj='delayed') - with pytest.warns(RuntimeWarning, match='reference channels are ignored'): - epochs_ref = add_reference_channels(epochs, ['M1', 'M2'], copy=True) + epochs = Epochs( + raw, + events=events, + event_id=1, + tmin=-0.2, + tmax=0.5, + picks=picks_eeg, + preload=True, + proj="delayed", + ) + with pytest.warns(RuntimeWarning, match="reference channels are ignored"): + epochs_ref = add_reference_channels(epochs, ["M1", "M2"], copy=True) assert_equal(epochs_ref._data.shape[1], epochs._data.shape[1] + 2) - _check_channel_names(epochs_ref, ['M1', 'M2']) - ref_idx = epochs_ref.ch_names.index('M1') - ref_idy = epochs_ref.ch_names.index('M2') - assert_equal(epochs_ref.info['chs'][ref_idx]['ch_name'], 'M1') - assert_equal(epochs_ref.info['chs'][ref_idy]['ch_name'], 'M2') + _check_channel_names(epochs_ref, ["M1", "M2"]) + ref_idx = epochs_ref.ch_names.index("M1") + ref_idy = epochs_ref.ch_names.index("M2") + assert_equal(epochs_ref.info["chs"][ref_idx]["ch_name"], "M1") + assert_equal(epochs_ref.info["chs"][ref_idy]["ch_name"], "M2") ref_data = epochs_ref.get_data()[:, [ref_idx, ref_idy], :] assert_array_equal(ref_data, 0) picks_eeg = pick_types(epochs.info, meg=False, eeg=True) - assert_array_equal(epochs.get_data()[:, picks_eeg, :], - epochs_ref.get_data()[:, picks_eeg, :]) + assert_array_equal( + epochs.get_data()[:, picks_eeg, :], epochs_ref.get_data()[:, picks_eeg, :] + ) # add reference channel to evoked raw = read_raw_fif(fif_fname, preload=True) events = read_events(eve_fname) picks_eeg = pick_types(raw.info, meg=False, eeg=True) # create epochs in delayed mode, allowing removal of CAR when re-reffing - epochs = Epochs(raw, events=events, event_id=1, tmin=-0.2, tmax=0.5, - picks=picks_eeg, preload=True, proj='delayed') + epochs = Epochs( + raw, + events=events, + event_id=1, + tmin=-0.2, + tmax=0.5, + picks=picks_eeg, + preload=True, + proj="delayed", + ) evoked = epochs.average() - evoked_ref = add_reference_channels(evoked, 'Ref', copy=True) + evoked_ref = add_reference_channels(evoked, "Ref", copy=True) assert_equal(evoked_ref.data.shape[0], evoked.data.shape[0] + 1) - _check_channel_names(evoked_ref, 'Ref') - ref_idx = evoked_ref.ch_names.index('Ref') + _check_channel_names(evoked_ref, "Ref") + ref_idx = evoked_ref.ch_names.index("Ref") ref_data = evoked_ref.data[ref_idx, :] assert_array_equal(ref_data, 0) picks_eeg = pick_types(evoked.info, meg=False, eeg=True) - assert_array_equal(evoked.data[picks_eeg, :], - evoked_ref.data[picks_eeg, :]) + assert_array_equal(evoked.data[picks_eeg, :], evoked_ref.data[picks_eeg, :]) # add two reference channels to evoked raw = read_raw_fif(fif_fname, preload=True) events = read_events(eve_fname) picks_eeg = pick_types(raw.info, meg=False, eeg=True) # create epochs in delayed mode, allowing removal of CAR when re-reffing - epochs = Epochs(raw, events=events, event_id=1, tmin=-0.2, tmax=0.5, - picks=picks_eeg, preload=True, proj='delayed') + epochs = Epochs( + raw, + events=events, + event_id=1, + tmin=-0.2, + tmax=0.5, + picks=picks_eeg, + preload=True, + proj="delayed", + ) evoked = epochs.average() - with pytest.warns(RuntimeWarning, match='reference channels are ignored'): - evoked_ref = add_reference_channels(evoked, ['M1', 'M2'], copy=True) + with pytest.warns(RuntimeWarning, match="reference channels are ignored"): + evoked_ref = add_reference_channels(evoked, ["M1", "M2"], copy=True) assert_equal(evoked_ref.data.shape[0], evoked.data.shape[0] + 2) - _check_channel_names(evoked_ref, ['M1', 'M2']) - ref_idx = evoked_ref.ch_names.index('M1') - ref_idy = evoked_ref.ch_names.index('M2') + _check_channel_names(evoked_ref, ["M1", "M2"]) + ref_idx = evoked_ref.ch_names.index("M1") + ref_idy = evoked_ref.ch_names.index("M2") ref_data = evoked_ref.data[[ref_idx, ref_idy], :] assert_array_equal(ref_data, 0) picks_eeg = pick_types(evoked.info, meg=False, eeg=True) - assert_array_equal(evoked.data[picks_eeg, :], - evoked_ref.data[picks_eeg, :]) + assert_array_equal(evoked.data[picks_eeg, :], evoked_ref.data[picks_eeg, :]) # Test invalid inputs raw = read_raw_fif(fif_fname, preload=False) - with pytest.raises(RuntimeError, match='loaded'): - add_reference_channels(raw, ['Ref']) + with pytest.raises(RuntimeError, match="loaded"): + add_reference_channels(raw, ["Ref"]) raw.load_data() - with pytest.raises(ValueError, match='Channel.*already.*'): + with pytest.raises(ValueError, match="Channel.*already.*"): add_reference_channels(raw, raw.ch_names[:1]) - with pytest.raises(TypeError, match='instance of'): + with pytest.raises(TypeError, match="instance of"): add_reference_channels(raw, 1) # gh-10878 raw = read_raw_fif(raw_fname).crop(0, 1, include_tmax=False).load_data() - data = raw.copy().add_reference_channels(['REF']).pick_types(eeg=True) + data = raw.copy().add_reference_channels(["REF"]).pick_types(eeg=True) data = data.get_data() epochs = make_fixed_length_epochs(raw).load_data() - data_2 = epochs.copy().add_reference_channels(['REF']).pick_types(eeg=True) + data_2 = epochs.copy().add_reference_channels(["REF"]).pick_types(eeg=True) data_2 = data_2.get_data()[0] assert_allclose(data, data_2) evoked = epochs.average() - data_3 = evoked.copy().add_reference_channels(['REF']).pick_types(eeg=True) + data_3 = evoked.copy().add_reference_channels(["REF"]).pick_types(eeg=True) data_3 = data_3.get_data() assert_allclose(data, data_3) -@pytest.mark.parametrize('n_ref', (1, 2)) +@pytest.mark.parametrize("n_ref", (1, 2)) def test_add_reorder(n_ref): """Test that a reference channel can be added and then data reordered.""" # gh-8300 - raw = read_raw_fif(raw_fname).crop(0, 0.1).del_proj().pick('eeg') + raw = read_raw_fif(raw_fname).crop(0, 0.1).del_proj().pick("eeg") assert len(raw.ch_names) == 60 - chs = ['EEG %03d' % (60 + ii) for ii in range(1, n_ref)] + ['EEG 000'] - with pytest.raises(RuntimeError, match='preload'): + chs = ["EEG %03d" % (60 + ii) for ii in range(1, n_ref)] + ["EEG 000"] + with pytest.raises(RuntimeError, match="preload"): with _record_warnings(): # ignore multiple warning add_reference_channels(raw, chs, copy=False) raw.load_data() @@ -658,14 +752,14 @@ def test_add_reorder(n_ref): ctx = nullcontext() else: assert n_ref == 2 - ctx = pytest.warns(RuntimeWarning, match='locations of multiple') + ctx = pytest.warns(RuntimeWarning, match="locations of multiple") with ctx: add_reference_channels(raw, chs, copy=False) data = raw.get_data() - assert_array_equal(data[-1], 0.) + assert_array_equal(data[-1], 0.0) assert raw.ch_names[-n_ref:] == chs raw.reorder_channels(raw.ch_names[-1:] + raw.ch_names[:-1]) - assert raw.ch_names == ['EEG %03d' % ii for ii in range(60 + n_ref)] + assert raw.ch_names == ["EEG %03d" % ii for ii in range(60 + n_ref)] data_new = raw.get_data() data_new = np.concatenate([data_new[1:], data_new[:1]]) assert_allclose(data, data_new) @@ -673,22 +767,23 @@ def test_add_reorder(n_ref): def test_bipolar_combinations(): """Test bipolar channel generation.""" - ch_names = ['CH' + str(ni + 1) for ni in range(10)] + ch_names = ["CH" + str(ni + 1) for ni in range(10)] info = create_info( - ch_names=ch_names, sfreq=1000., ch_types=['eeg'] * len(ch_names)) + ch_names=ch_names, sfreq=1000.0, ch_types=["eeg"] * len(ch_names) + ) raw_data = np.random.randn(len(ch_names), 1000) raw = RawArray(raw_data, info) def _check_bipolar(raw_test, ch_a, ch_b): - picks = [raw_test.ch_names.index(ch_a + '-' + ch_b)] + picks = [raw_test.ch_names.index(ch_a + "-" + ch_b)] get_data_res = raw_test.get_data(picks=picks)[0, :] manual_a = raw_data[ch_names.index(ch_a), :] manual_b = raw_data[ch_names.index(ch_b), :] assert_array_equal(get_data_res, manual_a - manual_b) # test classic EOG/ECG bipolar reference (only two channels per pair). - raw_test = set_bipolar_reference(raw, ['CH2'], ['CH1'], copy=True) - _check_bipolar(raw_test, 'CH2', 'CH1') + raw_test = set_bipolar_reference(raw, ["CH2"], ["CH1"], copy=True) + _check_bipolar(raw_test, "CH2", "CH1") # test all combinations. a_channels, b_channels = zip(*itertools.combinations(ch_names, 2)) @@ -700,7 +795,8 @@ def _check_bipolar(raw_test, ch_a, ch_b): assert len(raw_test.ch_names) == len(a_channels) raw_test = set_bipolar_reference( - raw, a_channels, b_channels, drop_refs=False, copy=True) + raw, a_channels, b_channels, drop_refs=False, copy=True + ) # check if reference channels have been kept correctly. assert len(raw_test.ch_names) == len(a_channels) + len(ch_names) for idx, ch_label in enumerate(ch_names): @@ -708,19 +804,20 @@ def _check_bipolar(raw_test, ch_a, ch_b): assert_array_equal(raw_test.get_data(ch_label), manual_ch) # test bipolars with a channel in both list (anode & cathode). - raw_test = set_bipolar_reference( - raw, ['CH2', 'CH1'], ['CH1', 'CH2'], copy=True) - _check_bipolar(raw_test, 'CH2', 'CH1') - _check_bipolar(raw_test, 'CH1', 'CH2') + raw_test = set_bipolar_reference(raw, ["CH2", "CH1"], ["CH1", "CH2"], copy=True) + _check_bipolar(raw_test, "CH2", "CH1") + _check_bipolar(raw_test, "CH1", "CH2") # test if bipolar channel is bad if anode is a bad channel raw.info["bads"] = ["CH1"] - raw_test = set_bipolar_reference(raw, ['CH1'], ['CH2'], on_bad="ignore", - ch_name="bad_bipolar", copy=True) + raw_test = set_bipolar_reference( + raw, ["CH1"], ["CH2"], on_bad="ignore", ch_name="bad_bipolar", copy=True + ) assert raw_test.info["bads"] == ["bad_bipolar"] # test if bipolar channel is bad if cathode is a bad channel raw.info["bads"] = ["CH2"] - raw_test = set_bipolar_reference(raw, ['CH1'], ['CH2'], on_bad="ignore", - ch_name="bad_bipolar", copy=True) + raw_test = set_bipolar_reference( + raw, ["CH1"], ["CH2"], on_bad="ignore", ch_name="bad_bipolar", copy=True + ) assert raw_test.info["bads"] == ["bad_bipolar"] diff --git a/mne/io/tests/test_show_fiff.py b/mne/io/tests/test_show_fiff.py index 52beb9cdbed..7ca46685bf8 100644 --- a/mne/io/tests/test_show_fiff.py +++ b/mne/io/tests/test_show_fiff.py @@ -16,13 +16,20 @@ def test_show_fiff(): """Test show_fiff.""" # this is not exhaustive, but hopefully bugs will be found in use info = show_fiff(fname_evoked) - assert 'BAD' not in info - keys = ['FIFF_EPOCH', 'FIFFB_HPI_COIL', 'FIFFB_PROJ_ITEM', - 'FIFFB_PROCESSED_DATA', 'FIFFB_EVOKED', 'FIFF_NAVE', - 'FIFF_EPOCH', 'COORD_TRANS'] + assert "BAD" not in info + keys = [ + "FIFF_EPOCH", + "FIFFB_HPI_COIL", + "FIFFB_PROJ_ITEM", + "FIFFB_PROCESSED_DATA", + "FIFFB_EVOKED", + "FIFF_NAVE", + "FIFF_EPOCH", + "COORD_TRANS", + ] assert all(key in info for key in keys) info = show_fiff(fname_raw, read_limit=1024) - assert 'BAD' not in info + assert "BAD" not in info info = show_fiff(fname_c_annot) - assert 'BAD' not in info - assert '>B' in info, info + assert "BAD" not in info + assert ">B" in info, info diff --git a/mne/io/tests/test_utils.py b/mne/io/tests/test_utils.py index 601a9df4e9c..d495ffc86ef 100644 --- a/mne/io/tests/test_utils.py +++ b/mne/io/tests/test_utils.py @@ -8,11 +8,10 @@ def test_check_orig_units(): """Test the checking of original units.""" - orig_units = dict(FC1='nV', Hfp3erz='n/a', Pz='uV', greekMu='μV', - microSign='µV') + orig_units = dict(FC1="nV", Hfp3erz="n/a", Pz="uV", greekMu="μV", microSign="µV") orig_units = _check_orig_units(orig_units) - assert orig_units['FC1'] == 'nV' - assert orig_units['Hfp3erz'] == 'n/a' - assert orig_units['Pz'] == 'µV' - assert orig_units['greekMu'] == 'µV' - assert orig_units['microSign'] == 'µV' + assert orig_units["FC1"] == "nV" + assert orig_units["Hfp3erz"] == "n/a" + assert orig_units["Pz"] == "µV" + assert orig_units["greekMu"] == "µV" + assert orig_units["microSign"] == "µV" diff --git a/mne/io/tests/test_what.py b/mne/io/tests/test_what.py index 96f4cf0d42d..12eca74978b 100644 --- a/mne/io/tests/test_what.py +++ b/mne/io/tests/test_what.py @@ -23,30 +23,39 @@ def test_what(tmp_path, verbose_debug): """Test mne.what.""" # ICA ica = ICA(max_iter=1) - raw = RawArray(np.random.RandomState(0).randn(3, 10), - create_info(3, 1000., 'eeg')) + raw = RawArray(np.random.RandomState(0).randn(3, 10), create_info(3, 1000.0, "eeg")) with _record_warnings(): # convergence sometimes ica.fit(raw) fname = tmp_path / "x-ica.fif" ica.save(fname) - assert what(fname) == 'ica' + assert what(fname) == "ica" # test files fnames = glob.glob(str(data_path / "MEG" / "sample" / "*.fif")) - fnames += glob.glob( - str(data_path / "subjects" / "sample" / "bem" / "*.fif") - ) + fnames += glob.glob(str(data_path / "subjects" / "sample" / "bem" / "*.fif")) fnames = sorted(fnames) - want_dict = dict(eve='events', ave='evoked', cov='cov', inv='inverse', - fwd='forward', trans='transform', proj='proj', - raw='raw', meg='raw', sol='bem solution', - bem='bem surfaces', src='src', dense='bem surfaces', - sparse='bem surfaces', head='bem surfaces', - fiducials='fiducials') + want_dict = dict( + eve="events", + ave="evoked", + cov="cov", + inv="inverse", + fwd="forward", + trans="transform", + proj="proj", + raw="raw", + meg="raw", + sol="bem solution", + bem="bem surfaces", + src="src", + dense="bem surfaces", + sparse="bem surfaces", + head="bem surfaces", + fiducials="fiducials", + ) for fname in fnames: kind = Path(fname).stem.split("-")[-1] if len(kind) > 5: - kind = kind.split('_')[-1] + kind = kind.split("_")[-1] this = what(fname) assert this == want_dict[kind] fname = data_path / "MEG" / "sample" / "sample_audvis-ave_xfit.dip" - assert what(fname) == 'unknown' + assert what(fname) == "unknown" diff --git a/mne/io/tests/test_write.py b/mne/io/tests/test_write.py index a86e47d175c..2dbfe5aa743 100644 --- a/mne/io/tests/test_write.py +++ b/mne/io/tests/test_write.py @@ -11,10 +11,10 @@ def test_write_int(tmp_path): """Test that write_int raises an error on bad values.""" - with start_file(tmp_path / 'temp.fif') as fid: + with start_file(tmp_path / "temp.fif") as fid: write_int(fid, FIFF.FIFF_MNE_EVENT_LIST, [2147483647]) # 2 ** 31 - 1 write_int(fid, FIFF.FIFF_MNE_EVENT_LIST, []) # 2 ** 31 - 1 - with pytest.raises(TypeError, match=r'.*exceeds max.*EVENT_LIST\)'): + with pytest.raises(TypeError, match=r".*exceeds max.*EVENT_LIST\)"): write_int(fid, FIFF.FIFF_MNE_EVENT_LIST, [2147483648]) # 2 ** 31 - with pytest.raises(TypeError, match='Cannot safely write'): - write_int(fid, FIFF.FIFF_MNE_EVENT_LIST, [0.]) # float + with pytest.raises(TypeError, match="Cannot safely write"): + write_int(fid, FIFF.FIFF_MNE_EVENT_LIST, [0.0]) # float diff --git a/mne/io/tree.py b/mne/io/tree.py index b4ed4ee1c7b..b0f1f415dfa 100644 --- a/mne/io/tree.py +++ b/mne/io/tree.py @@ -34,11 +34,11 @@ def dir_tree_find(tree, kind): nodes += dir_tree_find(t, kind) else: # Am I desirable myself? - if tree['block'] == kind: + if tree["block"] == kind: nodes.append(tree) # Search the subtrees - for child in tree['children']: + for child in tree["children"]: nodes += dir_tree_find(child, kind) return nodes @@ -58,57 +58,60 @@ def make_dir_tree(fid, directory, start=0, indent=0, verbose=None): else: block = 0 - logger.debug(' ' * indent + 'start { %d' % block) + logger.debug(" " * indent + "start { %d" % block) this = start tree = dict() - tree['block'] = block - tree['id'] = None - tree['parent_id'] = None - tree['nent'] = 0 - tree['nchild'] = 0 - tree['directory'] = directory[this] - tree['children'] = [] + tree["block"] = block + tree["id"] = None + tree["parent_id"] = None + tree["nent"] = 0 + tree["nchild"] = 0 + tree["directory"] = directory[this] + tree["children"] = [] while this < len(directory): if directory[this].kind == FIFF_BLOCK_START: if this != start: child, this = make_dir_tree(fid, directory, this, indent + 1) - tree['nchild'] += 1 - tree['children'].append(child) + tree["nchild"] += 1 + tree["children"].append(child) elif directory[this].kind == FIFF_BLOCK_END: tag = read_tag(fid, directory[start].pos) if tag.data == block: break else: - tree['nent'] += 1 - if tree['nent'] == 1: - tree['directory'] = list() - tree['directory'].append(directory[this]) + tree["nent"] += 1 + if tree["nent"] == 1: + tree["directory"] = list() + tree["directory"].append(directory[this]) # Add the id information if available if block == 0: if directory[this].kind == FIFF_FILE_ID: tag = read_tag(fid, directory[this].pos) - tree['id'] = tag.data + tree["id"] = tag.data else: if directory[this].kind == FIFF_BLOCK_ID: tag = read_tag(fid, directory[this].pos) - tree['id'] = tag.data + tree["id"] = tag.data elif directory[this].kind == FIFF_PARENT_BLOCK_ID: tag = read_tag(fid, directory[this].pos) - tree['parent_id'] = tag.data + tree["parent_id"] = tag.data this += 1 # Eliminate the empty directory - if tree['nent'] == 0: - tree['directory'] = None - - logger.debug(' ' * (indent + 1) + 'block = %d nent = %d nchild = %d' - % (tree['block'], tree['nent'], tree['nchild'])) - logger.debug(' ' * indent + 'end } %d' % block) + if tree["nent"] == 0: + tree["directory"] = None + + logger.debug( + " " * (indent + 1) + + "block = %d nent = %d nchild = %d" + % (tree["block"], tree["nent"], tree["nchild"]) + ) + logger.debug(" " * indent + "end } %d" % block) last = this return tree, last @@ -116,6 +119,7 @@ def make_dir_tree(fid, directory, start=0, indent=0, verbose=None): ############################################################################### # Writing + def copy_tree(fidin, in_id, nodes, fidout): """Copy directory subtrees from fidin to fidout.""" if len(nodes) <= 0: @@ -125,29 +129,31 @@ def copy_tree(fidin, in_id, nodes, fidout): nodes = [nodes] for node in nodes: - start_block(fidout, node['block']) - if node['id'] is not None: + start_block(fidout, node["block"]) + if node["id"] is not None: if in_id is not None: write_id(fidout, FIFF.FIFF_PARENT_FILE_ID, in_id) write_id(fidout, FIFF.FIFF_BLOCK_ID, in_id) - write_id(fidout, FIFF.FIFF_PARENT_BLOCK_ID, node['id']) + write_id(fidout, FIFF.FIFF_PARENT_BLOCK_ID, node["id"]) - if node['directory'] is not None: - for d in node['directory']: + if node["directory"] is not None: + for d in node["directory"]: # Do not copy these tags - if d.kind == FIFF.FIFF_BLOCK_ID or \ - d.kind == FIFF.FIFF_PARENT_BLOCK_ID or \ - d.kind == FIFF.FIFF_PARENT_FILE_ID: + if ( + d.kind == FIFF.FIFF_BLOCK_ID + or d.kind == FIFF.FIFF_PARENT_BLOCK_ID + or d.kind == FIFF.FIFF_PARENT_FILE_ID + ): continue # Read and write tags, pass data through transparently fidin.seek(d.pos, 0) - tag = Tag(*np.fromfile(fidin, ('>i4,>I4,>i4,>i4'), 1)[0]) - tag.data = np.fromfile(fidin, '>B', tag.size) - _write(fidout, tag.data, tag.kind, 1, tag.type, '>B') + tag = Tag(*np.fromfile(fidin, (">i4,>I4,>i4,>i4"), 1)[0]) + tag.data = np.fromfile(fidin, ">B", tag.size) + _write(fidout, tag.data, tag.kind, 1, tag.type, ">B") - for child in node['children']: + for child in node["children"]: copy_tree(fidin, in_id, child, fidout) - end_block(fidout, node['block']) + end_block(fidout, node["block"]) diff --git a/mne/io/utils.py b/mne/io/utils.py index f9d01d9bae4..f72ea4e20a2 100644 --- a/mne/io/utils.py +++ b/mne/io/utils.py @@ -43,41 +43,46 @@ def _check_orig_units(orig_units): valid_units_lowered = [unit.lower() for unit in valid_units] orig_units_remapped = dict(orig_units) for ch_name, unit in orig_units.items(): - # Be lenient: we ignore case for now. if unit.lower() in valid_units_lowered: continue # Common "invalid units" can be remapped to their valid equivalent remap_dict = dict() - remap_dict['uv'] = 'µV' - remap_dict['μv'] = 'µV' # greek letter mu vs micro sign. use micro - remap_dict['\x83\xeav'] = 'µV' # for shift-jis mu, use micro + remap_dict["uv"] = "µV" + remap_dict["μv"] = "µV" # greek letter mu vs micro sign. use micro + remap_dict["\x83\xeav"] = "µV" # for shift-jis mu, use micro if unit.lower() in remap_dict: orig_units_remapped[ch_name] = remap_dict[unit.lower()] continue # Some units cannot be saved, they are invalid: assign "n/a" - orig_units_remapped[ch_name] = 'n/a' + orig_units_remapped[ch_name] = "n/a" return orig_units_remapped -def _find_channels(ch_names, ch_type='EOG'): +def _find_channels(ch_names, ch_type="EOG"): """Find EOG channel.""" substrings = (ch_type,) substrings = [s.upper() for s in substrings] - if ch_type == 'EOG': - substrings = ('EOG', 'EYE') - eog_idx = [idx for idx, ch in enumerate(ch_names) if - any(substring in ch.upper() for substring in substrings)] + if ch_type == "EOG": + substrings = ("EOG", "EYE") + eog_idx = [ + idx + for idx, ch in enumerate(ch_names) + if any(substring in ch.upper() for substring in substrings) + ] return eog_idx def _mult_cal_one(data_view, one, idx, cals, mult): """Take a chunk of raw data, multiply by mult or cals, and store.""" one = np.asarray(one, dtype=data_view.dtype) - assert data_view.shape[1] == one.shape[1], (data_view.shape[1], one.shape[1]) # noqa: E501 + assert data_view.shape[1] == one.shape[1], ( + data_view.shape[1], + one.shape[1], + ) # noqa: E501 if mult is not None: mult.ndim == one.ndim == 2 data_view[:] = mult @ one[idx] @@ -160,7 +165,7 @@ def _blk_read_lims(start, stop, buf_len): """ # noqa: E501 # this is used to deal with indexing in the middle of a sampling period assert all(isinstance(x, int) for x in (start, stop, buf_len)) - block_start_idx = (start // buf_len) + block_start_idx = start // buf_len block_start = block_start_idx * buf_len last_used_samp = stop - 1 block_stop = last_used_samp - last_used_samp % buf_len + buf_len @@ -192,16 +197,28 @@ def _blk_read_lims(start, stop, buf_len): def _file_size(fname): """Get the file size in bytes.""" - with open(fname, 'rb') as f: + with open(fname, "rb") as f: f.seek(0, os.SEEK_END) return f.tell() -def _read_segments_file(raw, data, idx, fi, start, stop, cals, mult, - dtype, n_channels=None, offset=0, trigger_ch=None): +def _read_segments_file( + raw, + data, + idx, + fi, + start, + stop, + cals, + mult, + dtype, + n_channels=None, + offset=0, + trigger_ch=None, +): """Read a chunk of raw data.""" if n_channels is None: - n_channels = raw._raw_extras[fi]['orig_nchan'] + n_channels = raw._raw_extras[fi]["orig_nchan"] n_bytes = np.dtype(dtype).itemsize # data_offset and data_left count data samples (channels x time points), @@ -212,17 +229,19 @@ def _read_segments_file(raw, data, idx, fi, start, stop, cals, mult, # Read up to 100 MB of data at a time, block_size is in data samples block_size = ((int(100e6) // n_bytes) // n_channels) * n_channels block_size = min(data_left, block_size) - with open(raw._filenames[fi], 'rb', buffering=0) as fid: + with open(raw._filenames[fi], "rb", buffering=0) as fid: fid.seek(data_offset) # extract data in chunks for sample_start in np.arange(0, data_left, block_size) // n_channels: count = min(block_size, data_left - sample_start * n_channels) block = np.fromfile(fid, dtype, count) if block.size != count: - raise RuntimeError('Incorrect number of samples (%s != %s), ' - 'please report this error to MNE-Python ' - 'developers' % (block.size, count)) - block = block.reshape(n_channels, -1, order='F') + raise RuntimeError( + "Incorrect number of samples (%s != %s), " + "please report this error to MNE-Python " + "developers" % (block.size, count) + ) + block = block.reshape(n_channels, -1, order="F") n_samples = block.shape[1] # = count // n_channels sample_stop = sample_start + n_samples if trigger_ch is not None: @@ -234,13 +253,12 @@ def _read_segments_file(raw, data, idx, fi, start, stop, cals, mult, def read_str(fid, count=1): """Read string from a binary file in a python version compatible way.""" - dtype = np.dtype('>S%i' % count) + dtype = np.dtype(">S%i" % count) string = fid.read(dtype.itemsize) data = np.frombuffer(string, dtype=dtype)[0] - bytestr = b''.join([data[0:data.index(b'\x00') if - b'\x00' in data else count]]) + bytestr = b"".join([data[0 : data.index(b"\x00") if b"\x00" in data else count]]) - return str(bytestr.decode('ascii')) # Return native str type for Py2/3 + return str(bytestr.decode("ascii")) # Return native str type for Py2/3 def _create_chs(ch_names, cals, ch_coil, ch_kind, eog, ecg, emg, misc): @@ -263,13 +281,21 @@ def _create_chs(ch_names, cals, ch_coil, ch_kind, eog, ecg, emg, misc): coil_type = ch_coil kind = ch_kind - chan_info = {'cal': cals[idx], 'logno': idx + 1, 'scanno': idx + 1, - 'range': 1.0, 'unit_mul': FIFF.FIFF_UNITM_NONE, - 'ch_name': ch_name, 'unit': FIFF.FIFF_UNIT_V, - 'coord_frame': FIFF.FIFFV_COORD_HEAD, - 'coil_type': coil_type, 'kind': kind, 'loc': np.zeros(12)} + chan_info = { + "cal": cals[idx], + "logno": idx + 1, + "scanno": idx + 1, + "range": 1.0, + "unit_mul": FIFF.FIFF_UNITM_NONE, + "ch_name": ch_name, + "unit": FIFF.FIFF_UNIT_V, + "coord_frame": FIFF.FIFFV_COORD_HEAD, + "coil_type": coil_type, + "kind": kind, + "loc": np.zeros(12), + } if coil_type == FIFF.FIFFV_COIL_EEG: - chan_info['loc'][:3] = np.nan + chan_info["loc"][:3] = np.nan chs.append(chan_info) return chs @@ -295,7 +321,7 @@ def _synthesize_stim_channel(events, n_samples): # create output buffer stim_channel = np.zeros(n_samples, int) for onset, duration, trigger in events: - stim_channel[onset:onset + duration] = trigger + stim_channel[onset : onset + duration] = trigger return stim_channel @@ -304,13 +330,15 @@ def _construct_bids_filename(base, ext, part_idx, validate=True): # insert index in filename dirname = op.dirname(base) base = op.basename(base) - deconstructed_base = base.split('_') + deconstructed_base = base.split("_") if len(deconstructed_base) < 2 and validate: - raise ValueError('Filename base must end with an underscore followed ' - f'by the modality (e.g., _eeg or _meg), got {base}') + raise ValueError( + "Filename base must end with an underscore followed " + f"by the modality (e.g., _eeg or _meg), got {base}" + ) suffix = deconstructed_base[-1] - base = '_'.join(deconstructed_base[:-1]) - use_fname = '{}_split-{:02}_{}{}'.format(base, part_idx, suffix, ext) + base = "_".join(deconstructed_base[:-1]) + use_fname = "{}_split-{:02}_{}{}".format(base, part_idx, suffix, ext) if dirname: use_fname = op.join(dirname, use_fname) return use_fname diff --git a/mne/io/what.py b/mne/io/what.py index fda10db46b0..a8d27740c4e 100644 --- a/mne/io/what.py +++ b/mne/io/what.py @@ -38,32 +38,33 @@ def what(fname): from ..event import read_events from ..proj import read_proj from .meas_info import read_fiducials - _check_fname(fname, overwrite='read', must_exist=True) + + _check_fname(fname, overwrite="read", must_exist=True) checks = OrderedDict() - checks['raw'] = read_raw_fif - checks['ica'] = read_ica - checks['epochs'] = read_epochs - checks['evoked'] = read_evokeds - checks['forward'] = read_forward_solution - checks['inverse'] = read_inverse_operator - checks['src'] = read_source_spaces - checks['bem solution'] = read_bem_solution - checks['bem surfaces'] = read_bem_surfaces - checks['cov'] = read_cov - checks['transform'] = read_trans - checks['events'] = read_events - checks['fiducials'] = read_fiducials - checks['proj'] = read_proj + checks["raw"] = read_raw_fif + checks["ica"] = read_ica + checks["epochs"] = read_epochs + checks["evoked"] = read_evokeds + checks["forward"] = read_forward_solution + checks["inverse"] = read_inverse_operator + checks["src"] = read_source_spaces + checks["bem solution"] = read_bem_solution + checks["bem surfaces"] = read_bem_surfaces + checks["cov"] = read_cov + checks["transform"] = read_trans + checks["events"] = read_events + checks["fiducials"] = read_fiducials + checks["proj"] = read_proj for what, func in checks.items(): args = signature(func).parameters - assert 'verbose' in args, func - kwargs = dict(verbose='error') - if 'preload' in args: - kwargs['preload'] = False + assert "verbose" in args, func + kwargs = dict(verbose="error") + if "preload" in args: + kwargs["preload"] = False try: func(fname, **kwargs) except Exception as exp: - logger.debug('Not %s: %s' % (what, exp)) + logger.debug("Not %s: %s" % (what, exp)) else: return what - return 'unknown' + return "unknown" diff --git a/mne/io/write.py b/mne/io/write.py index fb832cec9ca..fcaf7261bad 100644 --- a/mne/io/write.py +++ b/mne/io/write.py @@ -20,7 +20,7 @@ # to treat as meas_date=None. This one should be impossible for systems # to write -- the second field is microseconds, so anything >= 1e6 # should be moved into the first field (seconds). -DATE_NONE = (0, 2 ** 31 - 1) +DATE_NONE = (0, 2**31 - 1) def _write(fid, data, kind, data_size, FIFFT_TYPE, dtype): @@ -31,10 +31,10 @@ def _write(fid, data, kind, data_size, FIFFT_TYPE, dtype): # XXX for string types the data size is used as # computed in ``write_string``. - fid.write(np.array(kind, dtype='>i4').tobytes()) - fid.write(np.array(FIFFT_TYPE, dtype='>i4').tobytes()) - fid.write(np.array(data_size, dtype='>i4').tobytes()) - fid.write(np.array(FIFF.FIFFV_NEXT_SEQ, dtype='>i4').tobytes()) + fid.write(np.array(kind, dtype=">i4").tobytes()) + fid.write(np.array(FIFFT_TYPE, dtype=">i4").tobytes()) + fid.write(np.array(data_size, dtype=">i4").tobytes()) + fid.write(np.array(FIFF.FIFFV_NEXT_SEQ, dtype=">i4").tobytes()) fid.write(np.array(data, dtype=dtype).tobytes()) @@ -43,12 +43,11 @@ def _get_split_size(split_size): if isinstance(split_size, str): exp = dict(MB=20, GB=30).get(split_size[-2:], None) if exp is None: - raise ValueError('split_size has to end with either' - '"MB" or "GB"') - split_size = int(float(split_size[:-2]) * 2 ** exp) + raise ValueError("split_size has to end with either" '"MB" or "GB"') + split_size = int(float(split_size[:-2]) * 2**exp) if split_size > 2147483648: - raise ValueError('split_size cannot be larger than 2GB') + raise ValueError("split_size cannot be larger than 2GB") return split_size @@ -57,11 +56,11 @@ def _get_split_size(split_size): def write_nop(fid, last=False): """Write a FIFF_NOP.""" - fid.write(np.array(FIFF.FIFF_NOP, dtype='>i4').tobytes()) - fid.write(np.array(FIFF.FIFFT_VOID, dtype='>i4').tobytes()) - fid.write(np.array(0, dtype='>i4').tobytes()) + fid.write(np.array(FIFF.FIFF_NOP, dtype=">i4").tobytes()) + fid.write(np.array(FIFF.FIFFT_VOID, dtype=">i4").tobytes()) + fid.write(np.array(0, dtype=">i4").tobytes()) next_ = FIFF.FIFFV_NEXT_NONE if last else FIFF.FIFFV_NEXT_SEQ - fid.write(np.array(next_, dtype='>i4').tobytes()) + fid.write(np.array(next_, dtype=">i4").tobytes()) INT32_MAX = 2147483647 @@ -71,51 +70,50 @@ def write_int(fid, kind, data): """Write a 32-bit integer tag to a fif file.""" data_size = 4 data = np.asarray(data) - if data.dtype.kind not in 'uib' and data.size > 0: - raise TypeError( - f'Cannot safely write data with dtype {data.dtype} as int') + if data.dtype.kind not in "uib" and data.size > 0: + raise TypeError(f"Cannot safely write data with dtype {data.dtype} as int") max_val = data.max() if data.size > 0 else 0 if max_val > INT32_MAX: raise TypeError( - f'Value {max_val} exceeds maximum allowed ({INT32_MAX}) for ' - f'tag {kind}') - data = data.astype('>i4').T - _write(fid, data, kind, data_size, FIFF.FIFFT_INT, '>i4') + f"Value {max_val} exceeds maximum allowed ({INT32_MAX}) for " f"tag {kind}" + ) + data = data.astype(">i4").T + _write(fid, data, kind, data_size, FIFF.FIFFT_INT, ">i4") def write_double(fid, kind, data): """Write a double-precision floating point tag to a fif file.""" data_size = 8 - data = np.array(data, dtype='>f8').T - _write(fid, data, kind, data_size, FIFF.FIFFT_DOUBLE, '>f8') + data = np.array(data, dtype=">f8").T + _write(fid, data, kind, data_size, FIFF.FIFFT_DOUBLE, ">f8") def write_float(fid, kind, data): """Write a single-precision floating point tag to a fif file.""" data_size = 4 - data = np.array(data, dtype='>f4').T - _write(fid, data, kind, data_size, FIFF.FIFFT_FLOAT, '>f4') + data = np.array(data, dtype=">f4").T + _write(fid, data, kind, data_size, FIFF.FIFFT_FLOAT, ">f4") def write_dau_pack16(fid, kind, data): """Write a dau_pack16 tag to a fif file.""" data_size = 2 - data = np.array(data, dtype='>i2').T - _write(fid, data, kind, data_size, FIFF.FIFFT_DAU_PACK16, '>i2') + data = np.array(data, dtype=">i2").T + _write(fid, data, kind, data_size, FIFF.FIFFT_DAU_PACK16, ">i2") def write_complex64(fid, kind, data): """Write a 64 bit complex floating point tag to a fif file.""" data_size = 8 - data = np.array(data, dtype='>c8').T - _write(fid, data, kind, data_size, FIFF.FIFFT_COMPLEX_FLOAT, '>c8') + data = np.array(data, dtype=">c8").T + _write(fid, data, kind, data_size, FIFF.FIFFT_COMPLEX_FLOAT, ">c8") def write_complex128(fid, kind, data): """Write a 128 bit complex floating point tag to a fif file.""" data_size = 16 - data = np.array(data, dtype='>c16').T - _write(fid, data, kind, data_size, FIFF.FIFFT_COMPLEX_FLOAT, '>c16') + data = np.array(data, dtype=">c16").T + _write(fid, data, kind, data_size, FIFF.FIFFT_COMPLEX_FLOAT, ">c16") def write_julian(fid, kind, data): @@ -123,15 +121,15 @@ def write_julian(fid, kind, data): assert len(data) == 3 data_size = 4 jd = np.sum(_cal_to_julian(*data)) - data = np.array(jd, dtype='>i4') - _write(fid, data, kind, data_size, FIFF.FIFFT_JULIAN, '>i4') + data = np.array(jd, dtype=">i4") + _write(fid, data, kind, data_size, FIFF.FIFFT_JULIAN, ">i4") def write_string(fid, kind, data): """Write a string tag.""" - str_data = str(data).encode('latin1') + str_data = str(data).encode("latin1") data_size = len(str_data) # therefore compute size here - my_dtype = '>a' # py2/3 compatible on writing -- don't ask me why + my_dtype = ">a" # py2/3 compatible on writing -- don't ask me why if data_size > 0: _write(fid, str_data, kind, data_size, FIFF.FIFFT_STRING, my_dtype) @@ -143,28 +141,27 @@ def write_name_list(fid, kind, data): ---------- data : list of strings """ - write_string(fid, kind, ':'.join(data)) + write_string(fid, kind, ":".join(data)) def write_name_list_sanitized(fid, kind, lst, name): """Write a sanitized, colon-separated list of names.""" - write_string(fid, kind, _safe_name_list(lst, 'write', name)) + write_string(fid, kind, _safe_name_list(lst, "write", name)) def _safe_name_list(lst, operation, name): - if operation == 'write': + if operation == "write": assert isinstance(lst, (list, tuple, np.ndarray)), type(lst) - if any('{COLON}' in val for val in lst): - raise ValueError( - f'The substring "{{COLON}}" in {name} not supported.') - return ':'.join(val.replace(':', '{COLON}') for val in lst) + if any("{COLON}" in val for val in lst): + raise ValueError(f'The substring "{{COLON}}" in {name} not supported.') + return ":".join(val.replace(":", "{COLON}") for val in lst) else: # take a sanitized string and return a list of strings - assert operation == 'read' + assert operation == "read" assert lst is None or isinstance(lst, str) if not lst: # None or empty string return [] - return [val.replace('{COLON}', ':') for val in lst.split(':')] + return [val.replace("{COLON}", ":") for val in lst.split(":")] def write_float_matrix(fid, kind, mat): @@ -174,16 +171,16 @@ def write_float_matrix(fid, kind, mat): data_size = 4 * mat.size + 4 * (mat.ndim + 1) - fid.write(np.array(kind, dtype='>i4').tobytes()) - fid.write(np.array(FIFFT_MATRIX_FLOAT, dtype='>i4').tobytes()) - fid.write(np.array(data_size, dtype='>i4').tobytes()) - fid.write(np.array(FIFF.FIFFV_NEXT_SEQ, dtype='>i4').tobytes()) - fid.write(np.array(mat, dtype='>f4').tobytes()) + fid.write(np.array(kind, dtype=">i4").tobytes()) + fid.write(np.array(FIFFT_MATRIX_FLOAT, dtype=">i4").tobytes()) + fid.write(np.array(data_size, dtype=">i4").tobytes()) + fid.write(np.array(FIFF.FIFFV_NEXT_SEQ, dtype=">i4").tobytes()) + fid.write(np.array(mat, dtype=">f4").tobytes()) dims = np.empty(mat.ndim + 1, dtype=np.int32) - dims[:mat.ndim] = mat.shape[::-1] + dims[: mat.ndim] = mat.shape[::-1] dims[-1] = mat.ndim - fid.write(np.array(dims, dtype='>i4').tobytes()) + fid.write(np.array(dims, dtype=">i4").tobytes()) check_fiff_length(fid) @@ -194,16 +191,16 @@ def write_double_matrix(fid, kind, mat): data_size = 8 * mat.size + 4 * (mat.ndim + 1) - fid.write(np.array(kind, dtype='>i4').tobytes()) - fid.write(np.array(FIFFT_MATRIX_DOUBLE, dtype='>i4').tobytes()) - fid.write(np.array(data_size, dtype='>i4').tobytes()) - fid.write(np.array(FIFF.FIFFV_NEXT_SEQ, dtype='>i4').tobytes()) - fid.write(np.array(mat, dtype='>f8').tobytes()) + fid.write(np.array(kind, dtype=">i4").tobytes()) + fid.write(np.array(FIFFT_MATRIX_DOUBLE, dtype=">i4").tobytes()) + fid.write(np.array(data_size, dtype=">i4").tobytes()) + fid.write(np.array(FIFF.FIFFV_NEXT_SEQ, dtype=">i4").tobytes()) + fid.write(np.array(mat, dtype=">f8").tobytes()) dims = np.empty(mat.ndim + 1, dtype=np.int32) - dims[:mat.ndim] = mat.shape[::-1] + dims[: mat.ndim] = mat.shape[::-1] dims[-1] = mat.ndim - fid.write(np.array(dims, dtype='>i4').tobytes()) + fid.write(np.array(dims, dtype=">i4").tobytes()) check_fiff_length(fid) @@ -214,17 +211,17 @@ def write_int_matrix(fid, kind, mat): data_size = 4 * mat.size + 4 * 3 - fid.write(np.array(kind, dtype='>i4').tobytes()) - fid.write(np.array(FIFFT_MATRIX_INT, dtype='>i4').tobytes()) - fid.write(np.array(data_size, dtype='>i4').tobytes()) - fid.write(np.array(FIFF.FIFFV_NEXT_SEQ, dtype='>i4').tobytes()) - fid.write(np.array(mat, dtype='>i4').tobytes()) + fid.write(np.array(kind, dtype=">i4").tobytes()) + fid.write(np.array(FIFFT_MATRIX_INT, dtype=">i4").tobytes()) + fid.write(np.array(data_size, dtype=">i4").tobytes()) + fid.write(np.array(FIFF.FIFFV_NEXT_SEQ, dtype=">i4").tobytes()) + fid.write(np.array(mat, dtype=">i4").tobytes()) dims = np.empty(3, dtype=np.int32) dims[0] = mat.shape[1] dims[1] = mat.shape[0] dims[2] = 2 - fid.write(np.array(dims, dtype='>i4').tobytes()) + fid.write(np.array(dims, dtype=">i4").tobytes()) check_fiff_length(fid) @@ -235,16 +232,16 @@ def write_complex_float_matrix(fid, kind, mat): data_size = 4 * 2 * mat.size + 4 * (mat.ndim + 1) - fid.write(np.array(kind, dtype='>i4').tobytes()) - fid.write(np.array(FIFFT_MATRIX_COMPLEX_FLOAT, dtype='>i4').tobytes()) - fid.write(np.array(data_size, dtype='>i4').tobytes()) - fid.write(np.array(FIFF.FIFFV_NEXT_SEQ, dtype='>i4').tobytes()) - fid.write(np.array(mat, dtype='>c8').tobytes()) + fid.write(np.array(kind, dtype=">i4").tobytes()) + fid.write(np.array(FIFFT_MATRIX_COMPLEX_FLOAT, dtype=">i4").tobytes()) + fid.write(np.array(data_size, dtype=">i4").tobytes()) + fid.write(np.array(FIFF.FIFFV_NEXT_SEQ, dtype=">i4").tobytes()) + fid.write(np.array(mat, dtype=">c8").tobytes()) dims = np.empty(mat.ndim + 1, dtype=np.int32) - dims[:mat.ndim] = mat.shape[::-1] + dims[: mat.ndim] = mat.shape[::-1] dims[-1] = mat.ndim - fid.write(np.array(dims, dtype='>i4').tobytes()) + fid.write(np.array(dims, dtype=">i4").tobytes()) check_fiff_length(fid) @@ -255,16 +252,16 @@ def write_complex_double_matrix(fid, kind, mat): data_size = 8 * 2 * mat.size + 4 * (mat.ndim + 1) - fid.write(np.array(kind, dtype='>i4').tobytes()) - fid.write(np.array(FIFFT_MATRIX_COMPLEX_DOUBLE, dtype='>i4').tobytes()) - fid.write(np.array(data_size, dtype='>i4').tobytes()) - fid.write(np.array(FIFF.FIFFV_NEXT_SEQ, dtype='>i4').tobytes()) - fid.write(np.array(mat, dtype='>c16').tobytes()) + fid.write(np.array(kind, dtype=">i4").tobytes()) + fid.write(np.array(FIFFT_MATRIX_COMPLEX_DOUBLE, dtype=">i4").tobytes()) + fid.write(np.array(data_size, dtype=">i4").tobytes()) + fid.write(np.array(FIFF.FIFFV_NEXT_SEQ, dtype=">i4").tobytes()) + fid.write(np.array(mat, dtype=">c16").tobytes()) dims = np.empty(mat.ndim + 1, dtype=np.int32) - dims[:mat.ndim] = mat.shape[::-1] + dims[: mat.ndim] = mat.shape[::-1] dims[-1] = mat.ndim - fid.write(np.array(dims, dtype='>i4').tobytes()) + fid.write(np.array(dims, dtype=">i4").tobytes()) check_fiff_length(fid) @@ -276,39 +273,45 @@ def get_machid(): ids : array (length 2, int32) The machine identifier used in MNE. """ - mac = b'%012x' % uuid.getnode() # byte conversion for Py3 - mac = re.findall(b'..', mac) # split string - mac += [b'00', b'00'] # add two more fields + mac = b"%012x" % uuid.getnode() # byte conversion for Py3 + mac = re.findall(b"..", mac) # split string + mac += [b"00", b"00"] # add two more fields # Convert to integer in reverse-order (for some reason) from codecs import encode - mac = b''.join([encode(h, 'hex_codec') for h in mac[::-1]]) + + mac = b"".join([encode(h, "hex_codec") for h in mac[::-1]]) ids = np.flipud(np.frombuffer(mac, np.int32, count=2)) return ids def get_new_file_id(): """Create a new file ID tag.""" - secs, usecs = divmod(time.time(), 1.) + secs, usecs = divmod(time.time(), 1.0) secs, usecs = int(secs), int(usecs * 1e6) - return {'machid': get_machid(), 'version': FIFF.FIFFC_VERSION, - 'secs': secs, 'usecs': usecs} + return { + "machid": get_machid(), + "version": FIFF.FIFFC_VERSION, + "secs": secs, + "usecs": usecs, + } def write_id(fid, kind, id_=None): """Write fiff id.""" id_ = _generate_meas_id() if id_ is None else id_ - data_size = 5 * 4 # The id comprises five integers - fid.write(np.array(kind, dtype='>i4').tobytes()) - fid.write(np.array(FIFF.FIFFT_ID_STRUCT, dtype='>i4').tobytes()) - fid.write(np.array(data_size, dtype='>i4').tobytes()) - fid.write(np.array(FIFF.FIFFV_NEXT_SEQ, dtype='>i4').tobytes()) + data_size = 5 * 4 # The id comprises five integers + fid.write(np.array(kind, dtype=">i4").tobytes()) + fid.write(np.array(FIFF.FIFFT_ID_STRUCT, dtype=">i4").tobytes()) + fid.write(np.array(data_size, dtype=">i4").tobytes()) + fid.write(np.array(FIFF.FIFFV_NEXT_SEQ, dtype=">i4").tobytes()) # Collect the bits together for one write - arr = np.array([id_['version'], - id_['machid'][0], id_['machid'][1], - id_['secs'], id_['usecs']], dtype='>i4') + arr = np.array( + [id_["version"], id_["machid"][0], id_["machid"][1], id_["secs"], id_["usecs"]], + dtype=">i4", + ) fid.write(arr.tobytes()) @@ -335,18 +338,18 @@ def start_file(fname, id_=None): ID to use for the FIFF_FILE_ID. """ if _file_like(fname): - logger.debug('Writing using %s I/O' % type(fname)) + logger.debug("Writing using %s I/O" % type(fname)) fid = fname fid.seek(0) else: fname = str(fname) - if op.splitext(fname)[1].lower() == '.gz': - logger.debug('Writing using gzip') + if op.splitext(fname)[1].lower() == ".gz": + logger.debug("Writing using gzip") # defaults to compression level 9, which is barely smaller but much # slower. 2 offers a good compromise. fid = GzipFile(fname, "wb", compresslevel=2) else: - logger.debug('Writing using normal I/O') + logger.debug("Writing using normal I/O") fid = open(fname, "wb") # Write the compulsory items write_id(fid, FIFF.FIFF_FILE_ID, id_) @@ -368,9 +371,11 @@ def check_fiff_length(fid, close=True): if fid.tell() > 2147483648: # 2 ** 31, FIFF uses signed 32-bit locations if close: fid.close() - raise OSError('FIFF file exceeded 2GB limit, please split file, reduce' - ' split_size (if possible), or save to a different ' - 'format') + raise OSError( + "FIFF file exceeded 2GB limit, please split file, reduce" + " split_size (if possible), or save to a different " + "format" + ) def end_file(fid): @@ -383,53 +388,53 @@ def end_file(fid): def write_coord_trans(fid, trans): """Write a coordinate transformation structure.""" data_size = 4 * 2 * 12 + 4 * 2 - fid.write(np.array(FIFF.FIFF_COORD_TRANS, dtype='>i4').tobytes()) - fid.write(np.array(FIFF.FIFFT_COORD_TRANS_STRUCT, dtype='>i4').tobytes()) - fid.write(np.array(data_size, dtype='>i4').tobytes()) - fid.write(np.array(FIFF.FIFFV_NEXT_SEQ, dtype='>i4').tobytes()) - fid.write(np.array(trans['from'], dtype='>i4').tobytes()) - fid.write(np.array(trans['to'], dtype='>i4').tobytes()) + fid.write(np.array(FIFF.FIFF_COORD_TRANS, dtype=">i4").tobytes()) + fid.write(np.array(FIFF.FIFFT_COORD_TRANS_STRUCT, dtype=">i4").tobytes()) + fid.write(np.array(data_size, dtype=">i4").tobytes()) + fid.write(np.array(FIFF.FIFFV_NEXT_SEQ, dtype=">i4").tobytes()) + fid.write(np.array(trans["from"], dtype=">i4").tobytes()) + fid.write(np.array(trans["to"], dtype=">i4").tobytes()) # The transform... - rot = trans['trans'][:3, :3] - move = trans['trans'][:3, 3] - fid.write(np.array(rot, dtype='>f4').tobytes()) - fid.write(np.array(move, dtype='>f4').tobytes()) + rot = trans["trans"][:3, :3] + move = trans["trans"][:3, 3] + fid.write(np.array(rot, dtype=">f4").tobytes()) + fid.write(np.array(move, dtype=">f4").tobytes()) # ...and its inverse - trans_inv = np.linalg.inv(trans['trans']) + trans_inv = np.linalg.inv(trans["trans"]) rot = trans_inv[:3, :3] move = trans_inv[:3, 3] - fid.write(np.array(rot, dtype='>f4').tobytes()) - fid.write(np.array(move, dtype='>f4').tobytes()) + fid.write(np.array(rot, dtype=">f4").tobytes()) + fid.write(np.array(move, dtype=">f4").tobytes()) def write_ch_info(fid, ch): """Write a channel information record to a fif file.""" data_size = 4 * 13 + 4 * 7 + 16 - fid.write(np.array(FIFF.FIFF_CH_INFO, dtype='>i4').tobytes()) - fid.write(np.array(FIFF.FIFFT_CH_INFO_STRUCT, dtype='>i4').tobytes()) - fid.write(np.array(data_size, dtype='>i4').tobytes()) - fid.write(np.array(FIFF.FIFFV_NEXT_SEQ, dtype='>i4').tobytes()) + fid.write(np.array(FIFF.FIFF_CH_INFO, dtype=">i4").tobytes()) + fid.write(np.array(FIFF.FIFFT_CH_INFO_STRUCT, dtype=">i4").tobytes()) + fid.write(np.array(data_size, dtype=">i4").tobytes()) + fid.write(np.array(FIFF.FIFFV_NEXT_SEQ, dtype=">i4").tobytes()) # Start writing fiffChInfoRec - fid.write(np.array(ch['scanno'], dtype='>i4').tobytes()) - fid.write(np.array(ch['logno'], dtype='>i4').tobytes()) - fid.write(np.array(ch['kind'], dtype='>i4').tobytes()) - fid.write(np.array(ch['range'], dtype='>f4').tobytes()) - fid.write(np.array(ch['cal'], dtype='>f4').tobytes()) - fid.write(np.array(ch['coil_type'], dtype='>i4').tobytes()) - fid.write(np.array(ch['loc'], dtype='>f4').tobytes()) # writing 12 values + fid.write(np.array(ch["scanno"], dtype=">i4").tobytes()) + fid.write(np.array(ch["logno"], dtype=">i4").tobytes()) + fid.write(np.array(ch["kind"], dtype=">i4").tobytes()) + fid.write(np.array(ch["range"], dtype=">f4").tobytes()) + fid.write(np.array(ch["cal"], dtype=">f4").tobytes()) + fid.write(np.array(ch["coil_type"], dtype=">i4").tobytes()) + fid.write(np.array(ch["loc"], dtype=">f4").tobytes()) # writing 12 values # unit and unit multiplier - fid.write(np.array(ch['unit'], dtype='>i4').tobytes()) - fid.write(np.array(ch['unit_mul'], dtype='>i4').tobytes()) + fid.write(np.array(ch["unit"], dtype=">i4").tobytes()) + fid.write(np.array(ch["unit_mul"], dtype=">i4").tobytes()) # Finally channel name - ch_name = ch['ch_name'][:15] - fid.write(np.array(ch_name, dtype='>c').tobytes()) - fid.write(b'\0' * (16 - len(ch_name))) + ch_name = ch["ch_name"][:15] + fid.write(np.array(ch_name, dtype=">c").tobytes()) + fid.write(b"\0" * (16 - len(ch_name))) def write_dig_points(fid, dig, block=False, coord_frame=None): @@ -441,42 +446,49 @@ def write_dig_points(fid, dig, block=False, coord_frame=None): if coord_frame is not None: write_int(fid, FIFF.FIFF_MNE_COORD_FRAME, coord_frame) for d in dig: - fid.write(np.array(FIFF.FIFF_DIG_POINT, '>i4').tobytes()) - fid.write(np.array(FIFF.FIFFT_DIG_POINT_STRUCT, '>i4').tobytes()) - fid.write(np.array(data_size, dtype='>i4').tobytes()) - fid.write(np.array(FIFF.FIFFV_NEXT_SEQ, '>i4').tobytes()) + fid.write(np.array(FIFF.FIFF_DIG_POINT, ">i4").tobytes()) + fid.write(np.array(FIFF.FIFFT_DIG_POINT_STRUCT, ">i4").tobytes()) + fid.write(np.array(data_size, dtype=">i4").tobytes()) + fid.write(np.array(FIFF.FIFFV_NEXT_SEQ, ">i4").tobytes()) # Start writing fiffDigPointRec - fid.write(np.array(d['kind'], '>i4').tobytes()) - fid.write(np.array(d['ident'], '>i4').tobytes()) - fid.write(np.array(d['r'][:3], '>f4').tobytes()) + fid.write(np.array(d["kind"], ">i4").tobytes()) + fid.write(np.array(d["ident"], ">i4").tobytes()) + fid.write(np.array(d["r"][:3], ">f4").tobytes()) if block: end_block(fid, FIFF.FIFFB_ISOTRAK) def write_float_sparse_rcs(fid, kind, mat): """Write a single-precision sparse compressed row matrix tag.""" - return write_float_sparse(fid, kind, mat, fmt='csr') + return write_float_sparse(fid, kind, mat, fmt="csr") def write_float_sparse_ccs(fid, kind, mat): """Write a single-precision sparse compressed column matrix tag.""" - return write_float_sparse(fid, kind, mat, fmt='csc') + return write_float_sparse(fid, kind, mat, fmt="csc") -def write_float_sparse(fid, kind, mat, fmt='auto'): +def write_float_sparse(fid, kind, mat, fmt="auto"): """Write a single-precision floating-point sparse matrix tag.""" from scipy import sparse from .tag import _matrix_coding_CCS, _matrix_coding_RCS - if fmt == 'auto': - fmt = 'csr' if isinstance(mat, sparse.csr_matrix) else 'csc' - if fmt == 'csr': + + if fmt == "auto": + fmt = "csr" if isinstance(mat, sparse.csr_matrix) else "csc" + if fmt == "csr": need = sparse.csr_matrix bits = _matrix_coding_RCS else: need = sparse.csc_matrix bits = _matrix_coding_CCS if not isinstance(mat, need): - raise TypeError('Must write %s, got %s' % (fmt.upper(), type(mat),)) + raise TypeError( + "Must write %s, got %s" + % ( + fmt.upper(), + type(mat), + ) + ) FIFFT_MATRIX = bits << 16 FIFFT_MATRIX_FLOAT_RCS = FIFF.FIFFT_FLOAT | FIFFT_MATRIX @@ -484,24 +496,24 @@ def write_float_sparse(fid, kind, mat, fmt='auto'): nrow = mat.shape[0] data_size = 4 * nnzm + 4 * nnzm + 4 * (nrow + 1) + 4 * 4 - fid.write(np.array(kind, dtype='>i4').tobytes()) - fid.write(np.array(FIFFT_MATRIX_FLOAT_RCS, dtype='>i4').tobytes()) - fid.write(np.array(data_size, dtype='>i4').tobytes()) - fid.write(np.array(FIFF.FIFFV_NEXT_SEQ, dtype='>i4').tobytes()) + fid.write(np.array(kind, dtype=">i4").tobytes()) + fid.write(np.array(FIFFT_MATRIX_FLOAT_RCS, dtype=">i4").tobytes()) + fid.write(np.array(data_size, dtype=">i4").tobytes()) + fid.write(np.array(FIFF.FIFFV_NEXT_SEQ, dtype=">i4").tobytes()) - fid.write(np.array(mat.data, dtype='>f4').tobytes()) - fid.write(np.array(mat.indices, dtype='>i4').tobytes()) - fid.write(np.array(mat.indptr, dtype='>i4').tobytes()) + fid.write(np.array(mat.data, dtype=">f4").tobytes()) + fid.write(np.array(mat.indices, dtype=">i4").tobytes()) + fid.write(np.array(mat.indptr, dtype=">i4").tobytes()) dims = [nnzm, mat.shape[0], mat.shape[1], 2] - fid.write(np.array(dims, dtype='>i4').tobytes()) + fid.write(np.array(dims, dtype=">i4").tobytes()) check_fiff_length(fid) def _generate_meas_id(): """Generate a new meas_id dict.""" id_ = dict() - id_['version'] = FIFF.FIFFC_VERSION - id_['machid'] = get_machid() - id_['secs'], id_['usecs'] = DATE_NONE + id_["version"] = FIFF.FIFFC_VERSION + id_["machid"] = get_machid() + id_["secs"], id_["usecs"] = DATE_NONE return id_ diff --git a/mne/label.py b/mne/label.py index 0a1fa9710a0..7167acff7fc 100644 --- a/mne/label.py +++ b/mne/label.py @@ -15,17 +15,35 @@ from .morph_map import read_morph_map from .parallel import parallel_func -from .source_estimate import (SourceEstimate, VolSourceEstimate, - _center_of_mass, extract_label_time_course, - spatial_src_adjacency) -from .source_space import (add_source_space_distances, SourceSpaces, - _ensure_src) +from .source_estimate import ( + SourceEstimate, + VolSourceEstimate, + _center_of_mass, + extract_label_time_course, + spatial_src_adjacency, +) +from .source_space import add_source_space_distances, SourceSpaces, _ensure_src from .stats.cluster_level import _find_clusters, _get_components -from .surface import (complete_surface_info, read_surface, fast_cross_3d, - _mesh_borders, mesh_edges, mesh_dist) -from .utils import (get_subjects_dir, _check_subject, logger, verbose, warn, - check_random_state, _validate_type, fill_doc, - _check_option, _check_fname) +from .surface import ( + complete_surface_info, + read_surface, + fast_cross_3d, + _mesh_borders, + mesh_edges, + mesh_dist, +) +from .utils import ( + get_subjects_dir, + _check_subject, + logger, + verbose, + warn, + check_random_state, + _validate_type, + fill_doc, + _check_option, + _check_fname, +) def _blend_colors(color_1, color_2): @@ -58,14 +76,14 @@ def _blend_colors(color_1, color_2): h_2, s_2, v_2 = rgb_to_hsv(r_2, g_2, b_2) hue_diff = abs(h_1 - h_2) if hue_diff < 0.5: - h = min(h_1, h_2) + hue_diff / 2. + h = min(h_1, h_2) + hue_diff / 2.0 else: - h = max(h_1, h_2) + (1. - hue_diff) / 2. - h %= 1. - s = (s_1 + s_2) / 2. - v = (v_1 + v_2) / 2. + h = max(h_1, h_2) + (1.0 - hue_diff) / 2.0 + h %= 1.0 + s = (s_1 + s_2) / 2.0 + v = (v_1 + v_2) / 2.0 r, g, b = hsv_to_rgb(h, s, v) - a = (a_1 + a_2) / 2. + a = (a_1 + a_2) / 2.0 color = (r, g, b, a) return color @@ -88,7 +106,7 @@ def _split_colors(color, n): """ r, g, b, a = color h, s, v = rgb_to_hsv(r, g, b) - gradient_range = np.sqrt(n / 10.) + gradient_range = np.sqrt(n / 10.0) if v > 0.5: v_max = min(0.95, v + gradient_range / 2) v_min = max(0.05, v_max - gradient_range) @@ -98,11 +116,19 @@ def _split_colors(color, n): hsv_colors = ((h, s, v_) for v_ in np.linspace(v_min, v_max, n)) rgb_colors = (hsv_to_rgb(h_, s_, v_) for h_, s_, v_ in hsv_colors) - rgba_colors = ((r_, g_, b_, a,) for r_, g_, b_ in rgb_colors) + rgba_colors = ( + ( + r_, + g_, + b_, + a, + ) + for r_, g_, b_ in rgb_colors + ) return tuple(rgba_colors) -def _n_colors(n, bytes_=False, cmap='hsv'): +def _n_colors(n, bytes_=False, cmap="hsv"): """Produce a list of n unique RGBA color tuples based on a colormap. Parameters @@ -120,12 +146,12 @@ def _n_colors(n, bytes_=False, cmap='hsv'): colors : array, shape (n, 4) RGBA color values. """ - n_max = 2 ** 10 + n_max = 2**10 if n > n_max: - raise NotImplementedError("Can't produce more than %i unique " - "colors" % n_max) + raise NotImplementedError("Can't produce more than %i unique " "colors" % n_max) from .viz.utils import _get_cmap + cm = _get_cmap(cmap) pos = np.linspace(0, 1, n, False) colors = cm(pos, bytes=bytes_) @@ -133,9 +159,10 @@ def _n_colors(n, bytes_=False, cmap='hsv'): # make sure colors are unique for ii, c in enumerate(colors): if np.any(np.all(colors[:ii] == c, 1)): - raise RuntimeError('Could not get %d unique colors from %s ' - 'colormap. Try using a different colormap.' - % (n, cmap)) + raise RuntimeError( + "Could not get %d unique colors from %s " + "colormap. Try using a different colormap." % (n, cmap) + ) return colors @@ -195,18 +222,30 @@ class Label: """ @verbose - def __init__(self, vertices=(), pos=None, values=None, hemi=None, - comment="", name=None, filename=None, subject=None, - color=None, *, verbose=None): # noqa: D102 + def __init__( + self, + vertices=(), + pos=None, + values=None, + hemi=None, + comment="", + name=None, + filename=None, + subject=None, + color=None, + *, + verbose=None, + ): # noqa: D102 # check parameters if not isinstance(hemi, str): - raise ValueError('hemi must be a string, not %s' % type(hemi)) + raise ValueError("hemi must be a string, not %s" % type(hemi)) vertices = np.asarray(vertices, int) if np.any(np.diff(vertices.astype(int)) <= 0): - raise ValueError('Vertices must be ordered in increasing order.') + raise ValueError("Vertices must be ordered in increasing order.") if color is not None: from matplotlib.colors import colorConverter + color = colorConverter.to_rgba(color) if values is None: @@ -220,8 +259,10 @@ def __init__(self, vertices=(), pos=None, values=None, hemi=None, pos = np.asarray(pos) if not (len(vertices) == len(values) == len(pos)): - raise ValueError("vertices, values and pos need to have same " - "length (number of vertices)") + raise ValueError( + "vertices, values and pos need to have same " + "length (number of vertices)" + ) # name if name is None and filename is not None: @@ -238,30 +279,32 @@ def __init__(self, vertices=(), pos=None, values=None, hemi=None, self.filename = filename def __setstate__(self, state): # noqa: D105 - self.vertices = state['vertices'] - self.pos = state['pos'] - self.values = state['values'] - self.hemi = state['hemi'] - self.comment = state['comment'] - self.subject = state.get('subject', None) - self.color = state.get('color', None) - self.name = state['name'] - self.filename = state['filename'] + self.vertices = state["vertices"] + self.pos = state["pos"] + self.values = state["values"] + self.hemi = state["hemi"] + self.comment = state["comment"] + self.subject = state.get("subject", None) + self.color = state.get("color", None) + self.name = state["name"] + self.filename = state["filename"] def __getstate__(self): # noqa: D105 - out = dict(vertices=self.vertices, - pos=self.pos, - values=self.values, - hemi=self.hemi, - comment=self.comment, - subject=self.subject, - color=self.color, - name=self.name, - filename=self.filename) + out = dict( + vertices=self.vertices, + pos=self.pos, + values=self.values, + hemi=self.hemi, + comment=self.comment, + subject=self.subject, + color=self.color, + name=self.name, + filename=self.filename, + ) return out def __repr__(self): # noqa: D105 - name = 'unknown, ' if self.subject is None else self.subject + ', ' + name = "unknown, " if self.subject is None else self.subject + ", " name += repr(self.name) if self.name is not None else "unnamed" n_vert = len(self) return "