diff --git a/references/segmentation/train.py b/references/segmentation/train.py index e82e5bda651..5e5e5615e19 100644 --- a/references/segmentation/train.py +++ b/references/segmentation/train.py @@ -133,11 +133,6 @@ def main(args): model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) model_without_ddp = model.module - if args.test_only: - confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes) - print(confmat) - return - params_to_optimize = [ {"params": [p for p in model_without_ddp.backbone.parameters() if p.requires_grad]}, {"params": [p for p in model_without_ddp.classifier.parameters() if p.requires_grad]}, @@ -155,10 +150,16 @@ def main(args): if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') - model_without_ddp.load_state_dict(checkpoint['model']) - optimizer.load_state_dict(checkpoint['optimizer']) - lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) - args.start_epoch = checkpoint['epoch'] + 1 + model_without_ddp.load_state_dict(checkpoint['model'], strict=not args.test_only) + if not args.test_only: + optimizer.load_state_dict(checkpoint['optimizer']) + lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) + args.start_epoch = checkpoint['epoch'] + 1 + + if args.test_only: + confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes) + print(confmat) + return start_time = time.time() for epoch in range(args.start_epoch, args.epochs):