-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbanti2chamanti.py
66 lines (56 loc) · 2.07 KB
/
banti2chamanti.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import pickle
from utils import cp, pp, den
def banti2chamanti(banti_pkl_file_name,
dont_pool_along_width=True,
conv_trainable=False,
dense_trainable=True):
"""
Input to Conv layer is (bz, SW, SH, m)
Kernals are k, (ksw, ksh)
Output shape is (bz, SW, SH, k)
Kernal weights will be (ksw, ksh, m, k)
Banti kernal weights will be (k, m, ksh, ksw)
So just apply transpose before loading to chamanti
"""
print("Loading weights from ", banti_pkl_file_name)
with open(banti_pkl_file_name, 'rb') as f:
d = pickle.load(f)
bspecs = d["layers"]
bwts = d["allwts"]
layer_args = []
last_pool = None
n, nconv, npool, nmaps = 0, 0, 0, 1
while n < len(bspecs):
name, spec = bspecs[n]
if name == 'ElasticLayer':
pass
elif name == 'ConvLayer':
if 'mode' in spec:
assert spec['mode'] == 'same'
nconv += 1
nmaps = spec['num_maps']
w, b = bwts[n]
layer_args.append(cp(f'conv{nconv}', nmaps, spec['filter_sz'], spec['actvn'],
weights=(w.T, b), trainable=conv_trainable))
elif name == 'PoolLayer':
npool += 1
last_pool = pp(f'pool{npool}', spec['pool_sz'])
layer_args.append(last_pool)
else:
break
n += 1
if dont_pool_along_width:
last_pool['pool_size'] = (1, last_pool['pool_size'][1])
if bspecs[n][0] in ('SoftmaxLayer', 'HiddenLayer'):
w, b = bwts[n]
nin, nout = w.shape
d = int((nin//nmaps)**.5)
assert nin == nmaps*d*d # eg: 5342 = 162*6*6
w = w.reshape((nmaps, d, d, nout))
w = w[:, :, d//2-1, :].reshape((nmaps*d, nout)) # Extract the middle column
try:
actvn = bspecs[n][1]["actvn"]
except KeyError:
actvn = "linear"
layer_args.append(den('dense1', len(b), actvn, weights=(w, b), trainable=dense_trainable))
return layer_args