-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy pathplot.py
134 lines (113 loc) · 4.36 KB
/
plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import re
import argparse
import os
import matplotlib
'''def parse_log(log, pattern):
with open(log, 'r') as log_file:
for line in log_file:
match = re.search(pattern, line)
if match:
# yield the first group of the pattern;
# i.e. the one delimited in parenthesis
# inside the pattern (...)
yield match.group(1)'''
def parse_log(log, pattern):
with open(log, 'r') as log_file:
for i, line in enumerate(log_file):
match = re.search(pattern, line)
if match and '(0%)' not in line:
# yield the first group of the pattern;
# i.e. the one delimited in parenthesis
# inside the pattern (...)
yield i, match.group(1)
def plot_train_loss(args):
losses = [(10*n,float(i)) for n,i in parse_log(args.log_file, r'Train loss: (.*)')]
plt.clf()
it, losses = zip(*losses)
plt.plot(it, losses)
plt.title('Train Loss')
plt.xlabel('Iteration')
if args.y_max != 0:
plt.ylim(ymax=args.y_max)
if args.y_min != 0:
plt.ylim(ymin=args.y_min)
plt.grid()
plt.savefig(os.path.join(args.img_dir, 'train_loss.png'))
if not args.no_show:
plt.show()
def plot_test_loss(args):
losses = [float(i) for _,i in parse_log(args.log_file, r'Test loss = (.*)')]
plt.clf()
plt.plot(losses)
plt.title('Test Loss')
if args.y_max != 0:
plt.ylim(ymax=args.y_max)
if args.y_min != 0:
plt.ylim(ymin=args.y_min)
plt.xlabel('Epoch')
plt.grid()
plt.savefig(os.path.join(args.img_dir, 'test_loss.png'))
if not args.no_show:
plt.show()
def plot_accuracy(args):
accuracy = [float(i) for _,i in parse_log(args.log_file, r'.* Accuracy = (\d+\.\d+)%')]
details = ['exist', 'number', 'material', 'size', 'shape', 'color']
accs = {k: [float(i) for _,i in parse_log(args.log_file, '{} -- acc: (\d+\.\d+)%'.format(k))]
for k in details}
plt.clf()
for k, v in accs.items():
plt.plot(v, label=k)
plt.plot(accuracy, linewidth=2, label='total')
plt.legend(loc='best')
plt.title('Accuracy')
plt.xlabel('Epoch')
plt.ylabel('%')
plt.grid()
plt.savefig(os.path.join(args.img_dir, 'accuracy.png'))
if not args.no_show:
plt.show()
def plot_invalids(args):
invalids = [float(i) for _,i in parse_log(args.log_file, r'.* Invalids = (\d+\.\d+)%')]
'''details = ['exist', 'number', 'material', 'size', 'shape', 'color']
invds = {k: [float(i) for i in parse_log(log, '.* invalid: (\d+\.\d+)%'.format(k))]
for k in details}
for k, v in invds.items():
plt.plot(v, label=k)'''
plt.clf()
plt.plot(invalids, linewidth=2, label='total')
plt.legend(loc='best')
plt.title('Invalid rate')
plt.xlabel('Epoch')
plt.ylabel('%')
plt.grid()
plt.savefig(os.path.join(args.img_dir, 'invalids.png'))
if not args.no_show:
plt.show()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Plot RN training logs')
parser.add_argument('log_file', type=str, help='Log file to plot')
parser.add_argument('-trl', '--train-loss', action='store_true', help='Show training loss plot')
parser.add_argument('-tsl', '--test-loss', action='store_true', help='Show test loss plot')
parser.add_argument('-a', '--accuracy', action='store_true', help='Show accuracy plot')
parser.add_argument('-i', '--invalids', action='store_true', help='Show invalid rate plot')
parser.add_argument('--no-show', action='store_true', help='Do not show figures, store only on file')
parser.add_argument('--y-max', type=float, default=0,
help='upper bound for y axis of loss plots (0 to leave default)')
parser.add_argument('--y-min', type=float, default=0,
help='lower bound for y axis of loss plots (0 to leave default)')
args = parser.parse_args()
img_dir = 'imgs/'
args.img_dir = img_dir
if args.no_show:
matplotlib.use('Agg')
import matplotlib.pyplot as plt
if not os.path.exists(img_dir):
os.makedirs(img_dir)
if args.train_loss:
plot_train_loss(args)
if args.test_loss:
plot_test_loss(args)
if args.accuracy:
plot_accuracy(args)
if args.invalids:
plot_invalids(args)