-
Notifications
You must be signed in to change notification settings - Fork 2
/
draw_models.py
25 lines (22 loc) · 1014 Bytes
/
draw_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from matplotlib import pyplot as plt
import numpy as np
# 修改成模型的名字
pic_name = 'MLP_CWRU'
dir_path = f'csv/{pic_name}'
step, acc_train = np.loadtxt(f'{dir_path}/acc_train.csv', unpack=True, delimiter=',', skiprows=1, usecols=(1, 2))
step, acc_val = np.loadtxt(f'{dir_path}/acc_val.csv', unpack=True, delimiter=',', skiprows=1, usecols=(1, 2))
step, loss_train = np.loadtxt(f'{dir_path}/loss_train.csv', unpack=True, delimiter=',', skiprows=1, usecols=(1, 2))
step, loss_val = np.loadtxt(f'{dir_path}/loss_val.csv', unpack=True, delimiter=',', skiprows=1, usecols=(1, 2))
fig, ax = plt.subplots()
plt.ylim(-0.1, 2.1)
ax.grid()
ax.plot(step, acc_train, label='train acc', color='red')
ax.plot(step, loss_train, label='train loss', color='green')
ax.plot(step, acc_val, label='val acc', color='blue')
ax.plot(step, loss_val, label='val loss', color='black')
ax.set_xlabel('epoch')
ax.set_ylabel('acc-loss')
ax.set_title(f'{pic_name}')
ax.legend()
plt.savefig(f'picture_model/{pic_name}')
plt.show()