Skip to content

Commit

Permalink
Load variables when --resume /path/to/checkpoint --test-only (#3285)
Browse files Browse the repository at this point in the history
Summary: Co-authored-by: Francisco Massa <[email protected]>

Reviewed By: datumbox

Differential Revision: D26156373

fbshipit-source-id: 83f22c90477ca2da8db176d2455a70ca302d17d1
  • Loading branch information
vincentqb authored and facebook-github-bot committed Feb 1, 2021
1 parent 4f57ba2 commit 2174ba4
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions references/segmentation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,6 @@ def main(args):
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module

if args.test_only:
confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes)
print(confmat)
return

params_to_optimize = [
{"params": [p for p in model_without_ddp.backbone.parameters() if p.requires_grad]},
{"params": [p for p in model_without_ddp.classifier.parameters() if p.requires_grad]},
Expand All @@ -155,10 +150,16 @@ def main(args):

if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1
model_without_ddp.load_state_dict(checkpoint['model'], strict=not args.test_only)
if not args.test_only:
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1

if args.test_only:
confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes)
print(confmat)
return

start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
Expand Down

0 comments on commit 2174ba4

Please sign in to comment.