diff --git a/docs/source/v1.9.md.inc b/docs/source/v1.9.md.inc index 67c724e07..b9d56ecc1 100644 --- a/docs/source/v1.9.md.inc +++ b/docs/source/v1.9.md.inc @@ -14,9 +14,9 @@ [//]: # (- Whatever (#000 by @whoever)) -[//]: # (### :bug: Bug fixes) +### :bug: Bug fixes -[//]: # (- Whatever (#000 by @whoever)) +- Fix incorrect inclusion of autobad channels when disabling `find_bad_channels_meg` for example after enabling it (#902 by @larsoner) ### :medical_symbol: Code health and infrastructure diff --git a/mne_bids_pipeline/_import_data.py b/mne_bids_pipeline/_import_data.py index aaf7b56e3..b0de58216 100644 --- a/mne_bids_pipeline/_import_data.py +++ b/mne_bids_pipeline/_import_data.py @@ -816,12 +816,11 @@ def _import_data_kwargs(*, config: SimpleNamespace, subject: str) -> dict: ) -def _get_run_type( - run: Optional[str], - task: Optional[str], -) -> str: +def _read_raw_msg( + bids_path_in: BIDSPath, run: Optional[str], task: Optional[str] +) -> tuple[str]: if run is None and task in ("noise", "rest"): run_type = dict(rest="resting-state", noise="empty-room")[task] else: run_type = "experimental" - return run_type + return f"Reading {run_type} recording: {bids_path_in.basename}", run_type diff --git a/mne_bids_pipeline/_run.py b/mne_bids_pipeline/_run.py index ca3bd1faf..2077f21ff 100644 --- a/mne_bids_pipeline/_run.py +++ b/mne_bids_pipeline/_run.py @@ -23,8 +23,10 @@ def failsafe_run( + *, get_input_fnames: Optional[Callable] = None, get_output_fnames: Optional[Callable] = None, + require_output: bool = True, ) -> Callable: def failsafe_run_decorator(func): @functools.wraps(func) # Preserve "identity" of original function @@ -36,6 +38,8 @@ def __mne_bids_pipeline_failsafe_wrapper__(*args, **kwargs): exec_params=exec_params, get_input_fnames=get_input_fnames, get_output_fnames=get_output_fnames, + require_output=require_output, + func_name=f"{__mne_bids_pipeline_step__}::{func.__name__}", ) t0 = time.time() log_info = pd.concat( @@ -116,7 +120,15 @@ def hash_file_path(path: pathlib.Path) -> str: class ConditionalStepMemory: - def __init__(self, *, exec_params, get_input_fnames, get_output_fnames): + def __init__( + self, + *, + exec_params: SimpleNamespace, + get_input_fnames: Optional[Callable], + get_output_fnames: Optional[Callable], + require_output: bool, + func_name: str, + ): memory_location = exec_params.memory_location if memory_location is True: use_location = exec_params.deriv_root / exec_params.memory_subdir @@ -134,6 +146,8 @@ def __init__(self, *, exec_params, get_input_fnames, get_output_fnames): self.get_input_fnames = get_input_fnames self.get_output_fnames = get_output_fnames self.memory_file_method = exec_params.memory_file_method + self.require_output = require_output + self.func_name = func_name def cache(self, func): def wrapper(*args, **kwargs): @@ -262,9 +276,19 @@ def wrapper(*args, **kwargs): # https://joblib.readthedocs.io/en/latest/memory.html#joblib.memory.MemorizedFunc.call # noqa: E501 if force_run or unknown_inputs or bad_out_files: - memorized_func.call(*args, **kwargs) + out_files = memorized_func.call(*args, **kwargs) + else: + out_files = memorized_func(*args, **kwargs) + if self.require_output: + assert isinstance(out_files, dict) and len(out_files), ( + f"Internal error: step must return non-empty out_files dict, got " + f"{type(out_files).__name__} for:\n{self.func_name}" + ) else: - memorized_func(*args, **kwargs) + assert out_files is None, ( + f"Internal error: step must return None, got {type(out_files)} " + f"for:\n{self.func_name}" + ) return wrapper diff --git a/mne_bids_pipeline/steps/init/_01_init_derivatives_dir.py b/mne_bids_pipeline/steps/init/_01_init_derivatives_dir.py index 2f17b0c77..e9cf373af 100644 --- a/mne_bids_pipeline/steps/init/_01_init_derivatives_dir.py +++ b/mne_bids_pipeline/steps/init/_01_init_derivatives_dir.py @@ -11,16 +11,14 @@ from ..._config_utils import _bids_kwargs, get_sessions, get_subjects from ..._logging import gen_log_kwargs, logger -from ..._run import failsafe_run +from ..._run import _prep_out_files, failsafe_run -def init_dataset(cfg) -> None: +@failsafe_run() +def init_dataset(cfg: SimpleNamespace, exec_params: SimpleNamespace) -> dict: """Prepare the pipeline directory in /derivatives.""" - fname_json = cfg.deriv_root / "dataset_description.json" - if fname_json.is_file(): - msg = "Output directories already exist …" - logger.info(**gen_log_kwargs(message=msg, emoji="✅")) - return + out_files = dict() + out_files["json"] = cfg.deriv_root / "dataset_description.json" logger.info(**gen_log_kwargs(message="Initializing output directories.")) cfg.deriv_root.mkdir(exist_ok=True, parents=True) @@ -38,10 +36,12 @@ def init_dataset(cfg) -> None: "URL": "n/a", } - _write_json(fname_json, ds_json, overwrite=True) + _write_json(out_files["json"], ds_json, overwrite=True) + return _prep_out_files( + exec_params=exec_params, out_files=out_files, bids_only=False + ) -@failsafe_run() def init_subject_dirs( *, cfg: SimpleNamespace, @@ -73,7 +73,7 @@ def get_config( def main(*, config): """Initialize the output directories.""" - init_dataset(cfg=get_config(config=config)) + init_dataset(cfg=get_config(config=config), exec_params=config.exec_params) # Don't bother with parallelization here as I/O operations are generally # not well parallelized (and this should be very fast anyway) for subject in get_subjects(config): diff --git a/mne_bids_pipeline/steps/preprocessing/_01_data_quality.py b/mne_bids_pipeline/steps/preprocessing/_01_data_quality.py index 1cbeca387..c928a5b40 100644 --- a/mne_bids_pipeline/steps/preprocessing/_01_data_quality.py +++ b/mne_bids_pipeline/steps/preprocessing/_01_data_quality.py @@ -5,7 +5,6 @@ import mne import pandas as pd -from mne_bids import BIDSPath from ..._config_utils import ( _do_mf_autobad, @@ -21,6 +20,7 @@ _get_mf_reference_run_path, _get_run_rest_noise_path, _import_data_kwargs, + _read_raw_msg, import_er_data, import_experimental_data, ) @@ -79,25 +79,96 @@ def assess_data_quality( out_files = dict() key = f"raw_task-{task}_run-{run}" bids_path_in = in_files.pop(key) + if key == "raw_task-noise_run-None": + bids_path_ref_in = in_files.pop("raw_ref_run") + else: + bids_path_ref_in = None + msg, _ = _read_raw_msg(bids_path_in=bids_path_in, run=run, task=task) + logger.info(**gen_log_kwargs(message=msg)) + if run is None and task == "noise": + raw = import_er_data( + cfg=cfg, + bids_path_er_in=bids_path_in, + bids_path_er_bads_in=None, + bids_path_ref_in=bids_path_ref_in, + bids_path_ref_bads_in=None, + prepare_maxwell_filter=True, + ) + else: + data_is_rest = run is None and task == "rest" + raw = import_experimental_data( + bids_path_in=bids_path_in, + bids_path_bads_in=None, + cfg=cfg, + data_is_rest=data_is_rest, + ) + preexisting_bads = set(raw.info["bads"]) + if _do_mf_autobad(cfg=cfg): - if key == "raw_task-noise_run-None": - bids_path_ref_in = in_files.pop("raw_ref_run") - else: - bids_path_ref_in = None - auto_scores = _find_bads_maxwell( + ( + auto_noisy_chs, + auto_flat_chs, + auto_scores, + ) = _find_bads_maxwell( cfg=cfg, exec_params=exec_params, - bids_path_in=bids_path_in, - bids_path_ref_in=bids_path_ref_in, + raw=raw, subject=subject, session=session, run=run, task=task, - out_files=out_files, ) + bads = sorted(set(raw.info["bads"] + auto_noisy_chs + auto_flat_chs)) + msg = f"Found {len(bads)} channel{_pl(bads)} as bad." + raw.info["bads"] = bads + del bads + logger.info(**gen_log_kwargs(message=msg)) else: - auto_scores = None - del key + auto_scores = auto_noisy_chs = auto_flat_chs = None + del key, raw + + # Always output the scores and bads TSV + out_files["auto_scores"] = bids_path_in.copy().update( + suffix="scores", + extension=".json", + root=cfg.deriv_root, + split=None, + check=False, + session=session, + subject=subject, + ) + _write_json(out_files["auto_scores"], auto_scores) + + # Write the bad channels to disk. + out_files["bads_tsv"] = _bads_path( + cfg=cfg, + bids_path_in=bids_path_in, + subject=subject, + session=session, + ) + bads_for_tsv = [] + reasons = [] + + if auto_flat_chs: + bads_for_tsv.extend(auto_flat_chs) + reasons.extend(["auto-flat"] * len(auto_flat_chs)) + preexisting_bads -= set(auto_flat_chs) + + if auto_noisy_chs is not None: + bads_for_tsv.extend(auto_noisy_chs) + reasons.extend(["auto-noisy"] * len(auto_noisy_chs)) + preexisting_bads -= set(auto_noisy_chs) + + preexisting_bads = sorted(preexisting_bads) + if preexisting_bads: + bads_for_tsv.extend(preexisting_bads) + reasons.extend( + ["pre-existing (before MNE-BIDS-pipeline was run)"] * len(preexisting_bads) + ) + + tsv_data = pd.DataFrame(dict(name=bads_for_tsv, reason=reasons)) + tsv_data = tsv_data.sort_values(by="name") + tsv_data.to_csv(out_files["bads_tsv"], sep="\t", index=False) # Report with _open_report( @@ -145,45 +216,25 @@ def _find_bads_maxwell( *, cfg: SimpleNamespace, exec_params: SimpleNamespace, - bids_path_in: BIDSPath, - bids_path_ref_in: Optional[BIDSPath], + raw: mne.io.Raw, subject: str, session: Optional[str], run: Optional[str], task: Optional[str], - out_files: dict, ): - if cfg.find_flat_channels_meg and not cfg.find_noisy_channels_meg: - msg = "Finding flat channels." - elif cfg.find_noisy_channels_meg and not cfg.find_flat_channels_meg: - msg = "Finding noisy channels using Maxwell filtering." + if cfg.find_flat_channels_meg: + if cfg.find_noisy_channels_meg: + msg = "Finding flat channels and noisy channels using Maxwell filtering." + else: + msg = "Finding flat channels." else: - msg = "Finding flat channels and noisy channels using Maxwell filtering." + assert cfg.find_noisy_channels_meg + msg = "Finding noisy channels using Maxwell filtering." logger.info(**gen_log_kwargs(message=msg)) - if run is None and task == "noise": - raw = import_er_data( - cfg=cfg, - bids_path_er_in=bids_path_in, - bids_path_er_bads_in=None, - bids_path_ref_in=bids_path_ref_in, - bids_path_ref_bads_in=None, - prepare_maxwell_filter=True, - ) - else: - data_is_rest = run is None and task == "rest" - raw = import_experimental_data( - bids_path_in=bids_path_in, - bids_path_bads_in=None, - cfg=cfg, - data_is_rest=data_is_rest, - ) - # Filter the data manually before passing it to find_bad_channels_maxwell() # This reduces memory usage, as we can control the number of jobs used # during filtering. - preexisting_bads = raw.info["bads"].copy() - bads = preexisting_bads.copy() raw_filt = raw.copy().filter(l_freq=None, h_freq=40, n_jobs=1) ( auto_noisy_chs, @@ -209,7 +260,8 @@ def _find_bads_maxwell( else: msg = "Found no flat channels." logger.info(**gen_log_kwargs(message=msg)) - bads.extend(auto_flat_chs) + else: + auto_flat_chs = [] if cfg.find_noisy_channels_meg: if auto_noisy_chs: @@ -222,56 +274,8 @@ def _find_bads_maxwell( msg = "Found no noisy channels." logger.info(**gen_log_kwargs(message=msg)) - bads.extend(auto_noisy_chs) - - bads = sorted(set(bads)) - msg = f"Found {len(bads)} channel{_pl(bads)} as bad." - raw.info["bads"] = bads - del bads - logger.info(**gen_log_kwargs(message=msg)) - - if cfg.find_noisy_channels_meg: - out_files["auto_scores"] = bids_path_in.copy().update( - suffix="scores", - extension=".json", - root=cfg.deriv_root, - split=None, - check=False, - session=session, - subject=subject, - ) - _write_json(out_files["auto_scores"], auto_scores) - - # Write the bad channels to disk. - out_files["bads_tsv"] = _bads_path( - cfg=cfg, - bids_path_in=bids_path_in, - subject=subject, - session=session, - ) - bads_for_tsv = [] - reasons = [] - - if cfg.find_flat_channels_meg: - bads_for_tsv.extend(auto_flat_chs) - reasons.extend(["auto-flat"] * len(auto_flat_chs)) - preexisting_bads = set(preexisting_bads) - set(auto_flat_chs) - - if cfg.find_noisy_channels_meg: - bads_for_tsv.extend(auto_noisy_chs) - reasons.extend(["auto-noisy"] * len(auto_noisy_chs)) - preexisting_bads = set(preexisting_bads) - set(auto_noisy_chs) - - preexisting_bads = list(preexisting_bads) - if preexisting_bads: - bads_for_tsv.extend(preexisting_bads) - reasons.extend( - ["pre-existing (before MNE-BIDS-pipeline was run)"] * len(preexisting_bads) - ) - - tsv_data = pd.DataFrame(dict(name=bads_for_tsv, reason=reasons)) - tsv_data = tsv_data.sort_values(by="name") - tsv_data.to_csv(out_files["bads_tsv"], sep="\t", index=False) + else: + auto_noisy_chs = [] # Interaction if exec_params.interactive and cfg.find_noisy_channels_meg: @@ -280,7 +284,7 @@ def _find_bads_maxwell( plot_auto_scores(auto_scores, ch_types=cfg.ch_types) plt.show() - return auto_scores + return auto_noisy_chs, auto_flat_chs, auto_scores def get_config( diff --git a/mne_bids_pipeline/steps/preprocessing/_04_frequency_filter.py b/mne_bids_pipeline/steps/preprocessing/_04_frequency_filter.py index fd9c6c874..6248b7bad 100644 --- a/mne_bids_pipeline/steps/preprocessing/_04_frequency_filter.py +++ b/mne_bids_pipeline/steps/preprocessing/_04_frequency_filter.py @@ -30,8 +30,8 @@ ) from ..._import_data import ( _get_run_rest_noise_path, - _get_run_type, _import_data_kwargs, + _read_raw_msg, import_er_data, import_experimental_data, ) @@ -167,9 +167,7 @@ def filter_data( in_key = f"raw_task-{task}_run-{run}" bids_path_in = in_files.pop(in_key) bids_path_bads_in = in_files.pop(f"{in_key}-bads", None) - - run_type = _get_run_type(run=run, task=task) - msg = f"Reading {run_type} recording: " f"{bids_path_in.basename}" + msg, run_type = _read_raw_msg(bids_path_in=bids_path_in, run=run, task=task) logger.info(**gen_log_kwargs(message=msg)) if cfg.use_maxwell_filter: raw = mne.io.read_raw_fif(bids_path_in) diff --git a/mne_bids_pipeline/steps/preprocessing/_05_regress_artifact.py b/mne_bids_pipeline/steps/preprocessing/_05_regress_artifact.py index 9fce737cc..cf1a9b932 100644 --- a/mne_bids_pipeline/steps/preprocessing/_05_regress_artifact.py +++ b/mne_bids_pipeline/steps/preprocessing/_05_regress_artifact.py @@ -12,7 +12,7 @@ get_sessions, get_subjects, ) -from ..._import_data import _get_run_rest_noise_path, _get_run_type, _import_data_kwargs +from ..._import_data import _get_run_rest_noise_path, _import_data_kwargs, _read_raw_msg from ..._logging import gen_log_kwargs, logger from ..._parallel import get_parallel_backend, parallel_func from ..._report import _add_raw, _open_report @@ -59,8 +59,7 @@ def run_regress_artifact( in_key = f"raw_task-{task}_run-{run}" bids_path_in = in_files.pop(in_key) out_files[in_key] = bids_path_in.copy().update(processing="regress") - run_type = _get_run_type(run=run, task=task) - msg = f"Reading {run_type} recording: " f"{bids_path_in.basename}" + msg = _read_raw_msg(bids_path_in=bids_path_in, run=run, task=task) logger.info(**gen_log_kwargs(message=msg)) raw = mne.io.read_raw_fif(bids_path_in).load_data() projs = raw.info["projs"] diff --git a/mne_bids_pipeline/steps/sensor/_06_make_cov.py b/mne_bids_pipeline/steps/sensor/_06_make_cov.py index e3c8cdc9e..6a3e8f026 100644 --- a/mne_bids_pipeline/steps/sensor/_06_make_cov.py +++ b/mne_bids_pipeline/steps/sensor/_06_make_cov.py @@ -79,8 +79,8 @@ def get_input_fnames_cov( root=cfg.deriv_root, check=False, ) - run_type = "resting-state" if cfg.noise_cov == "rest" else "empty-room" - if run_type == "resting-state": + cov_type = "resting-state" if cfg.noise_cov == "rest" else "empty-room" + if cov_type == "resting-state": bids_path_raw_noise.task = "rest" else: bids_path_raw_noise.task = "noise" @@ -134,8 +134,8 @@ def compute_cov_from_raw( out_files: dict, ) -> mne.Covariance: fname_raw = in_files.pop("raw") - run_type = "resting-state" if fname_raw.task == "rest" else "empty-room" - msg = f"Computing regularized covariance based on {run_type} recording." + cov_type = "resting-state" if fname_raw.task == "rest" else "empty-room" + msg = f"Computing regularized covariance based on {cov_type} recording." logger.info(**gen_log_kwargs(message=msg)) msg = f"Input: {fname_raw.basename}" logger.info(**gen_log_kwargs(message=msg)) diff --git a/mne_bids_pipeline/steps/sensor/_99_group_average.py b/mne_bids_pipeline/steps/sensor/_99_group_average.py index 923b61ccb..d7c70da0c 100644 --- a/mne_bids_pipeline/steps/sensor/_99_group_average.py +++ b/mne_bids_pipeline/steps/sensor/_99_group_average.py @@ -655,6 +655,7 @@ def average_full_epochs_report( in_files: dict, ) -> dict: """Add decoding results to the grand average report.""" + out_files = dict() with _open_report( cfg=cfg, exec_params=exec_params, subject=subject, session=session ) as report: @@ -697,7 +698,7 @@ def average_full_epochs_report( ) # close figure to save memory plt.close(fig) - return _prep_out_files(exec_params=exec_params, out_files=dict()) + return _prep_out_files(exec_params=exec_params, out_files=out_files) @failsafe_run( diff --git a/mne_bids_pipeline/tests/test_functions.py b/mne_bids_pipeline/tests/test_functions.py new file mode 100644 index 000000000..f4d64adf4 --- /dev/null +++ b/mne_bids_pipeline/tests/test_functions.py @@ -0,0 +1,64 @@ +"""Test some properties of our core processing-step functions.""" + +import ast +import inspect + +import pytest + +from mne_bids_pipeline._config_utils import _get_step_modules + +# mne_bids_pipeline.init._01_init_derivatives_dir: +FLAT_MODULES = {x.__name__: x for x in sum(_get_step_modules().values(), ())} + + +@pytest.mark.parametrize("module_name", list(FLAT_MODULES)) +def test_all_functions_return(module_name): + """Test that all functions decorated with failsafe_run return a dict.""" + # Find the functions within the module that use the failsafe_run decorator + module = FLAT_MODULES[module_name] + funcs = list() + for name in dir(module): + obj = getattr(module, name) + if not callable(obj): + continue + if getattr(obj, "__module__", None) != module_name: + continue + if not hasattr(obj, "__wrapped__"): + continue + # All our failsafe_run decorated functions should look like this + assert "__mne_bids_pipeline_failsafe_wrapper__" in repr(obj.__code__) + funcs.append(obj) + # Some module names we know don't have any + if module_name.split(".")[-1] in ("_01_recon_all",): + assert len(funcs) == 0 + return + + assert len(funcs) != 0, f"No failsafe_runs functions found in {module_name}" + + # Adapted from numpydoc RT01 validation + def get_returns_not_on_nested_functions(node): + returns = [node] if isinstance(node, ast.Return) else [] + for child in ast.iter_child_nodes(node): + # Ignore nested functions and its subtrees. + if not isinstance(child, ast.FunctionDef): + child_returns = get_returns_not_on_nested_functions(child) + returns.extend(child_returns) + return returns + + for func in funcs: + what = f"{module_name}.{func.__name__}" + tree = ast.parse(inspect.getsource(func.__wrapped__)).body + if func.__closure__[-1].cell_contents is False: + continue # last closure node is require_output=False + assert tree, f"Failed to parse source code for {what}" + returns = get_returns_not_on_nested_functions(tree[0]) + return_values = [r.value for r in returns] + # Replace Constant nodes valued None for None. + for i, v in enumerate(return_values): + if isinstance(v, ast.Constant) and v.value is None: + return_values[i] = None + assert len(return_values), f"Function does not return anything: {what}" + for r in return_values: + assert ( + isinstance(r, ast.Call) and r.func.id == "_prep_out_files" + ), f"Function does _prep_out_files: {what}"