From 3b7c1bef8d7deb966910304ab842805c729a2b3c Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Sun, 24 Jan 2021 14:34:25 +0000 Subject: [PATCH] Load variables when --resume /path/to/checkpoint --test-only --- references/segmentation/train.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/references/segmentation/train.py b/references/segmentation/train.py index e82e5bda651..5e5e5615e19 100644 --- a/references/segmentation/train.py +++ b/references/segmentation/train.py @@ -133,11 +133,6 @@ def main(args): model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) model_without_ddp = model.module - if args.test_only: - confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes) - print(confmat) - return - params_to_optimize = [ {"params": [p for p in model_without_ddp.backbone.parameters() if p.requires_grad]}, {"params": [p for p in model_without_ddp.classifier.parameters() if p.requires_grad]}, @@ -155,10 +150,16 @@ def main(args): if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') - model_without_ddp.load_state_dict(checkpoint['model']) - optimizer.load_state_dict(checkpoint['optimizer']) - lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) - args.start_epoch = checkpoint['epoch'] + 1 + model_without_ddp.load_state_dict(checkpoint['model'], strict=not args.test_only) + if not args.test_only: + optimizer.load_state_dict(checkpoint['optimizer']) + lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) + args.start_epoch = checkpoint['epoch'] + 1 + + if args.test_only: + confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes) + print(confmat) + return start_time = time.time() for epoch in range(args.start_epoch, args.epochs):