From a49170e4b8907f90472680b723a8511bbfeeb206 Mon Sep 17 00:00:00 2001 From: Corentin Dancette Date: Wed, 1 Jul 2020 00:51:32 +0200 Subject: [PATCH] Add notebook tools --- bootstrap/run.py | 7 ++--- bootstrap/tools.py | 67 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 71 insertions(+), 3 deletions(-) create mode 100644 bootstrap/tools.py diff --git a/bootstrap/run.py b/bootstrap/run.py index c181de7..94eec50 100755 --- a/bootstrap/run.py +++ b/bootstrap/run.py @@ -53,7 +53,7 @@ def init_logs_options_files(exp_dir, resume=None): Logger(exp_dir, name=logs_name) -def run(path_opts=None): +def run(path_opts=None, train_engine=True, eval_engine=True): # first call to Options() load the options yaml file from --path_opts command line argument if path_opts=None Options(path_opts) @@ -106,19 +106,20 @@ def run(path_opts=None): # if no training split, evaluate the model on the evaluation split # (example: $ python main.py --dataset.train_split --dataset.eval_split test) - if not Options()['dataset']['train_split']: + if eval_engine and not Options()['dataset']['train_split']: engine.eval() # optimize the model on the training split for several epochs # (example: $ python main.py --dataset.train_split train) # if evaluation split, evaluate the model after each epochs # (example: $ python main.py --dataset.train_split train --dataset.eval_split val) - if Options()['dataset']['train_split']: + if train_engine and Options()['dataset']['train_split']: engine.train() finally: # write profiling results, if enabled process_profiler(profiler) + return engine def activate_debugger(): diff --git a/bootstrap/tools.py b/bootstrap/tools.py new file mode 100644 index 0000000..7887c7a --- /dev/null +++ b/bootstrap/tools.py @@ -0,0 +1,67 @@ +import os +import sys +import torch +from bootstrap.lib.logger import Logger +from bootstrap.lib.options import Options +from bootstrap.run import run + + +def reset_instance(): + Options._Options__instance = None + Options.__instance = None + Logger._Loger_instance = None + Logger.perf_memory = {} + sys.argv = [sys.argv[0]] # reset command line args + + +def get_engine( + path_experiment, weights="best_eval_epoch.accuracy_top1", logs_name="tools", +): + reset_instance() + path_yaml = os.path.join(path_experiment, "options.yaml") + opt = Options(path_yaml) + if weights is not None: + opt["exp.resume"] = weights + opt["exp.dir"] = path_experiment + opt["misc.logs_name"] = logs_name + engine = run(train_engine=False, eval_engine=False) + return engine + + +def item_to_batch(engine, split, item, prepare_batch=True): + batch = engine.dataset[split].collate_fn([item]) + if prepare_batch: + batch = engine.model.prepare_batch(batch) + return batch + + +def apply_item(engine, item, split="eval"): + # item = engine.dataset[split][idx] + engine.model.eval() + batch = item_to_batch(engine, split, item) + with torch.no_grad(): + out = engine.model.network(batch) + return out + + +def load_model_state(engine, path): + """ + engine: bootstran Engine + path: path to model weights + """ + model_state = torch.load(path) + engine.model.load_state_dict(model_state) + + +def load_epoch( + engine, epoch, exp_dir, +): + path = os.path.join(exp_dir, f"ckpt_epoch_{epoch}_model.pth.tar") + print(path) + load_model_state(engine, path) + + +def load_last(engine, exp_dir): + path = os.path.join(exp_dir, "ckpt_last_model.pth.tar") + load_model_state(engine, path) +