-
Notifications
You must be signed in to change notification settings - Fork 83
/
logger.py
88 lines (74 loc) · 3.32 KB
/
logger.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import csv
import os.path
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
import numpy as np
plt.switch_backend('agg')
class CsvLogger:
def __init__(self, filepath='./', filename='results.csv', data=None):
self.log_path = filepath
self.log_name = filename
self.csv_path = os.path.join(self.log_path, self.log_name)
self.fieldsnames = ['epoch', 'val_error1', 'val_error5', 'val_loss', 'train_error1', 'train_error5',
'train_loss']
with open(self.csv_path, 'w') as f:
writer = csv.DictWriter(f, fieldnames=self.fieldsnames)
writer.writeheader()
self.data = {}
for field in self.fieldsnames:
self.data[field] = []
if data is not None:
for d in data:
d_num = {}
for key in d:
d_num[key] = float(d[key]) if key != 'epoch' else int(d[key])
self.write(d_num)
def write(self, data):
for k in self.data:
self.data[k].append(data[k])
with open(self.csv_path, 'a') as f:
writer = csv.DictWriter(f, fieldnames=self.fieldsnames)
writer.writerow(data)
def save_params(self, args, params):
with open(os.path.join(self.log_path, 'params.txt'), 'w') as f:
f.write('{}\n'.format(' '.join(args)))
f.write('{}\n'.format(params))
def write_text(self, text, print_t=True):
with open(os.path.join(self.log_path, 'params.txt'), 'a') as f:
f.write('{}\n'.format(text))
if print_t:
print(text)
def plot_progress_errk(self, claimed_acc=None, title='MobileNetv2', k=1):
tr_str = 'train_error{}'.format(k)
val_str = 'val_error{}'.format(k)
plt.figure(figsize=(9, 8), dpi=300)
plt.plot(self.data[tr_str], label='Training error')
plt.plot(self.data[val_str], label='Validation error')
if claimed_acc is not None:
plt.plot((0, len(self.data[tr_str])), (1 - claimed_acc, 1 - claimed_acc), 'k--',
label='Claimed validation error ({:.2f}%)'.format(100. * (1 - claimed_acc)))
plt.plot((0, len(self.data[tr_str])),
(np.min(self.data[val_str]), np.min(self.data[val_str])), 'r--',
label='Best validation error ({:.2f}%)'.format(100. * np.min(self.data[val_str])))
plt.title('Top-{} error for {}'.format(k, title))
plt.xlabel('Epoch')
plt.ylabel('Error')
plt.legend()
plt.xlim(0, len(self.data[tr_str]) + 1)
plt.savefig(os.path.join(self.log_path, 'top{}.png'.format(k)))
def plot_progress_loss(self, title='MobileNetv2'):
plt.figure(figsize=(9, 8), dpi=300)
plt.plot(self.data['train_loss'], label='Training')
plt.plot(self.data['val_loss'], label='Validation')
plt.title(title)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.xlim(0, len(self.data['train_loss']) + 1)
plt.savefig(os.path.join(self.log_path, 'loss.png'))
def plot_progress(self, claimed_acc1=None, claimed_acc5=None, title='MobileNetv2'):
self.plot_progress_errk(claimed_acc1, title, 1)
self.plot_progress_errk(claimed_acc5, title, 5)
self.plot_progress_loss(title)
plt.close('all')