Skip to content

Commit

Permalink
Adding rmsprop support and allowing warm restarts on the train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Dec 20, 2020
1 parent 25f8b26 commit 385e077
Showing 1 changed file with 27 additions and 5 deletions.
32 changes: 27 additions & 5 deletions references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,15 @@ def main(args):

criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.SGD(
model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
opt_name = args.opt.lower()
if opt_name == 'sgd':
optimizer = torch.optim.SGD(
model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
elif opt_name == 'rmsprop':
optimizer = torch.optim.RMSprop(
model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
else:
raise RuntimeError("Invalid optimizer {}. Only SGD and RMSprop are supported.".format(args.opt))

if args.apex:
model, optimizer = amp.initialize(model, optimizer,
Expand All @@ -191,9 +198,11 @@ def main(args):
if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1
if not args.no_resume_opt:
optimizer.load_state_dict(checkpoint['optimizer'])
if not args.no_resume_sched:
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1

if args.test_only:
evaluate(model, criterion, data_loader_test, device=device)
Expand Down Expand Up @@ -238,6 +247,7 @@ def parse_args():
help='number of total epochs to run')
parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
help='number of data loading workers (default: 16)')
parser.add_argument('--opt', default='sgd', type=str, help='optimizer')
parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
Expand Down Expand Up @@ -275,6 +285,18 @@ def parse_args():
help="Use pre-trained models from the modelzoo",
action="store_true",
)
parser.add_argument(
"--no-resume-opt",
dest="no_resume_opt",
help="When resuming from checkpoint it ignores the optimizer state",
action="store_true",
)
parser.add_argument(
"--no-resume-sched",
dest="no_resume_sched",
help="When resuming from checkpoint it ignores the scheduler state",
action="store_true",
)

# Mixed precision training parameters
parser.add_argument('--apex', action='store_true',
Expand Down

0 comments on commit 385e077

Please sign in to comment.