-
Notifications
You must be signed in to change notification settings - Fork 44
/
convert.py
197 lines (169 loc) · 8.14 KB
/
convert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import numpy as np
from PIL import Image
import torch
import torchvision.transforms as transforms
from torch.autograd import Variable
from model import NetworkImageNet
from genotypes import PNASNet
from operations import *
from utils import preprocess_for_eval
import sys
import os
sys.path.append('../PNASNet.TF')
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import tensorflow as tf
from pnasnet import build_pnasnet_large, pnasnet_large_arg_scope
slim = tf.contrib.slim
class ConvertPNASNet(object):
def __init__(self):
self.image = Image.open('data/cat.jpg')
self.read_tf_weight()
self.write_pytorch_weight()
def read_tf_weight(self):
self.weight_dict = {}
image_ph = tf.placeholder(tf.uint8, (None, None, 3))
image_proc = preprocess_for_eval(image_ph, 323, 323)
with slim.arg_scope(pnasnet_large_arg_scope()):
logits, end_points = build_pnasnet_large(
tf.expand_dims(image_proc, 0), num_classes=1001, is_training=False)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
ckpt_restorer = tf.train.Saver()
ckpt_restorer.restore(sess, '../PNASNet.TF/data/model.ckpt')
weight_keys = [var.name[:-2] for var in tf.global_variables()]
weight_vals = sess.run(tf.global_variables())
for weight_key, weight_val in zip(weight_keys, weight_vals):
self.weight_dict[weight_key] = weight_val
self.tf_logits, self.tf_end_points, self.tf_image_proc = sess.run(
[logits, end_points, image_proc], feed_dict={image_ph: self.image})
def write_pytorch_weight(self):
model = NetworkImageNet(216, 1001, 12, False, PNASNet)
model.drop_path_prob = 0
model.eval()
self.used_keys = []
self.convert_conv(model.conv0, 'conv0/weights')
self.convert_bn(model.conv0_bn, 'conv0_bn/gamma', 'conv0_bn/beta',
'conv0_bn/moving_mean', 'conv0_bn/moving_variance')
self.convert_cell(model.stem1, 'cell_stem_0/')
self.convert_cell(model.stem2, 'cell_stem_1/')
for i in range(12):
self.convert_cell(model.cells[i], 'cell_{}/'.format(i))
self.convert_fc(model.classifier, 'final_layer/FC/weights',
'final_layer/FC/biases')
print('Conversion complete!')
print('Check 1: whether all TF variables are used...')
assert len(self.weight_dict) == len(self.used_keys)
print('Pass!')
model = model.cuda()
image = self.tf_image_proc.transpose((2, 0, 1))
image = Variable(self.Tensor(image)).cuda()
logits, _ = model(image.unsqueeze(0))
self.pytorch_logits = logits.data.cpu().numpy()
print('Check 2: whether logits have small diff...')
assert np.max(np.abs(self.tf_logits - self.pytorch_logits)) < 1e-5
print('Pass!')
model_path = 'data/PNASNet-5_Large.pth'
torch.save(model.state_dict(), model_path)
print('PyTorch model saved to {}'.format(model_path))
def convert_cell(self, cell, name):
# cell.preprocess0
assert isinstance(cell.preprocess0, FactorizedReduce) or isinstance(cell.preprocess0, ReLUConvBN) or isinstance(cell.preprocess0, Identity)
if isinstance(cell.preprocess0, FactorizedReduce):
self.convert_conv(cell.preprocess0.conv_1, name + 'path1_conv/weights')
self.convert_conv(cell.preprocess0.conv_2, name + 'path2_conv/weights')
self.convert_bn(cell.preprocess0.bn, name + 'final_path_bn/gamma',
name + 'final_path_bn/beta', name + 'final_path_bn/moving_mean',
name + 'final_path_bn/moving_variance')
else:
if name + 'prev_1x1/weights' in self.weight_dict:
self.convert_conv(cell.preprocess0.op[1], name + 'prev_1x1/weights')
self.convert_bn(cell.preprocess0.op[2], name + 'prev_bn/gamma',
name + 'prev_bn/beta', name + 'prev_bn/moving_mean',
name + 'prev_bn/moving_variance')
# else preprocess0 is Identity or = preprocess1; do nothing
# cell.preprocess1
assert isinstance(cell.preprocess1, ReLUConvBN)
self.convert_conv(cell.preprocess1.op[1], name + '1x1/weights')
self.convert_bn(cell.preprocess1.op[2], name + 'beginning_bn/gamma',
name + 'beginning_bn/beta', name + 'beginning_bn/moving_mean',
name + 'beginning_bn/moving_variance')
# cell._ops
for i in range(len(cell._ops)):
side = 'left/' if i % 2 == 0 else 'right/'
prefix = name + 'comb_iter_{}/'.format(i // 2) + side
if isinstance(cell._ops[i], SepConv):
suffix = '{0}x{0}'.format(cell._ops[i].op[1].kernel_size[0])
self.convert_conv(cell._ops[i].op[1],
prefix + 'separable_' + suffix + '_1/depthwise_weights', sep=True)
self.convert_conv(cell._ops[i].op[2],
prefix + 'separable_' + suffix + '_1/pointwise_weights', sep=False)
self.convert_bn(cell._ops[i].op[3],
prefix + 'bn_sep_' + suffix + '_1/gamma',
prefix + 'bn_sep_' + suffix + '_1/beta',
prefix + 'bn_sep_' + suffix + '_1/moving_mean',
prefix + 'bn_sep_' + suffix + '_1/moving_variance')
self.convert_conv(cell._ops[i].op[5],
prefix + 'separable_' + suffix + '_2/depthwise_weights', sep=True)
self.convert_conv(cell._ops[i].op[6],
prefix + 'separable_' + suffix + '_2/pointwise_weights', sep=False)
self.convert_bn(cell._ops[i].op[7],
prefix + 'bn_sep_' + suffix + '_2/gamma',
prefix + 'bn_sep_' + suffix + '_2/beta',
prefix + 'bn_sep_' + suffix + '_2/moving_mean',
prefix + 'bn_sep_' + suffix + '_2/moving_variance')
elif isinstance(cell._ops[i], ReLUConvBN):
# skip_connect with stride > 1
self.convert_conv(cell._ops[i].op[1], prefix + '1x1/weights')
self.convert_bn(cell._ops[i].op[2],
prefix + 'bn_1/gamma', prefix + 'bn_1/beta',
prefix + 'bn_1/moving_mean', prefix + 'bn_1/moving_variance')
elif isinstance(cell._ops[i], nn.Sequential):
# max_pool or avg_pool with C_in != C_out
self.convert_conv(cell._ops[i][1], prefix + '1x1/weights')
self.convert_bn(cell._ops[i][2],
prefix + 'bn_1/gamma', prefix + 'bn_1/beta',
prefix + 'bn_1/moving_mean', prefix + 'bn_1/moving_variance')
def convert_conv(self, conv2d, weights_key, sep=False):
weights = self.weight_dict[weights_key]
if sep:
# TF: [filter_height, filter_width, in_channels, channel_multiplier]
# TF: [1, 1, channel_multiplier * in_channels, channel_multiplier]
# PyTorch: [out_channels, in_channels // groups, *kernel_size]
weights = np.transpose(weights, (2, 3, 0, 1))
else:
# TF: [filter_height, filter_width, in_channels, out_channels]
# PyTorch: [out_channels, in_channels, *kernel_size]
weights = np.transpose(weights, (3, 2, 0, 1))
assert conv2d.weight.shape == self.Param(weights).shape, '{0} vs {1}'.format(conv2d.weight.shape, self.Param(weights).shape)
conv2d.weight = self.Param(weights)
self.used_keys += [weights_key]
def convert_bn(self, bn, gamma_key, beta_key, moving_mean_key, moving_var_key):
gamma = self.weight_dict[gamma_key]
beta = self.weight_dict[beta_key]
moving_mean = self.weight_dict[moving_mean_key]
moving_var = self.weight_dict[moving_var_key]
assert bn.weight.shape == self.Param(gamma).shape
assert bn.bias.shape == self.Param(beta).shape
assert bn.running_mean.shape == self.Tensor(moving_mean).shape
assert bn.running_var.shape == self.Tensor(moving_var).shape
bn.weight = self.Param(gamma)
bn.bias = self.Param(beta)
bn.running_mean = self.Tensor(moving_mean)
bn.running_var = self.Tensor(moving_var)
self.used_keys += [gamma_key, beta_key, moving_mean_key, moving_var_key]
def convert_fc(self, fc, weights_key, biases_key):
weights = self.weight_dict[weights_key]
biases = self.weight_dict[biases_key]
weights = np.transpose(weights)
assert fc.weight.shape == self.Param(weights).shape
assert fc.bias.shape == self.Param(biases).shape
fc.weight = self.Param(weights)
fc.bias = self.Param(biases)
self.used_keys += [weights_key, biases_key]
def Param(self, x):
return torch.nn.Parameter(torch.from_numpy(x))
def Tensor(self, x):
return torch.from_numpy(x)
if __name__ == '__main__':
ConvertPNASNet()