-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_model.py
173 lines (150 loc) · 6.38 KB
/
test_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
# -*- coding: utf-8 -*-
# @Time : 2021/6/17 20:29
# @Author : dejahu
# @Email : [email protected]
# @File : test_model.py
# @Software: PyCharm
# @Brief : 模型测试代码,测试会生成热力图,热力图会保存在results目录下
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
plt.rcParams['font.family'] = ['sans-serif']
plt.rcParams['font.sans-serif'] = ['SimHei']
# 数据加载,分别从训练的数据集的文件夹和测试的文件夹中加载训练集和验证集
def data_load(data_dir, test_data_dir, img_height, img_width, batch_size):
# 加载训练集
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
label_mode='categorical',
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
# 加载测试集
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
test_data_dir,
label_mode='categorical',
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
class_names = train_ds.class_names
# 返回处理之后的训练集、验证集和类名
return train_ds, val_ds, class_names
# 测试mobilenet准确率
def test_mobilenet():
# todo 加载数据, 修改为你自己的数据集的路径
train_ds, test_ds, class_names = data_load("../data/vegetable_fruit/image_data",
"../data/vegetable_fruit/test_image_data", 224, 224, 16)
# todo 加载模型,修改为你的模型名称
model = tf.keras.models.load_model("models/mobilenet_fv.h5")
# model.summary()
# 测试
loss, accuracy = model.evaluate(test_ds)
# 输出结果
print('Mobilenet test accuracy :', accuracy)
test_real_labels = []
test_pre_labels = []
for test_batch_images, test_batch_labels in test_ds:
test_batch_labels = test_batch_labels.numpy()
test_batch_pres = model.predict(test_batch_images)
# print(test_batch_pres)
test_batch_labels_max = np.argmax(test_batch_labels, axis=1)
test_batch_pres_max = np.argmax(test_batch_pres, axis=1)
# print(test_batch_labels_max)
# print(test_batch_pres_max)
# 将推理对应的标签取出
for i in test_batch_labels_max:
test_real_labels.append(i)
for i in test_batch_pres_max:
test_pre_labels.append(i)
# break
# print(test_real_labels)
# print(test_pre_labels)
class_names_length = len(class_names)
heat_maps = np.zeros((class_names_length, class_names_length))
for test_real_label, test_pre_label in zip(test_real_labels, test_pre_labels):
heat_maps[test_real_label][test_pre_label] = heat_maps[test_real_label][test_pre_label] + 1
print(heat_maps)
heat_maps_sum = np.sum(heat_maps, axis=1).reshape(-1, 1)
# print(heat_maps_sum)
print()
heat_maps_float = heat_maps / heat_maps_sum
print(heat_maps_float)
# title, x_labels, y_labels, harvest
show_heatmaps(title="heatmap", x_labels=class_names, y_labels=class_names, harvest=heat_maps_float,
save_name="results/heatmap_mobilenet.png")
# 测试cnn模型准确率
def test_cnn():
# todo 加载数据, 修改为你自己的数据集的路径
train_ds, test_ds, class_names = data_load("../data/vegetable_fruit/image_data",
"../data/vegetable_fruit/test_image_data", 224, 224, 16)
# todo 加载模型,修改为你的模型名称
model = tf.keras.models.load_model("models/cnn_fv.h5")
# model.summary()
# 测试
loss, accuracy = model.evaluate(test_ds)
# 输出结果
print('CNN test accuracy :', accuracy)
# 对模型分开进行推理
test_real_labels = []
test_pre_labels = []
for test_batch_images, test_batch_labels in test_ds:
test_batch_labels = test_batch_labels.numpy()
test_batch_pres = model.predict(test_batch_images)
# print(test_batch_pres)
test_batch_labels_max = np.argmax(test_batch_labels, axis=1)
test_batch_pres_max = np.argmax(test_batch_pres, axis=1)
# print(test_batch_labels_max)
# print(test_batch_pres_max)
# 将推理对应的标签取出
for i in test_batch_labels_max:
test_real_labels.append(i)
for i in test_batch_pres_max:
test_pre_labels.append(i)
# break
# print(test_real_labels)
# print(test_pre_labels)
class_names_length = len(class_names)
heat_maps = np.zeros((class_names_length, class_names_length))
for test_real_label, test_pre_label in zip(test_real_labels, test_pre_labels):
heat_maps[test_real_label][test_pre_label] = heat_maps[test_real_label][test_pre_label] + 1
print(heat_maps)
heat_maps_sum = np.sum(heat_maps, axis=1).reshape(-1, 1)
# print(heat_maps_sum)
print()
heat_maps_float = heat_maps / heat_maps_sum
print(heat_maps_float)
# title, x_labels, y_labels, harvest
show_heatmaps(title="heatmap", x_labels=class_names, y_labels=class_names, harvest=heat_maps_float,
save_name="results/heatmap_cnn.png")
def show_heatmaps(title, x_labels, y_labels, harvest, save_name):
# 这里是创建一个画布
fig, ax = plt.subplots()
# cmap https://blog.csdn.net/ztf312/article/details/102474190
im = ax.imshow(harvest, cmap="OrRd")
# 这里是修改标签
# We want to show all ticks...
ax.set_xticks(np.arange(len(y_labels)))
ax.set_yticks(np.arange(len(x_labels)))
# ... and label them with the respective list entries
ax.set_xticklabels(y_labels)
ax.set_yticklabels(x_labels)
# 因为x轴的标签太长了,需要旋转一下,更加好看
# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
rotation_mode="anchor")
# 添加每个热力块的具体数值
# Loop over data dimensions and create text annotations.
for i in range(len(x_labels)):
for j in range(len(y_labels)):
text = ax.text(j, i, round(harvest[i, j], 2),
ha="center", va="center", color="black")
ax.set_xlabel("Predict label")
ax.set_ylabel("Actual label")
ax.set_title(title)
fig.tight_layout()
plt.colorbar(im)
plt.savefig(save_name, dpi=100)
# plt.show()
if __name__ == '__main__':
test_mobilenet()
test_cnn()