From c5616391db32b61c0c66ffa14fd4889953ae73c0 Mon Sep 17 00:00:00 2001 From: Lukas Hermann Date: Thu, 7 Dec 2023 17:56:22 +0100 Subject: [PATCH] fix small error in eval script at checkpoint loading --- .../calvin_agent/evaluation/evaluate_policy.py | 8 +++++++- calvin_models/calvin_agent/utils/utils.py | 15 +++++++++------ 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/calvin_models/calvin_agent/evaluation/evaluate_policy.py b/calvin_models/calvin_agent/evaluation/evaluate_policy.py index 19dae80..1f30e47 100644 --- a/calvin_models/calvin_agent/evaluation/evaluate_policy.py +++ b/calvin_models/calvin_agent/evaluation/evaluate_policy.py @@ -39,6 +39,12 @@ NUM_SEQUENCES = 1000 +def get_epoch(checkpoint): + if "=" not in checkpoint.stem: + return "0" + checkpoint.stem.split("=")[1] + + def make_env(dataset_path): val_folder = Path(dataset_path) / "validation" env = get_env(val_folder, show_gui=False) @@ -235,7 +241,7 @@ def main(): env = None for checkpoint in checkpoints: - epoch = checkpoint.stem.split("=")[1] + epoch = get_epoch(checkpoint) model, env, _ = get_default_model_and_env( args.train_folder, args.dataset_path, diff --git a/calvin_models/calvin_agent/utils/utils.py b/calvin_models/calvin_agent/utils/utils.py index 629e204..4f348b7 100644 --- a/calvin_models/calvin_agent/utils/utils.py +++ b/calvin_models/calvin_agent/utils/utils.py @@ -49,12 +49,15 @@ def get_checkpoints_for_epochs(experiment_folder: Path, epochs: Union[List, str] def get_all_checkpoints(experiment_folder: Path) -> List: - if experiment_folder.is_dir(): - checkpoint_folder = experiment_folder / "saved_models" - if checkpoint_folder.is_dir(): - checkpoints = sorted(Path(checkpoint_folder).iterdir(), key=lambda chk: chk.stat().st_mtime) - if len(checkpoints): - return [chk for chk in checkpoints if chk.suffix == ".ckpt"] + if not experiment_folder.is_dir(): + return [] + checkpoint_folder = experiment_folder / "saved_models" + if checkpoint_folder.is_dir(): + return get_all_checkpoints(checkpoint_folder) + + checkpoints = sorted(Path(experiment_folder).iterdir(), key=lambda chk: chk.stat().st_mtime) + if len(checkpoints): + return [chk for chk in checkpoints if chk.suffix == ".ckpt"] return []