Skip to content

Commit

Permalink
Add notebook tools
Browse files Browse the repository at this point in the history
  • Loading branch information
cdancette committed Jul 17, 2020
1 parent 6482bf8 commit 1051a5c
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 3 deletions.
7 changes: 4 additions & 3 deletions bootstrap/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def init_logs_options_files(exp_dir, resume=None):
Logger(exp_dir, name=logs_name)


def run(path_opts=None):
def run(path_opts=None, train_engine=True, eval_engine=True):
# first call to Options() load the options yaml file from --path_opts command line argument if path_opts=None
Options(path_opts)

Expand Down Expand Up @@ -106,14 +106,14 @@ def run(path_opts=None):

# if no training split, evaluate the model on the evaluation split
# (example: $ python main.py --dataset.train_split --dataset.eval_split test)
if not Options()['dataset']['train_split']:
if eval_engine and not Options()['dataset']['train_split']:
engine.eval()

# optimize the model on the training split for several epochs
# (example: $ python main.py --dataset.train_split train)
# if evaluation split, evaluate the model after each epochs
# (example: $ python main.py --dataset.train_split train --dataset.eval_split val)
if Options()['dataset']['train_split']:
if train_engine and Options()['dataset']['train_split']:
engine.train()

if hasattr(engine.view, 'current_thread') and engine.view.current_thread.is_alive():
Expand All @@ -123,6 +123,7 @@ def run(path_opts=None):
finally:
# write profiling results, if enabled
process_profiler(profiler)
return engine


def activate_debugger():
Expand Down
67 changes: 67 additions & 0 deletions bootstrap/tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import os
import sys
import torch
from bootstrap.lib.logger import Logger
from bootstrap.lib.options import Options
from bootstrap.run import run


def reset_instance():
Options._Options__instance = None
Options.__instance = None
Logger._Loger_instance = None
Logger.perf_memory = {}
sys.argv = [sys.argv[0]] # reset command line args


def get_engine(
path_experiment, weights="best_eval_epoch.accuracy_top1", logs_name="tools",
):
reset_instance()
path_yaml = os.path.join(path_experiment, "options.yaml")
opt = Options(path_yaml)
if weights is not None:
opt["exp.resume"] = weights
opt["exp.dir"] = path_experiment
opt["misc.logs_name"] = logs_name
engine = run(train_engine=False, eval_engine=False)
return engine


def item_to_batch(engine, split, item, prepare_batch=True):
batch = engine.dataset[split].collate_fn([item])
if prepare_batch:
batch = engine.model.prepare_batch(batch)
return batch


def apply_item(engine, item, split="eval"):
# item = engine.dataset[split][idx]
engine.model.eval()
batch = item_to_batch(engine, split, item)
with torch.no_grad():
out = engine.model.network(batch)
return out


def load_model_state(engine, path):
"""
engine: bootstran Engine
path: path to model weights
"""
model_state = torch.load(path)
engine.model.load_state_dict(model_state)


def load_epoch(
engine, epoch, exp_dir,
):
path = os.path.join(exp_dir, f"ckpt_epoch_{epoch}_model.pth.tar")
print(path)
load_model_state(engine, path)


def load_last(engine, exp_dir):
path = os.path.join(exp_dir, "ckpt_last_model.pth.tar")
load_model_state(engine, path)

0 comments on commit 1051a5c

Please sign in to comment.