From 869b232258a2b1ee952cb5f09d8affa80b7d92b0 Mon Sep 17 00:00:00 2001 From: yizt Date: Sun, 7 Jun 2020 17:03:03 +0800 Subject: [PATCH] =?UTF-8?q?adadelta=E6=9B=BF=E6=8D=A2sgd?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- train.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/train.py b/train.py index bd4e0f9..0343213 100644 --- a/train.py +++ b/train.py @@ -107,8 +107,9 @@ def train(args): if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) - optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) - lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) + # optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) + optimizer = optim.Adadelta(model.parameters(), weight_decay=args.weight_decay) + # lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) model_without_ddp = model if args.distributed: @@ -122,7 +123,7 @@ def train(args): 'crnn.{}.{:03d}.pth'.format(args.direction, args.init_epoch)), map_location='cpu') optimizer.load_state_dict(checkpoint['optimizer']) - lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) + # lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) model_without_ddp.load_state_dict(checkpoint['model']) # log writer = SummaryWriter(log_dir=cfg.log_dir) if utils.is_main_process() else None @@ -140,14 +141,14 @@ def train(args): utils.add_scalar_on_master(writer, 'scalar/train_loss', loss, epoch + 1) utils.add_weight_history_on_master(writer, model_without_ddp, epoch + 1) # 更新lr - lr_scheduler.step(epoch) + # lr_scheduler.step(epoch) # 保存模型 if args.output_dir: checkpoint = { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), - 'lr_scheduler': lr_scheduler.state_dict(), + # 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch + 1, 'args': args} utils.save_on_master( @@ -167,8 +168,8 @@ def train(args): parser.add_argument("--init-epoch", type=int, default=0, help="init epoch") parser.add_argument("--lr", type=float, default=1e-2, help="learning rate") parser.add_argument('--momentum', default=0.9, type=float, help='momentum') - parser.add_argument('--wd', '--weight-decay', default=1e-4, help='weight decay (default: 1e-4)', - dest='weight_decay') + parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, + help='weight decay (default: 1e-4)', dest='weight decay') parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs') parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')