-
Notifications
You must be signed in to change notification settings - Fork 5
/
train_sr.py
116 lines (93 loc) · 4.83 KB
/
train_sr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import argparse
import os
import copy
import torch
from torch import nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
from sr_models.model import RDN, VGGLoss
from sr_models.datasets import TrainDataset, EvalDataset
from sr_models.utils import AverageMeter, calc_psnr, convert_rgb_to_y, denormalize
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--train-file', type=str, required=True)
parser.add_argument('--eval-file', type=str, required=True)
parser.add_argument('--outputs-dir', type=str, required=True)
parser.add_argument('--weights-file', type=str)
parser.add_argument('--num-features', type=int, default=64)
parser.add_argument('--growth-rate', type=int, default=64)
parser.add_argument('--num-blocks', type=int, default=16)
parser.add_argument('--num-layers', type=int, default=8)
parser.add_argument('--scale', type=int, default=4)
parser.add_argument('--patch-size', type=int, default=32)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--lr-decay', type=float, default=0.5)
parser.add_argument('--lr-decay-epoch', type=int, default=200)
parser.add_argument('--batch-size', type=int, default=16)
parser.add_argument('--num-epochs', type=int, default=800)
parser.add_argument('--num-save', type=int, default=100)
parser.add_argument('--num-workers', type=int, default=8)
parser.add_argument('--gpu-id',type=int, default=0)
parser.add_argument('--seed', type=int, default=123)
parser.add_argument('--vgg-lambda', type=float, default=0.2)
parser.add_argument('--augment', action='store_true', help='whether applying jpeg and gaussian noising augmentation in training a sr model')
parser.add_argument('--completion', action='store_true', help='completion')
parser.add_argument('--colorization', action='store_true', help='colorization')
args = parser.parse_args()
args.outputs_dir = os.path.join(args.outputs_dir, 'x{}'.format(args.scale))
if not os.path.exists(args.outputs_dir):
os.makedirs(args.outputs_dir)
cudnn.benchmark = True
device = torch.device('cuda:%d'%args.gpu_id if torch.cuda.is_available() else 'cpu')
torch.manual_seed(args.seed)
model = RDN(scale_factor=args.scale,
num_channels=3,
num_features=args.num_features,
growth_rate=args.growth_rate,
num_blocks=args.num_blocks,
num_layers=args.num_layers).to(device)
if args.weights_file is not None:
state_dict = model.state_dict()
for n, p in torch.load(args.weights_file, map_location=lambda storage, loc: storage).items():
if n in state_dict.keys():
state_dict[n].copy_(p)
else:
raise KeyError(n)
criterion = nn.L1Loss()
criterion_vgg = VGGLoss(args.gpu_id)
optimizer = optim.Adam(model.parameters(), lr=args.lr)
train_dataset = TrainDataset(args.train_file, patch_size=args.patch_size, scale=args.scale, aug=args.augment, colorization=args.colorization, completion=args.completion)
train_dataloader = DataLoader(dataset=train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=True)
#eval_dataset = EvalDataset(args.eval_file, scale=args.scale)
#eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1)
best_weights = copy.deepcopy(model.state_dict())
best_epoch = 0
best_psnr = 0.0
for epoch in range(args.num_epochs):
for param_group in optimizer.param_groups:
param_group['lr'] = args.lr * (args.lr_decay ** (epoch // args.lr_decay_epoch))
model.train()
epoch_losses = AverageMeter()
with tqdm(total=(len(train_dataset) - len(train_dataset) % args.batch_size), ncols=80) as t:
t.set_description('epoch: {}/{}'.format(epoch, args.num_epochs - 1))
for data in train_dataloader:
inputs, labels = data
inputs = inputs.to(device)
labels = labels.to(device)
preds = model(inputs)
#import ipdb; ipdb.set_trace()
loss = criterion(preds, labels) + criterion_vgg(preds, labels) * args.vgg_lambda
epoch_losses.update(loss.item(), len(inputs))
optimizer.zero_grad()
loss.backward()
optimizer.step()
t.set_postfix(loss='{:.6f}'.format(epoch_losses.avg))
t.update(len(inputs))
if (epoch + 1) % args.num_save == 0:
torch.save(model.state_dict(), os.path.join(args.outputs_dir, 'epoch_{}.pth'.format(epoch)))