diff --git a/train.py b/train.py index 8754970..4625cc2 100644 --- a/train.py +++ b/train.py @@ -207,7 +207,7 @@ def main(): for param_group in optimizer.param_groups: param_group['lr'] = lr - if epoch % 5 == 0: + if epoch % 2 == 0: save_path = os.path.join(checkpoints_dir, '{}_{:03d}_{:04f}.pth.tar'.format(model.__class__.__name__, epoch, float(train_loss))) torch.save({