From 2174ba481bbb9bd54e7ba45fc920e2bf4afcbc2d Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Mon, 1 Feb 2021 12:16:52 -0800 Subject: [PATCH] Load variables when --resume /path/to/checkpoint --test-only (#3285) Summary: Co-authored-by: Francisco Massa Reviewed By: datumbox Differential Revision: D26156373 fbshipit-source-id: 83f22c90477ca2da8db176d2455a70ca302d17d1 --- references/segmentation/train.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/references/segmentation/train.py b/references/segmentation/train.py index e82e5bda651..5e5e5615e19 100644 --- a/references/segmentation/train.py +++ b/references/segmentation/train.py @@ -133,11 +133,6 @@ def main(args): model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) model_without_ddp = model.module - if args.test_only: - confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes) - print(confmat) - return - params_to_optimize = [ {"params": [p for p in model_without_ddp.backbone.parameters() if p.requires_grad]}, {"params": [p for p in model_without_ddp.classifier.parameters() if p.requires_grad]}, @@ -155,10 +150,16 @@ def main(args): if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') - model_without_ddp.load_state_dict(checkpoint['model']) - optimizer.load_state_dict(checkpoint['optimizer']) - lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) - args.start_epoch = checkpoint['epoch'] + 1 + model_without_ddp.load_state_dict(checkpoint['model'], strict=not args.test_only) + if not args.test_only: + optimizer.load_state_dict(checkpoint['optimizer']) + lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) + args.start_epoch = checkpoint['epoch'] + 1 + + if args.test_only: + confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes) + print(confmat) + return start_time = time.time() for epoch in range(args.start_epoch, args.epochs):