Skip to content

Commit

Permalink
MAINT: Ensure input changes cause output changes
Browse files Browse the repository at this point in the history
  • Loading branch information
larsoner committed Mar 26, 2024
1 parent ed5939a commit 9433293
Show file tree
Hide file tree
Showing 10 changed files with 212 additions and 123 deletions.
4 changes: 2 additions & 2 deletions docs/source/v1.9.md.inc
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

[//]: # (- Whatever (#000 by @whoever))

[//]: # (### :bug: Bug fixes)
### :bug: Bug fixes

[//]: # (- Whatever (#000 by @whoever))
- Fix incorrect inclusion of autobad channels when disabling `find_bad_channels_meg` for example after enabling it (#902 by @larsoner)

### :medical_symbol: Code health and infrastructure

Expand Down
9 changes: 4 additions & 5 deletions mne_bids_pipeline/_import_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,12 +816,11 @@ def _import_data_kwargs(*, config: SimpleNamespace, subject: str) -> dict:
)


def _get_run_type(
run: Optional[str],
task: Optional[str],
) -> str:
def _read_raw_msg(
bids_path_in: BIDSPath, run: Optional[str], task: Optional[str]
) -> tuple[str]:
if run is None and task in ("noise", "rest"):
run_type = dict(rest="resting-state", noise="empty-room")[task]
else:
run_type = "experimental"
return run_type
return f"Reading {run_type} recording: {bids_path_in.basename}", run_type
30 changes: 27 additions & 3 deletions mne_bids_pipeline/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@


def failsafe_run(
*,
get_input_fnames: Optional[Callable] = None,
get_output_fnames: Optional[Callable] = None,
require_output: bool = True,
) -> Callable:
def failsafe_run_decorator(func):
@functools.wraps(func) # Preserve "identity" of original function
Expand All @@ -36,6 +38,8 @@ def __mne_bids_pipeline_failsafe_wrapper__(*args, **kwargs):
exec_params=exec_params,
get_input_fnames=get_input_fnames,
get_output_fnames=get_output_fnames,
require_output=require_output,
func_name=f"{__mne_bids_pipeline_step__}::{func.__name__}",
)
t0 = time.time()
log_info = pd.concat(
Expand Down Expand Up @@ -116,7 +120,15 @@ def hash_file_path(path: pathlib.Path) -> str:


class ConditionalStepMemory:
def __init__(self, *, exec_params, get_input_fnames, get_output_fnames):
def __init__(
self,
*,
exec_params: SimpleNamespace,
get_input_fnames: Optional[Callable],
get_output_fnames: Optional[Callable],
require_output: bool,
func_name: str,
):
memory_location = exec_params.memory_location
if memory_location is True:
use_location = exec_params.deriv_root / exec_params.memory_subdir
Expand All @@ -134,6 +146,8 @@ def __init__(self, *, exec_params, get_input_fnames, get_output_fnames):
self.get_input_fnames = get_input_fnames
self.get_output_fnames = get_output_fnames
self.memory_file_method = exec_params.memory_file_method
self.require_output = require_output
self.func_name = func_name

def cache(self, func):
def wrapper(*args, **kwargs):
Expand Down Expand Up @@ -262,9 +276,19 @@ def wrapper(*args, **kwargs):

# https://joblib.readthedocs.io/en/latest/memory.html#joblib.memory.MemorizedFunc.call # noqa: E501
if force_run or unknown_inputs or bad_out_files:
memorized_func.call(*args, **kwargs)
out_files = memorized_func.call(*args, **kwargs)
else:
out_files = memorized_func(*args, **kwargs)
if self.require_output:
assert isinstance(out_files, dict) and len(out_files), (
f"Internal error: step must return non-empty out_files dict, got "
f"{type(out_files).__name__} for:\n{self.func_name}"
)
else:
memorized_func(*args, **kwargs)
assert out_files is None, (
f"Internal error: step must return None, got {type(out_files)} "
f"for:\n{self.func_name}"
)

return wrapper

Expand Down
20 changes: 10 additions & 10 deletions mne_bids_pipeline/steps/init/_01_init_derivatives_dir.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,14 @@

from ..._config_utils import _bids_kwargs, get_sessions, get_subjects
from ..._logging import gen_log_kwargs, logger
from ..._run import failsafe_run
from ..._run import _prep_out_files, failsafe_run


def init_dataset(cfg) -> None:
@failsafe_run()
def init_dataset(cfg: SimpleNamespace, exec_params: SimpleNamespace) -> dict:
"""Prepare the pipeline directory in /derivatives."""
fname_json = cfg.deriv_root / "dataset_description.json"
if fname_json.is_file():
msg = "Output directories already exist …"
logger.info(**gen_log_kwargs(message=msg, emoji="✅"))
return
out_files = dict()
out_files["json"] = cfg.deriv_root / "dataset_description.json"
logger.info(**gen_log_kwargs(message="Initializing output directories."))

cfg.deriv_root.mkdir(exist_ok=True, parents=True)
Expand All @@ -38,10 +36,12 @@ def init_dataset(cfg) -> None:
"URL": "n/a",
}

_write_json(fname_json, ds_json, overwrite=True)
_write_json(out_files["json"], ds_json, overwrite=True)
return _prep_out_files(
exec_params=exec_params, out_files=out_files, bids_only=False
)


@failsafe_run()
def init_subject_dirs(
*,
cfg: SimpleNamespace,
Expand Down Expand Up @@ -73,7 +73,7 @@ def get_config(

def main(*, config):
"""Initialize the output directories."""
init_dataset(cfg=get_config(config=config))
init_dataset(cfg=get_config(config=config), exec_params=config.exec_params)
# Don't bother with parallelization here as I/O operations are generally
# not well parallelized (and this should be very fast anyway)
for subject in get_subjects(config):
Expand Down
186 changes: 95 additions & 91 deletions mne_bids_pipeline/steps/preprocessing/_01_data_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import mne
import pandas as pd
from mne_bids import BIDSPath

from ..._config_utils import (
_do_mf_autobad,
Expand All @@ -21,6 +20,7 @@
_get_mf_reference_run_path,
_get_run_rest_noise_path,
_import_data_kwargs,
_read_raw_msg,
import_er_data,
import_experimental_data,
)
Expand Down Expand Up @@ -79,25 +79,96 @@ def assess_data_quality(
out_files = dict()
key = f"raw_task-{task}_run-{run}"
bids_path_in = in_files.pop(key)
if key == "raw_task-noise_run-None":
bids_path_ref_in = in_files.pop("raw_ref_run")
else:
bids_path_ref_in = None
msg, _ = _read_raw_msg(bids_path_in=bids_path_in, run=run, task=task)
logger.info(**gen_log_kwargs(message=msg))
if run is None and task == "noise":
raw = import_er_data(
cfg=cfg,
bids_path_er_in=bids_path_in,
bids_path_er_bads_in=None,
bids_path_ref_in=bids_path_ref_in,
bids_path_ref_bads_in=None,
prepare_maxwell_filter=True,
)
else:
data_is_rest = run is None and task == "rest"
raw = import_experimental_data(
bids_path_in=bids_path_in,
bids_path_bads_in=None,
cfg=cfg,
data_is_rest=data_is_rest,
)
preexisting_bads = set(raw.info["bads"])

if _do_mf_autobad(cfg=cfg):
if key == "raw_task-noise_run-None":
bids_path_ref_in = in_files.pop("raw_ref_run")
else:
bids_path_ref_in = None
auto_scores = _find_bads_maxwell(
(
auto_noisy_chs,
auto_flat_chs,
auto_scores,
) = _find_bads_maxwell(
cfg=cfg,
exec_params=exec_params,
bids_path_in=bids_path_in,
bids_path_ref_in=bids_path_ref_in,
raw=raw,
subject=subject,
session=session,
run=run,
task=task,
out_files=out_files,
)
bads = sorted(set(raw.info["bads"] + auto_noisy_chs + auto_flat_chs))
msg = f"Found {len(bads)} channel{_pl(bads)} as bad."
raw.info["bads"] = bads
del bads
logger.info(**gen_log_kwargs(message=msg))
else:
auto_scores = None
del key
auto_scores = auto_noisy_chs = auto_flat_chs = None
del key, raw

# Always output the scores and bads TSV
out_files["auto_scores"] = bids_path_in.copy().update(
suffix="scores",
extension=".json",
root=cfg.deriv_root,
split=None,
check=False,
session=session,
subject=subject,
)
_write_json(out_files["auto_scores"], auto_scores)

# Write the bad channels to disk.
out_files["bads_tsv"] = _bads_path(
cfg=cfg,
bids_path_in=bids_path_in,
subject=subject,
session=session,
)
bads_for_tsv = []
reasons = []

if auto_flat_chs:
bads_for_tsv.extend(auto_flat_chs)
reasons.extend(["auto-flat"] * len(auto_flat_chs))
preexisting_bads -= set(auto_flat_chs)

if auto_noisy_chs is not None:
bads_for_tsv.extend(auto_noisy_chs)
reasons.extend(["auto-noisy"] * len(auto_noisy_chs))
preexisting_bads -= set(auto_noisy_chs)

preexisting_bads = sorted(preexisting_bads)
if preexisting_bads:
bads_for_tsv.extend(preexisting_bads)
reasons.extend(
["pre-existing (before MNE-BIDS-pipeline was run)"] * len(preexisting_bads)
)

tsv_data = pd.DataFrame(dict(name=bads_for_tsv, reason=reasons))
tsv_data = tsv_data.sort_values(by="name")
tsv_data.to_csv(out_files["bads_tsv"], sep="\t", index=False)

# Report
with _open_report(
Expand Down Expand Up @@ -145,45 +216,25 @@ def _find_bads_maxwell(
*,
cfg: SimpleNamespace,
exec_params: SimpleNamespace,
bids_path_in: BIDSPath,
bids_path_ref_in: Optional[BIDSPath],
raw: mne.io.Raw,
subject: str,
session: Optional[str],
run: Optional[str],
task: Optional[str],
out_files: dict,
):
if cfg.find_flat_channels_meg and not cfg.find_noisy_channels_meg:
msg = "Finding flat channels."
elif cfg.find_noisy_channels_meg and not cfg.find_flat_channels_meg:
msg = "Finding noisy channels using Maxwell filtering."
if cfg.find_flat_channels_meg:
if cfg.find_noisy_channels_meg:
msg = "Finding flat channels and noisy channels using Maxwell filtering."
else:
msg = "Finding flat channels."
else:
msg = "Finding flat channels and noisy channels using Maxwell filtering."
assert cfg.find_noisy_channels_meg
msg = "Finding noisy channels using Maxwell filtering."
logger.info(**gen_log_kwargs(message=msg))

if run is None and task == "noise":
raw = import_er_data(
cfg=cfg,
bids_path_er_in=bids_path_in,
bids_path_er_bads_in=None,
bids_path_ref_in=bids_path_ref_in,
bids_path_ref_bads_in=None,
prepare_maxwell_filter=True,
)
else:
data_is_rest = run is None and task == "rest"
raw = import_experimental_data(
bids_path_in=bids_path_in,
bids_path_bads_in=None,
cfg=cfg,
data_is_rest=data_is_rest,
)

# Filter the data manually before passing it to find_bad_channels_maxwell()
# This reduces memory usage, as we can control the number of jobs used
# during filtering.
preexisting_bads = raw.info["bads"].copy()
bads = preexisting_bads.copy()
raw_filt = raw.copy().filter(l_freq=None, h_freq=40, n_jobs=1)
(
auto_noisy_chs,
Expand All @@ -209,7 +260,8 @@ def _find_bads_maxwell(
else:
msg = "Found no flat channels."
logger.info(**gen_log_kwargs(message=msg))
bads.extend(auto_flat_chs)
else:
auto_flat_chs = []

if cfg.find_noisy_channels_meg:
if auto_noisy_chs:
Expand All @@ -222,56 +274,8 @@ def _find_bads_maxwell(
msg = "Found no noisy channels."

logger.info(**gen_log_kwargs(message=msg))
bads.extend(auto_noisy_chs)

bads = sorted(set(bads))
msg = f"Found {len(bads)} channel{_pl(bads)} as bad."
raw.info["bads"] = bads
del bads
logger.info(**gen_log_kwargs(message=msg))

if cfg.find_noisy_channels_meg:
out_files["auto_scores"] = bids_path_in.copy().update(
suffix="scores",
extension=".json",
root=cfg.deriv_root,
split=None,
check=False,
session=session,
subject=subject,
)
_write_json(out_files["auto_scores"], auto_scores)

# Write the bad channels to disk.
out_files["bads_tsv"] = _bads_path(
cfg=cfg,
bids_path_in=bids_path_in,
subject=subject,
session=session,
)
bads_for_tsv = []
reasons = []

if cfg.find_flat_channels_meg:
bads_for_tsv.extend(auto_flat_chs)
reasons.extend(["auto-flat"] * len(auto_flat_chs))
preexisting_bads = set(preexisting_bads) - set(auto_flat_chs)

if cfg.find_noisy_channels_meg:
bads_for_tsv.extend(auto_noisy_chs)
reasons.extend(["auto-noisy"] * len(auto_noisy_chs))
preexisting_bads = set(preexisting_bads) - set(auto_noisy_chs)

preexisting_bads = list(preexisting_bads)
if preexisting_bads:
bads_for_tsv.extend(preexisting_bads)
reasons.extend(
["pre-existing (before MNE-BIDS-pipeline was run)"] * len(preexisting_bads)
)

tsv_data = pd.DataFrame(dict(name=bads_for_tsv, reason=reasons))
tsv_data = tsv_data.sort_values(by="name")
tsv_data.to_csv(out_files["bads_tsv"], sep="\t", index=False)
else:
auto_noisy_chs = []

# Interaction
if exec_params.interactive and cfg.find_noisy_channels_meg:
Expand All @@ -280,7 +284,7 @@ def _find_bads_maxwell(
plot_auto_scores(auto_scores, ch_types=cfg.ch_types)
plt.show()

return auto_scores
return auto_noisy_chs, auto_flat_chs, auto_scores


def get_config(
Expand Down
Loading

0 comments on commit 9433293

Please sign in to comment.