Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: Fix bug with cache invalidation #756

Merged
merged 2 commits into from
Jul 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/v1.5.md.inc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@

- Fixed doc build errors and dependency specifications (#755 by @larsoner)

[//]: # (### :bug: Bug fixes)
### :bug: Bug fixes

[//]: # (- Whatever (#000 by @whoever))
- Fixed bug where cache would not invalidate properly based on output file changes and steps could be incorrectly skipped. All steps will automatically rerun to accommodate the new, safer caching scheme (#756 by @larsoner)
116 changes: 67 additions & 49 deletions mne_bids_pipeline/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import sys
import traceback
import time
from typing import Callable, Optional, Dict, List
from typing import Callable, Optional, Dict, List, Literal, Union
from types import SimpleNamespace

from filelock import FileLock
Expand Down Expand Up @@ -163,20 +163,7 @@ def wrapper(*args, **kwargs):
# If this is ever true, we'll need to improve the logic below
assert not (unknown_inputs and force_run)

def hash_(k, v):
if isinstance(v, BIDSPath):
v = v.fpath
assert isinstance(
v, pathlib.Path
), f'Bad type {type(v)}: in_files["{k}"] = {v}'
assert v.exists(), f'missing in_files["{k}"] = {v}'
if self.memory_file_method == "mtime":
this_hash = v.lstat().st_mtime
else:
assert self.memory_file_method == "hash" # guaranteed
this_hash = hash_file_path(v)
return (str(v), this_hash)

hash_ = functools.partial(_path_to_str_hash, method=self.memory_file_method)
hashes = []
for k, v in in_files.items():
hashes.append(hash_(k, v))
Expand Down Expand Up @@ -211,9 +198,12 @@ def hash_(k, v):
memorized_func = self.memory.cache(func, ignore=self.ignore)
msg = emoji = None
short_circuit = False
subject = kwargs.get("subject", None)
session = kwargs.get("session", None)
run = kwargs.get("run", None)
# Used for logging automatically
subject = kwargs.get("subject", None) # noqa
session = kwargs.get("session", None) # noqa
run = kwargs.get("run", None) # noqa
task = kwargs.get("task", None) # noqa
bad_out_files = False
try:
done = memorized_func.check_call_in_cache(*args, **kwargs)
except Exception:
Expand All @@ -229,9 +219,25 @@ def hash_(k, v):
msg = "Computation forced despite existing cached result …"
emoji = "🔂"
else:
msg = "Computation unnecessary (cached) …"
emoji = "cache"
# When out_files is not None, we should check if the output files
# Check our output file hashes
out_files_hashes = memorized_func(*args, **kwargs)
for key, (fname, this_hash) in out_files_hashes.items():
fname = pathlib.Path(fname)
if not fname.exists():
msg = "Output file missing, will recompute …"
emoji = "🧩"
bad_out_files = True
break
got_hash = hash_(key, fname, kind="out")[1]
if this_hash != got_hash:
msg = "Output file hash mismatch, will recompute …"
emoji = "🚫"
bad_out_files = True
break
else:
msg = "Computation unnecessary (cached) …"
emoji = "cache"
# When out_files_expected is not None, we should check if the output files
# exist and stop if they do (e.g., in bem surface or coreg surface
# creation)
elif out_files is not None:
Expand All @@ -246,41 +252,19 @@ def hash_(k, v):
msg = "Computation unnecessary (output files exist) …"
emoji = "🔍"
short_circuit = True
del out_files

if msg is not None:
step = _short_step_path(pathlib.Path(inspect.getfile(func)))
logger.info(
**gen_log_kwargs(
message=msg,
subject=subject,
session=session,
run=run,
emoji=emoji,
step=step,
)
)
logger.info(**gen_log_kwargs(message=msg, emoji=emoji, step=step))
if short_circuit:
return

# https://joblib.readthedocs.io/en/latest/memory.html#joblib.memory.MemorizedFunc.call # noqa: E501
if force_run or unknown_inputs:
out_files, _ = memorized_func.call(*args, **kwargs)
if force_run or unknown_inputs or bad_out_files:
memorized_func.call(*args, **kwargs)
else:
out_files = memorized_func(*args, **kwargs)
assert isinstance(out_files, dict), type(out_files)
out_files_missing_msg = "\n".join(
f"- {key}={fname}"
for key, fname in out_files.items()
if not pathlib.Path(fname).exists()
)
if out_files_missing_msg:
raise ValueError(
"Missing at least one output file: \n"
+ out_files_missing_msg
+ "\n"
+ "This should not happen unless some files "
"have been manually moved or deleted. You "
"need to flush your cache to fix this."
)
memorized_func(*args, **kwargs)

return wrapper

Expand Down Expand Up @@ -381,3 +365,37 @@ def _get_step_path(

def _short_step_path(step_path: pathlib.Path) -> str:
return f"{step_path.parent.name}/{step_path.stem}"


def _prep_out_files(
*,
exec_params: SimpleNamespace,
out_files: Dict[str, BIDSPath],
):
for key, fname in out_files.items():
out_files[key] = _path_to_str_hash(
key,
pathlib.Path(fname),
method=exec_params.memory_file_method,
kind="out",
)
return out_files


def _path_to_str_hash(
k: str,
v: Union[BIDSPath, pathlib.Path],
*,
method: Literal["mtime", "hash"],
kind: str = "in",
):
if isinstance(v, BIDSPath):
v = v.fpath
assert isinstance(v, pathlib.Path), f'Bad type {type(v)}: {kind}_files["{k}"] = {v}'
assert v.exists(), f'missing {kind}_files["{k}"] = {v}'
if method == "mtime":
this_hash = v.lstat().st_mtime
else:
assert method == "hash" # guaranteed
this_hash = hash_file_path(v)
return (str(v), this_hash)
4 changes: 2 additions & 2 deletions mne_bids_pipeline/steps/freesurfer/_02_coreg_surfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
)
from ..._logging import logger, gen_log_kwargs
from ..._parallel import parallel_func, get_parallel_backend
from ..._run import failsafe_run
from ..._run import failsafe_run, _prep_out_files

fs_bids_app = Path(__file__).parent / "contrib" / "run.py"

Expand Down Expand Up @@ -62,7 +62,7 @@ def make_coreg_surfaces(
overwrite=True,
)
out_files = get_output_fnames_coreg_surfaces(cfg=cfg, subject=subject)
return out_files
return _prep_out_files(exec_params=exec_params, out_files=out_files)


def get_config(*, config, subject) -> SimpleNamespace:
Expand Down
4 changes: 2 additions & 2 deletions mne_bids_pipeline/steps/init/_02_find_empty_room.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from ..._io import _empty_room_match_path, _write_json
from ..._logging import gen_log_kwargs, logger
from ..._run import _update_for_splits, failsafe_run, save_logs
from ..._run import _update_for_splits, failsafe_run, save_logs, _prep_out_files


def get_input_fnames_find_empty_room(
Expand Down Expand Up @@ -96,7 +96,7 @@ def find_empty_room(
out_files = dict()
out_files["empty_room_match"] = _empty_room_match_path(raw_path, cfg)
_write_json(out_files["empty_room_match"], dict(fname=fname))
return out_files
return _prep_out_files(exec_params=exec_params, out_files=out_files)


def get_config(
Expand Down
4 changes: 2 additions & 2 deletions mne_bids_pipeline/steps/preprocessing/_01_data_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from ..._logging import gen_log_kwargs, logger
from ..._parallel import parallel_func, get_parallel_backend
from ..._report import _open_report, _add_raw
from ..._run import failsafe_run, save_logs
from ..._run import failsafe_run, save_logs, _prep_out_files
from ..._viz import plot_auto_scores


Expand Down Expand Up @@ -140,7 +140,7 @@ def assess_data_quality(
plt.close(fig)

assert len(in_files) == 0, in_files.keys()
return out_files
return _prep_out_files(exec_params=exec_params, out_files=out_files)


def _find_bads_maxwell(
Expand Down
4 changes: 2 additions & 2 deletions mne_bids_pipeline/steps/preprocessing/_02_head_pos.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ..._logging import gen_log_kwargs, logger
from ..._parallel import parallel_func, get_parallel_backend
from ..._report import _open_report
from ..._run import failsafe_run, save_logs
from ..._run import failsafe_run, save_logs, _prep_out_files


def get_input_fnames_head_pos(
Expand Down Expand Up @@ -140,7 +140,7 @@ def run_head_pos(
plt.close(fig)
del bids_path_in
assert len(in_files) == 0, in_files.keys()
return out_files
return _prep_out_files(exec_params=exec_params, out_files=out_files)


def get_config(
Expand Down
4 changes: 2 additions & 2 deletions mne_bids_pipeline/steps/preprocessing/_03_maxfilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from ..._logging import gen_log_kwargs, logger
from ..._parallel import parallel_func, get_parallel_backend
from ..._report import _open_report, _add_raw
from ..._run import failsafe_run, save_logs, _update_for_splits
from ..._run import failsafe_run, save_logs, _update_for_splits, _prep_out_files


def get_input_fnames_maxwell_filter(
Expand Down Expand Up @@ -355,7 +355,7 @@ def run_maxwell_filter(
)

assert len(in_files) == 0, in_files.keys()
return out_files
return _prep_out_files(exec_params=exec_params, out_files=out_files)


def get_config(
Expand Down
4 changes: 2 additions & 2 deletions mne_bids_pipeline/steps/preprocessing/_04_frequency_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from ..._logging import gen_log_kwargs, logger
from ..._parallel import parallel_func, get_parallel_backend
from ..._report import _open_report, _add_raw
from ..._run import failsafe_run, save_logs, _update_for_splits
from ..._run import failsafe_run, save_logs, _update_for_splits, _prep_out_files


def get_input_fnames_frequency_filter(
Expand Down Expand Up @@ -265,7 +265,7 @@ def filter_data(
)

assert len(in_files) == 0, in_files.keys()
return out_files
return _prep_out_files(exec_params=exec_params, out_files=out_files)


def get_config(
Expand Down
3 changes: 2 additions & 1 deletion mne_bids_pipeline/steps/preprocessing/_05_make_epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
save_logs,
_update_for_splits,
_sanitize_callable,
_prep_out_files,
)
from ..._parallel import parallel_func, get_parallel_backend

Expand Down Expand Up @@ -262,7 +263,7 @@ def run_epochs(
epochs.plot()
epochs.plot_image(combine="gfp", sigma=2.0, cmap="YlGnBu_r")
assert len(in_files) == 0, in_files.keys()
return out_files
return _prep_out_files(exec_params=exec_params, out_files=out_files)


# TODO: ideally we wouldn't need this anymore and could refactor the code above
Expand Down
4 changes: 2 additions & 2 deletions mne_bids_pipeline/steps/preprocessing/_06a_run_ica.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from ..._parallel import parallel_func, get_parallel_backend
from ..._reject import _get_reject
from ..._report import _agg_backend
from ..._run import failsafe_run, _update_for_splits, save_logs
from ..._run import failsafe_run, _update_for_splits, save_logs, _prep_out_files


def filter_for_ica(
Expand Down Expand Up @@ -527,7 +527,7 @@ def run_ica(
)

assert len(in_files) == 0, in_files.keys()
return out_files
return _prep_out_files(exec_params=exec_params, out_files=out_files)


def get_config(
Expand Down
4 changes: 2 additions & 2 deletions mne_bids_pipeline/steps/preprocessing/_06b_run_ssp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ..._parallel import parallel_func, get_parallel_backend
from ..._reject import _get_reject
from ..._report import _open_report
from ..._run import failsafe_run, _update_for_splits, save_logs
from ..._run import failsafe_run, _update_for_splits, save_logs, _prep_out_files


def get_input_fnames_run_ssp(
Expand Down Expand Up @@ -205,7 +205,7 @@ def run_ssp(
replace=True,
)
plt.close(fig)
return out_files
return _prep_out_files(exec_params=exec_params, out_files=out_files)


def get_config(
Expand Down
4 changes: 2 additions & 2 deletions mne_bids_pipeline/steps/preprocessing/_07a_apply_ica.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from ..._parallel import parallel_func, get_parallel_backend
from ..._reject import _get_reject
from ..._report import _open_report, _agg_backend
from ..._run import failsafe_run, _update_for_splits, save_logs
from ..._run import failsafe_run, _update_for_splits, save_logs, _prep_out_files


def get_input_fnames_apply_ica(
Expand Down Expand Up @@ -172,7 +172,7 @@ def apply_ica(
replace=True,
)

return out_files
return _prep_out_files(exec_params=exec_params, out_files=out_files)


def get_config(
Expand Down
4 changes: 2 additions & 2 deletions mne_bids_pipeline/steps/preprocessing/_07b_apply_ssp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
_bids_kwargs,
)
from ..._logging import gen_log_kwargs, logger
from ..._run import failsafe_run, _update_for_splits, save_logs
from ..._run import failsafe_run, _update_for_splits, save_logs, _prep_out_files
from ..._parallel import parallel_func, get_parallel_backend


Expand Down Expand Up @@ -79,7 +79,7 @@ def apply_ssp(
)
_update_for_splits(out_files, "epochs")
assert len(in_files) == 0, in_files.keys()
return out_files
return _prep_out_files(exec_params=exec_params, out_files=out_files)


def get_config(
Expand Down
4 changes: 2 additions & 2 deletions mne_bids_pipeline/steps/preprocessing/_08_ptp_reject.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ..._parallel import parallel_func, get_parallel_backend
from ..._reject import _get_reject
from ..._report import _open_report
from ..._run import failsafe_run, _update_for_splits, save_logs
from ..._run import failsafe_run, _update_for_splits, save_logs, _prep_out_files


def get_input_fnames_drop_ptp(
Expand Down Expand Up @@ -154,7 +154,7 @@ def drop_ptp(
drop_log_ignore=(),
replace=True,
)
return out_files
return _prep_out_files(exec_params=exec_params, out_files=out_files)


def get_config(
Expand Down
Loading