forked from daoqiugsy/wust_cv
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path实验五RNN4.py
110 lines (94 loc) · 3.57 KB
/
实验五RNN4.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets, transforms
from torch.autograd import Variable
from matplotlib import pyplot as plt
import sys
device = torch.device('cpu')
class RNN(nn.Module):
def __init__(self):
super().__init__()
self.rnn = nn.RNN(
input_size=28,
hidden_size=128,
num_layers=1,
batch_first=True,
)
self.hidden2one_list = []
for i in range(28):
self.hidden2one_list.append(nn.Linear(128, 1))
self.Out2Class = nn.Linear(28, 10)
def forward(self, input):
output, hn = self.rnn(input, None)
hidden2one_res = []
for i in range(28):
tmp_res = self.hidden2one_list[i](output[:, i, :])
# print(tmp_res.shape)
hidden2one_res.append(tmp_res.data)
hidden2one_res = torch.cat(hidden2one_res, dim=1) # 或者先对hidden2one_res中的元素squeeze(1),再用torch.stack
# print(hidden2one_res.shape) #torch.Size([64, 28])
res = self.Out2Class(hidden2one_res)
return res
model = RNN()
model = model.to(device)
print(model)
model = model.train()
img_transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])])
dataset_train = datasets.MNIST(root='./data', transform=img_transform, train=True, download=True)
dataset_test = datasets.MNIST(root='./data', transform=img_transform, train=False, download=True)
train_loader = torch.utils.data.DataLoader(dataset=dataset_train, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=dataset_test, batch_size=64, shuffle=False)
# images,label = next(iter(train_loader))
# print(images.shape)
# print(label.shape)
# images_example = torchvision.utils.make_grid(images)
# images_example = images_example.numpy().transpose(1,2,0)
# mean = [0.5,0.5,0.5]
# std = [0.5,0.5,0.5]
# images_example = images_example*std + mean
# plt.imshow(images_example)
# plt.show()
def Get_ACC():
correct = 0
total_num = len(dataset_test)
for item in test_loader:
batch_imgs, batch_labels = item
batch_imgs = batch_imgs.squeeze(1)
batch_imgs = Variable(batch_imgs)
batch_imgs = batch_imgs.to(device)
batch_labels = batch_labels.to(device)
out = model(batch_imgs)
_, pred = torch.max(out.data, 1)
correct += torch.sum(pred == batch_labels)
# print(pred)
# print(batch_labels)
correct = correct.data.item()
acc = correct / total_num
print('correct={},Test ACC:{:.5}'.format(correct, acc))
optimizer = torch.optim.Adam(model.parameters())
loss_f = nn.CrossEntropyLoss()
Get_ACC()
for epoch in range(5):
print('epoch:{}'.format(epoch))
cnt = 0
for item in train_loader:
batch_imgs, batch_labels = item
batch_imgs = batch_imgs.squeeze(1)
# print(batch_imgs.shape)
batch_imgs, batch_labels = Variable(batch_imgs), Variable(batch_labels)
batch_imgs = batch_imgs.to(device)
batch_labels = batch_labels.to(device)
out = model(batch_imgs)
# print(out.shape)
loss = loss_f(out, batch_labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (cnt % 100 == 0):
print_loss = loss.data.item()
print('epoch:{},cnt:{},loss:{}'.format(epoch, cnt, print_loss))
cnt += 1
Get_ACC()
torch.save(model, 'model')