-
Notifications
You must be signed in to change notification settings - Fork 45
/
draw_graph_disc.py
61 lines (46 loc) · 1.52 KB
/
draw_graph_disc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import os
import csv
import numpy as np
import argparse
from matplotlib import pyplot as plt
def smooth(arr, n):
end = -(len(arr)%n)
if end == 0:
end = None
arr = np.reshape(arr[:end], (-1, n))
arr = np.mean(arr, axis=1)
return arr
def drawall(name, x, metrics, n=100, begin=0):
x = smooth(x[-begin:], n)
for i, metric in enumerate(metrics):
metrics[i] = smooth(metric[-begin:], n)
plt.plot(x, metrics[0], label=name, linewidth=3)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--n', type=int, default=1)
parser.add_argument('--begin', type=int, default=0)
args = parser.parse_args()
dir ='save_graph/'
if not os.path.exists(dir):
os.makedirs(dir)
plt.figure(figsize=(15,5))
plt.xlabel('Episode')
plt.ylabel('Best record')
plt.title('Result')
for name in ['ra2c', 'rdqn']:#, 'ppo', 'rddpg_per', 'td3_per']:
filename = './save_stat/' + name + '_stat.csv'
bestY = []
with open(filename, 'r') as f:
read = csv.reader(f)
for i, row in enumerate(read):
if 'ppo' in name:
bestY.append(float(row[0]))
else:
bestY.append(float(row[3]))
episodes = [i for i in range(len(bestY))]
metrics = [
bestY
]
drawall(name, episodes, metrics, args.n, args.begin)
plt.legend()
plt.savefig(dir + '/result.png')