-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
100 lines (82 loc) · 3.17 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from data_loader import Autoencoder_dataset
from model import Autoencoder
import os
root ='path to your image dataset'
def img_denorm(img, mean, std):
#for ImageNet the mean and std are:
#mean = np.asarray([ 0.485, 0.456, 0.406 ])
#std = np.asarray([ 0.229, 0.224, 0.225 ])
denormalize = transforms.Normalize((-1 * mean / std), (1.0 / std))
res = denormalize(res)
#Image needs to be clipped since the denormalize function will map some
#values below 0 and above 1
res = torch.clamp(res, 0, 1)
res = res.view(res.size(0), 3, 576, 288)
return(res)
def adjust_learning_rate(optimizer, epoch):
"""Sets the learning rate to the initial LR decayed by 2 every 30 epochs"""
lr = lr * (0.5 ** (epoch // 30))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def save_checkpoint(state, filename='checkpoint.pth.tar'):
"""
Save the training model
"""
torch.save(state, filename)
# setting hyperparameters
batch_size = 128,
num_epochs = 150,
learning_rate = 1e-4
if not os.path.exists('./decoded_images'):
os.mkdir('./decoded_images')
def main():
trainset = Autoencoder_dataset(True ,root,transforms=transforms.Compose([
transforms.Rescale(576,288),
transforms.ToTensor(),
transforms.Normalize(mean = [0.485, 0.456, 0.406],
std = [0.229, 0.224, 0.225])
]))
train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
valset = Autoencoder_dataset(False ,root,transforms=transforms.Compose([
transforms.Rescale(576,288),
transforms.ToTensor(),
transforms.Normalize(mean = [0.485, 0.456, 0.406],
std = [0.229, 0.224, 0.225]))
]))
val_loader = DataLoader(valset, batch_size=batch_size)
model = Autoencoder().cuda()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,
weight_decay=1e-5)
for epoch in range(num_epochs):
adjust_learning_rate(optimizer, epoch)
for data in train_loader:
img, _ = data
img = (img).cuda()
output = model(img)
loss = criterion(output, img)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('epoch [{}/{}], loss:{:.4f}'
.format(epoch+1, num_epochs, loss.data[0]))
with torch.no_grad():
output_val = model(input)
loss_val = criterion(output_val, target)
print('epoch [{}/{}], loss:{:.4f}'
.format(epoch+1, num_epochs, loss_val.data[0]))
save_checkpoint({
'epoch': epoch + 1,
'state_dict': model.state_dict(),
}, filename=os.path.join('./', 'checkpoint_{}.tar'.format(epoch)))
if epoch % 25 == 0:
pic = img_denorm(output.cpu().data)
save_image(pic, './decoded_images/image_{}.png'.format(epoch))
if __name__ == '__main__':
main()