-
Notifications
You must be signed in to change notification settings - Fork 100
/
Copy pathcheckpoint.py
94 lines (80 loc) · 2.99 KB
/
checkpoint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
import torch
import datetime
import numpy as np
import imageio
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
class Checkpoint():
def __init__(self, opt):
self.opt = opt
self.ok = True
self.log = torch.Tensor()
now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
if opt.save == '.': opt.save = '../experiment/EXP/' + now
self.dir = opt.save
def _make_dir(path):
if not os.path.exists(path): os.makedirs(path)
_make_dir(self.dir)
_make_dir(self.dir + '/model')
_make_dir(self.dir + '/results')
open_type = 'a' if os.path.exists(self.dir + '/log.txt') else 'w'
self.log_file = open(self.dir + '/log.txt', open_type)
with open(self.dir + '/config.txt', open_type) as f:
f.write(now + '\n\n')
for arg in vars(opt):
f.write('{}: {}\n'.format(arg, getattr(opt, arg)))
f.write('\n')
def save(self, trainer, epoch, is_best=False):
trainer.model.save(self.dir, is_best=is_best)
trainer.loss.save(self.dir)
trainer.loss.plot_loss(self.dir, epoch)
self.plot_psnr(epoch)
torch.save(self.log, os.path.join(self.dir, 'psnr_log.pt'))
torch.save(
trainer.optimizer.state_dict(),
os.path.join(self.dir, 'optimizer.pt')
)
dual_optimizers = {}
for i in range(len(trainer.dual_optimizers)):
dual_optimizers[i] = trainer.dual_optimizers[i]
torch.save(
dual_optimizers,
os.path.join(self.dir, 'dual_optimizers.pt')
)
def add_log(self, log):
self.log = torch.cat([self.log, log])
def write_log(self, log, refresh=False):
print(log)
self.log_file.write(log + '\n')
if refresh:
self.log_file.close()
self.log_file = open(self.dir + '/log.txt', 'a')
def done(self):
self.log_file.close()
def plot_psnr(self, epoch):
axis = np.linspace(1, epoch, epoch)
label = 'SR on {}'.format(self.opt.data_test)
fig = plt.figure()
plt.title(label)
for idx_scale, scale in enumerate([self.opt.scale[0]]):
plt.plot(
axis,
self.log[:, idx_scale].numpy(),
label='Scale {}'.format(scale)
)
plt.legend()
plt.xlabel('Epochs')
plt.ylabel('PSNR')
plt.grid(True)
plt.savefig('{}/test_{}.pdf'.format(self.dir, self.opt.data_test))
plt.close(fig)
def save_results_nopostfix(self, filename, sr, scale):
apath = '{}/results/{}/x{}'.format(self.dir, self.opt.data_test, scale)
if not os.path.exists(apath):
os.makedirs(apath)
filename = os.path.join(apath, filename)
normalized = sr[0].data.mul(255 / self.opt.rgb_range)
ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy()
imageio.imwrite('{}.png'.format(filename), ndarr)