-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathinflate_resnet.py
117 lines (98 loc) · 4.26 KB
/
inflate_resnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import argparse
import copy
import json
from matplotlib import pyplot as plt
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from src.i3res import I3ResNet
# To profile uncomment @profile and run `kernprof -lv inflate_resnet.py`
# @profile
def run_inflater(args):
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
dataset = datasets.ImageFolder('data/dummy-dataset',
transforms.Compose([
transforms.CenterCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
class_idx = json.load(open('data/imagenet_class_index.json'))
imagenet_classes = [class_idx[str(k)][1] for k in range(len(class_idx))]
if args.resnet_nb == 50:
resnet = torchvision.models.resnet50(pretrained=True)
elif args.resnet_nb == 101:
resnet = torchvision.models.resnet101(pretrained=True)
elif args.resnet_nb == 152:
resnet = torchvision.models.resnet152(pretrained=True)
else:
raise ValueError('resnet_nb should be in [50|101|152] but got {}'
).format(args.resnet_nb)
loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)
i3resnet = I3ResNet(copy.deepcopy(resnet), args.frame_nb)
i3resnet.train()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
i3resnet = i3resnet.to(device)
resnet = resnet.to(device)
for i, (input_2d, target) in enumerate(loader):
target = target.to(device)
target_var = torch.autograd.Variable(target)
input_2d_var = torch.autograd.Variable(input_2d.to(device))
out2d = resnet(input_2d_var)
out2d = out2d.cpu().data
input_3d = input_2d.unsqueeze(2).repeat(1, 1, args.frame_nb, 1, 1)
input_3d_var = torch.autograd.Variable(input_3d.to(device))
out3d = i3resnet(input_3d_var)
out3d = out3d.cpu().data
out_diff = out2d - out3d
print('mean abs error {}'.format(out_diff.abs().mean()))
print('mean abs val {}'.format(out2d.abs().mean()))
# Computing errors between final predictions of inflated and uninflated
# dense networks
print(
'Batch {i} maximum error between 2d and inflated predictions: {err}'.
format(i=i, err=out_diff.max()))
assert (out_diff.max() < 0.0001)
if args.display_samples:
max_vals, max_indexes = out3d.max(1)
for sample_idx in range(out3d.shape[0]):
sample_out = out3d[sample_idx]
top_val, top_idx = torch.sort(sample_out, 0, descending=True)
print('Top {} classes and associated scores: '.format(
args.top_k))
for i in range(args.top_k):
print('[{}]: {}'.format(imagenet_classes[top_idx[i]],
top_val[i]))
sample_img = input_2d[sample_idx].numpy().transpose(1, 2, 0)
sample_img = (sample_img - sample_img.min()) * (1 / (
sample_img.max() - sample_img.min()))
plt.imshow(sample_img)
plt.show()
if __name__ == "__main__":
parser = argparse.ArgumentParser('Inflates ResNet and runs\
it on dummy dataset to compare outputs from original and inflated networks\
(should be the same)')
parser.add_argument(
'--resnet_nb',
type=int,
default=50,
help='What version of ResNet to use, in [50|101|152]')
parser.add_argument(
'--display_samples',
action='store_true',
help='Whether to display samples and associated\
scores for 3d inflated resnet')
parser.add_argument(
'--top_k',
type=int,
default='5',
help='When display_samples, number of top classes to display')
parser.add_argument(
'--frame_nb',
type=int,
default='16',
help='Number of video_frames to use (should be a multiple of 8)')
args = parser.parse_args()
run_inflater(args)