Skip to content

Commit

Permalink
adadelta替换sgd
Browse files Browse the repository at this point in the history
  • Loading branch information
yizt committed Jun 7, 2020
1 parent 02ff207 commit 869b232
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,9 @@ def train(args):
if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
# optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
optimizer = optim.Adadelta(model.parameters(), weight_decay=args.weight_decay)
# lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)

model_without_ddp = model
if args.distributed:
Expand All @@ -122,7 +123,7 @@ def train(args):
'crnn.{}.{:03d}.pth'.format(args.direction, args.init_epoch)),
map_location='cpu')
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
# lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
model_without_ddp.load_state_dict(checkpoint['model'])
# log
writer = SummaryWriter(log_dir=cfg.log_dir) if utils.is_main_process() else None
Expand All @@ -140,14 +141,14 @@ def train(args):
utils.add_scalar_on_master(writer, 'scalar/train_loss', loss, epoch + 1)
utils.add_weight_history_on_master(writer, model_without_ddp, epoch + 1)
# 更新lr
lr_scheduler.step(epoch)
# lr_scheduler.step(epoch)

# 保存模型
if args.output_dir:
checkpoint = {
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
# 'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch + 1,
'args': args}
utils.save_on_master(
Expand All @@ -167,8 +168,8 @@ def train(args):
parser.add_argument("--init-epoch", type=int, default=0, help="init epoch")
parser.add_argument("--lr", type=float, default=1e-2, help="learning rate")
parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
parser.add_argument('--wd', '--weight-decay', default=1e-4, help='weight decay (default: 1e-4)',
dest='weight_decay')
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
help='weight decay (default: 1e-4)', dest='weight decay')
parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs')
parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')

Expand Down

0 comments on commit 869b232

Please sign in to comment.