-
Notifications
You must be signed in to change notification settings - Fork 47
/
functions.py
866 lines (739 loc) · 43 KB
/
functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
import torch
import copy
import sys
import time
import pickle
import numpy as np
import warnings
from scipy.interpolate import Rbf
from collections import OrderedDict
from constants import *
def update_progress(index, length, **kwargs):
'''
display progress
Input:
`index`: (int) shows the index of current progress
`length`: (int) total length of the progress
`**kwargs`: info to display (e.g. accuracy)
'''
barLength = 10 # Modify this to change the length of the progress bar
progress = float(index/length)
if isinstance(progress, int):
progress = float(progress)
if not isinstance(progress, float):
progress = 0
if progress >= 1:
progress = 1
block = int(round(barLength*progress))
text = "\rPercent: [{0}] {1:.2f}% ({2}/{3}) ".format(
"#"*block + "-"*(barLength-block), round(progress*100, 3), \
index, length)
for key, value in kwargs.items():
text = text + str(key) + ': ' + str(value) + ', '
if len(kwargs) != 0:
text = text[:-2:]
sys.stdout.write(text)
sys.stdout.flush()
def get_layer_by_param_name(model, param_name):
'''
Get a certain layer (e.g. torch.Conv2d) from a model
by layer parameter name (e.g. models.conv_layers.0.weight)
Input:
`model`: model we want to get a certain layer from
`param_name`: (string) layer parameter name
Output:
`layer`: (e.g. torch.nn.Conv2d)
'''
# Get layer from model using layer name.
layer_name_str_split = param_name.split(STRING_SEPARATOR)[:-1]
layer = model
for s in layer_name_str_split:
layer = getattr(layer, s)
return layer
def get_keys_from_ordered_dict(ordered_dict):
'''
get ordered list of keys from ordered dict
Input:
`ordered_dict`
Output:
`dict_keys`
'''
dict_keys = []
for key, _ in ordered_dict.items():
dict_keys.append(key) # get key from (key, value) pair
return dict_keys
def extract_feature_map_sizes(model, input_data_shape):
'''
get conv and fc layerwise feature map size
Input:
`model`: model which we want to get layerwise feature map size.
`input_data_shape`: (list) [C, H, W].
Output:
`fmap_sizes_dict`: (dict) layerwise feature map sizes.
'''
fmap_sizes_dict = {}
hooks = []
model = model.cuda()
model.eval()
def _register_hook(module):
def _hook(module, input, output):
type_str = module.__class__.__name__
if type_str in (CONV_LAYER_TYPES + FC_LAYER_TYPES):
module_id = id(module)
in_fmap_size = list(input[0].size())
out_fmap_size = list(output.size())
fmap_sizes_dict[module_id] = {KEY_INPUT_FEATURE_MAP_SIZE: in_fmap_size,
KEY_OUTPUT_FEATURE_MAP_SIZE: out_fmap_size}
if (not isinstance(module, torch.nn.Sequential) and not isinstance(module, torch.nn.ModuleList) and not (
module == model)):
hooks.append(module.register_forward_hook(_hook))
model.apply(_register_hook)
_ = model(torch.randn([1, *input_data_shape]).cuda())
for hook in hooks:
hook.remove()
return fmap_sizes_dict
def get_network_def_from_model(model, input_data_shape):
'''
return network def (OrderedDict) of the input model
network_def only contains information about FC, Conv2d, ConvTranspose2d
not includes batchnorm ...
Input:
`model`: model we want to get network_def from
`input_data_shape`: (list) [C, H, W].
Output:
`network_def`: (OrderedDict)
keys(): layer name (e.g. model.0.1, feature.2 ...)
values(): layer properties (dict)
'''
network_def = OrderedDict()
state_dict = model.state_dict()
# extract model keys in ordered manner from model dict.
state_dict_keys = get_keys_from_ordered_dict(state_dict)
# extract the feature map sizes.
fmap_sizes_dict = extract_feature_map_sizes(model, input_data_shape)
# for pixel shuffle
previous_layer_name_str = None
previous_out_channels = None
before_squared_pixel_shuffle_factor = int(1)
for layer_param_name in state_dict_keys:
layer = get_layer_by_param_name(model, layer_param_name)
layer_id = id(layer)
layer_name_str = STRING_SEPARATOR.join(layer_param_name.split(STRING_SEPARATOR)[:-1])
layer_type_str = layer.__class__.__name__
# If conv layer, populate network definition.
# WARNING: ignores maxpool and upsampling layers.
if layer_type_str in (CONV_LAYER_TYPES + FC_LAYER_TYPES) and WEIGHTSTRING in layer_param_name:
# Populate network def.
if layer_type_str in FC_LAYER_TYPES:
network_def[layer_name_str] = {
KEY_IS_DEPTHWISE: False,
KEY_NUM_IN_CHANNELS: layer.in_features,
KEY_NUM_OUT_CHANNELS: layer.out_features,
KEY_KERNEL_SIZE: (1, 1),
KEY_STRIDE: (1, 1),
KEY_PADDING: (0, 0),
KEY_GROUPS: 1,
KEY_INPUT_FEATURE_MAP_SIZE: [1, fmap_sizes_dict[layer_id][KEY_INPUT_FEATURE_MAP_SIZE][1], 1, 1],
KEY_OUTPUT_FEATURE_MAP_SIZE: [1, fmap_sizes_dict[layer_id][KEY_OUTPUT_FEATURE_MAP_SIZE][1], 1, 1]
}
else: # this means layer_type_str is in CONV_LAYER_TYPES
# Note: Need to handle the special case when there is only one filter in the depth-wise layer
# because the number of groups will also be 1, which is the same as that of the point-wise layer.
if layer.groups == 1:
is_depthwise = False
else:
is_depthwise = True
network_def[layer_name_str] = {
KEY_IS_DEPTHWISE: is_depthwise,
KEY_NUM_IN_CHANNELS: layer.in_channels,
KEY_NUM_OUT_CHANNELS: layer.out_channels,
KEY_KERNEL_SIZE: layer.kernel_size,
KEY_STRIDE: layer.stride,
KEY_PADDING: layer.padding,
KEY_GROUPS: layer.groups,
# (1, C, H, W)
KEY_INPUT_FEATURE_MAP_SIZE: fmap_sizes_dict[layer_id][KEY_INPUT_FEATURE_MAP_SIZE],
KEY_OUTPUT_FEATURE_MAP_SIZE: fmap_sizes_dict[layer_id][KEY_OUTPUT_FEATURE_MAP_SIZE]
}
network_def[layer_name_str][KEY_LAYER_TYPE_STR] = layer_type_str
# Support pixel shuffle.
if layer_type_str in FC_LAYER_TYPES:
before_squared_pixel_shuffle_factor = int(1)
else:
if previous_out_channels is None:
before_squared_pixel_shuffle_factor = int(1)
else:
if previous_out_channels % layer.in_channels != 0:
raise ValueError('previous_out_channels is not divisible by layer.in_channels.')
before_squared_pixel_shuffle_factor = int(previous_out_channels / layer.in_channels)
previous_out_channels = layer.out_channels
if previous_layer_name_str is not None:
network_def[previous_layer_name_str][
KEY_AFTER_SQUSRED_PIXEL_SHUFFLE_FACTOR] = before_squared_pixel_shuffle_factor
network_def[layer_name_str][KEY_BEFORE_SQUARED_PIXEL_SHUFFLE_FACTOR] = before_squared_pixel_shuffle_factor
previous_layer_name_str = layer_name_str
if previous_layer_name_str:
network_def[previous_layer_name_str][
KEY_AFTER_SQUSRED_PIXEL_SHUFFLE_FACTOR] = before_squared_pixel_shuffle_factor
return network_def
def compute_weights_and_macs(network_def):
'''
Compute the number of weights and MACs of a whole network.
Input:
`network_def`: defined in get_network_def_from_model()
Output:
`layer_weights_dict`: (OrderedDict) records layerwise num of weights.
`total_num_weights`: (int) total num of weights.
`layer_macs_dict`: (OrderedDict) recordes layerwise num of MACs.
`total_num_macs`: (int) total num of MACs.
'''
total_num_weights, total_num_macs = 0, 0
# Extract conv layer names from ordered network dict.
network_def_keys = get_keys_from_ordered_dict(network_def)
# Init dict to store num resources for each layer.
layer_weights_dict = OrderedDict()
layer_macs_dict = OrderedDict()
# Iterate over conv layers in network def.
for layer_name in network_def_keys:
# Take product of filter size dimensions to get num weights for layer.
layer_num_weights = (network_def[layer_name][KEY_NUM_OUT_CHANNELS] / \
network_def[layer_name][KEY_GROUPS]) * \
network_def[layer_name][KEY_NUM_IN_CHANNELS] * \
network_def[layer_name][KEY_KERNEL_SIZE][0] * \
network_def[layer_name][KEY_KERNEL_SIZE][1]
# Store num weights in layer dict and add to total.
layer_weights_dict[layer_name] = layer_num_weights
total_num_weights += layer_num_weights
# Determine num macs for layer using output size.
output_size = network_def[layer_name][KEY_OUTPUT_FEATURE_MAP_SIZE]
output_height, output_width = output_size[2], output_size[3]
layer_num_macs = layer_num_weights * output_width * output_height
# Store num macs in layer dict and add to total.
layer_macs_dict[layer_name] = layer_num_macs
total_num_macs += layer_num_macs
return layer_weights_dict, total_num_weights, layer_macs_dict, total_num_macs
def measure_latency(model, input_data_shape, runtimes=500):
'''
Measure latency of 'model'
Randomly sample 'runtimes' inputs with normal distribution and
measure the latencies
Input:
`model`: model to be measured (e.g. torch.nn.Conv2d)
`input_shape`: (list) input shape of the model (e.g. (B, C, H, W))
Output:
average time (float)
'''
total_time = .0
is_cuda = next(model.parameters()).is_cuda
if is_cuda:
cuda_num = next(model.parameters()).get_device()
for i in range(runtimes):
if is_cuda:
input = torch.cuda.FloatTensor(*input_data_shape).normal_(0, 1)
input = input.cuda(cuda_num)
with torch.no_grad():
start = time.time()
model(input)
torch.cuda.synchronize()
finish = time.time()
else:
input = torch.randn(input_data_shape)
with torch.no_grad():
start = time.time()
model(input)
finish = time.time()
total_time += (finish - start)
return total_time/float(runtimes)
def compute_latency_from_lookup_table(network_def, lookup_table_path):
'''
Compute the latency of all layers defined in `network_def` (only including Conv and FC).
When the value of latency is not in the lookup table, that value would be interpolated.
Input:
`network_def`: defined in get_network_def_from_model()
`lookup_table_path`: (string) path to lookup table
Output:
`latency`: (float) latency
'''
latency = .0
with open(lookup_table_path, 'rb') as file_id:
lookup_table = pickle.load(file_id)
for layer_name, layer_properties in network_def.items():
if layer_name not in lookup_table.keys():
raise ValueError('Layer name {} in network def not found in lookup table'.format(layer_name))
break
num_in_channels = layer_properties[KEY_NUM_IN_CHANNELS]
num_out_channels = layer_properties[KEY_NUM_OUT_CHANNELS]
if (num_in_channels, num_out_channels) in lookup_table[layer_name][KEY_LATENCY].keys():
latency += lookup_table[layer_name][KEY_LATENCY][(num_in_channels, num_out_channels)]
else:
# Not found in the lookup table, then interpolate the latency
feature_samples = np.array(list(lookup_table[layer_name][KEY_LATENCY].keys()))
feature_samples_in = feature_samples[:, 0]
feature_samples_out = feature_samples[:, 1]
measurement = np.array(list(lookup_table[layer_name][KEY_LATENCY].values()))
assert feature_samples_in.shape == feature_samples_out.shape
assert feature_samples_in.shape == measurement.shape
rbf = Rbf(feature_samples_in, feature_samples_out, \
measurement, function='cubic')
num_in_channels = np.array([num_in_channels])
num_out_channels = np.array([num_out_channels])
estimated_latency = rbf(num_in_channels, num_out_channels)
latency += estimated_latency[0]
return latency
def compute_resource(network_def, resource_type, lookup_table_path=None):
'''
compute resource based on resource type
Input:
`network_def`: defined in get_network_def_from_model()
`resource_type`: (string) (FLOPS/WEIGHTS/LATENCY)
`lookup_table_path`: (string) path to lookup table
Output:
`resource`: (float)
'''
if resource_type == 'FLOPS':
_, _, _, resource = compute_weights_and_macs(network_def)
elif resource_type == 'WEIGHTS':
_, resource, _, _ = compute_weights_and_macs(network_def)
elif resource_type == 'LATENCY':
resource = compute_latency_from_lookup_table(network_def, lookup_table_path)
else:
raise ValueError('Only support the resource type `FLOPS`, `WEIGHTS`, and `LATENCY`.')
return resource
def build_latency_lookup_table(network_def_full, lookup_table_path, min_conv_feature_size=8,
min_fc_feature_size=128, measure_latency_batch_size=4,
measure_latency_sample_times=500, verbose=False):
'''
Build lookup table for latencies of layers defined by `network_def_full`.
Supported layers: Conv2d, Linear, ConvTranspose2d
Modify get_network_def_from_model() and this function to include more layer types.
input:
`network_def_full`: defined in get_network_def_from_model()
`lookup_table_path`: (string) path to save the file of lookup table
`min_conv_feature_size`: (int) The size of feature maps of simplified layers (conv layer)
along channel dimmension are multiples of 'min_conv_feature_size'.
The reason is that on mobile devices, the computation of (B, 7, H, W) tensors
would take longer time than that of (B, 8, H, W) tensors.
`min_fc_feature_size`: (int) The size of features of simplified FC layers are
multiples of 'min_fc_feature_size'.
`measure_latency_batch_size`: (int) the batch size of input data
when running forward functions to measure latency.
`measure_latency_sample_times`: (int) the number of times to run the forward function of
a layer in order to get its latency.
`verbose`: (bool) set True to display detailed information.
'''
resource_type = 'LATENCY'
# Generate the lookup table.
lookup_table = OrderedDict()
for layer_name, layer_properties in network_def_full.items():
if verbose:
print('-------------------------------------------')
print('Measuring layer', layer_name, ':')
# If the layer has the same properties as a previous layer, directly use the previous lookup table.
for layer_name_pre, layer_properties_pre in network_def_full.items():
if layer_name_pre == layer_name:
break
# Do not consider pixel shuffling.
layer_properties_pre[KEY_BEFORE_SQUARED_PIXEL_SHUFFLE_FACTOR] = layer_properties[
KEY_BEFORE_SQUARED_PIXEL_SHUFFLE_FACTOR]
layer_properties_pre[KEY_AFTER_SQUSRED_PIXEL_SHUFFLE_FACTOR] = layer_properties[
KEY_AFTER_SQUSRED_PIXEL_SHUFFLE_FACTOR]
if layer_properties_pre == layer_properties:
lookup_table[layer_name] = lookup_table[layer_name_pre]
if verbose:
print(' Find previous layer', layer_name_pre, 'that has the same properties')
break
if layer_name in lookup_table:
continue
is_depthwise = layer_properties[KEY_IS_DEPTHWISE]
num_in_channels = layer_properties[KEY_NUM_IN_CHANNELS]
num_out_channels = layer_properties[KEY_NUM_OUT_CHANNELS]
kernel_size = layer_properties[KEY_KERNEL_SIZE]
stride = layer_properties[KEY_STRIDE]
padding = layer_properties[KEY_PADDING]
groups = layer_properties[KEY_GROUPS]
layer_type_str = layer_properties[KEY_LAYER_TYPE_STR]
input_data_shape = layer_properties[KEY_INPUT_FEATURE_MAP_SIZE]
lookup_table[layer_name] = {}
lookup_table[layer_name][KEY_IS_DEPTHWISE] = is_depthwise
lookup_table[layer_name][KEY_NUM_IN_CHANNELS] = num_in_channels
lookup_table[layer_name][KEY_NUM_OUT_CHANNELS] = num_out_channels
lookup_table[layer_name][KEY_KERNEL_SIZE] = kernel_size
lookup_table[layer_name][KEY_STRIDE] = stride
lookup_table[layer_name][KEY_PADDING] = padding
lookup_table[layer_name][KEY_GROUPS] = groups
lookup_table[layer_name][KEY_LAYER_TYPE_STR] = layer_type_str
lookup_table[layer_name][KEY_INPUT_FEATURE_MAP_SIZE] = input_data_shape
lookup_table[layer_name][KEY_LATENCY] = {}
print('Is depthwise:', is_depthwise)
print('Num in channels:', num_in_channels)
print('Num out channels:', num_out_channels)
print('Kernel size:', kernel_size)
print('Stride:', stride)
print('Padding:', padding)
print('Groups:', groups)
print('Input feature map size:', input_data_shape)
print('Layer type:', layer_type_str)
'''
if num_in_channels >= min_feature_size and \
(num_in_channels % min_feature_size != 0 or num_out_channels % min_feature_size != 0):
raise ValueError('The number of channels is not divisible by {}.'.format(str(min_feature_size)))
'''
if layer_type_str in CONV_LAYER_TYPES:
min_feature_size = min_conv_feature_size
elif layer_type_str in FC_LAYER_TYPES:
min_feature_size = min_fc_feature_size
else:
raise ValueError('Layer type {} not supported'.format(layer_type_str))
for reduced_num_in_channels in range(num_in_channels, 0, -min_feature_size):
if verbose:
index = 1
print(' Start measuring num_in_channels =', reduced_num_in_channels)
if is_depthwise:
reduced_num_out_channels_list = [reduced_num_in_channels]
else:
reduced_num_out_channels_list = list(range(num_out_channels, 0, -min_feature_size))
for reduced_num_out_channels in reduced_num_out_channels_list:
if resource_type == 'LATENCY':
if layer_type_str == 'Conv2d':
if is_depthwise:
layer_test = torch.nn.Conv2d(reduced_num_in_channels, reduced_num_out_channels, \
kernel_size, stride, padding, groups=reduced_num_in_channels)
else:
layer_test = torch.nn.Conv2d(reduced_num_in_channels, reduced_num_out_channels, \
kernel_size, stride, padding, groups=groups)
input_data_shape = layer_properties[KEY_INPUT_FEATURE_MAP_SIZE]
input_data_shape = (measure_latency_batch_size,
reduced_num_in_channels, *input_data_shape[2::])
elif layer_type_str == 'Linear':
layer_test = torch.nn.Linear(reduced_num_in_channels, reduced_num_out_channels)
input_data_shape = (measure_latency_batch_size, reduced_num_in_channels)
elif layer_type_str == 'ConvTranspose2d':
if is_depthwise:
layer_test = torch.nn.ConvTranspose2d(reduced_num_in_channels, reduced_num_out_channels,
kernel_size, stride, padding, groups=reduced_num_in_channels)
else:
layer_test = torch.nn.ConvTranspose2d(reduced_num_in_channels, reduced_num_out_channels,
kernel_size, stride, padding, groups=groups)
input_data_shape = layer_properties[KEY_INPUT_FEATURE_MAP_SIZE]
input_data_shape = (measure_latency_batch_size,
reduced_num_in_channels, *input_data_shape[2::])
else:
raise ValueError('Not support this type of layer.')
if torch.cuda.is_available():
layer_test = layer_test.cuda()
measurement = measure_latency(layer_test, input_data_shape, measure_latency_sample_times)
else:
raise ValueError('Only support building the lookup table for `LATENCY`.')
# Add the measurement into the lookup table.
lookup_table[layer_name][KEY_LATENCY][(reduced_num_in_channels, reduced_num_out_channels)] = measurement
if verbose:
update_progress(index, len(reduced_num_out_channels_list), latency=str(measurement))
index = index + 1
if verbose:
print(' ')
print(' Finish measuring num_in_channels =', reduced_num_in_channels)
# Save the lookup table.
with open(lookup_table_path, 'wb') as file_id:
pickle.dump(lookup_table, file_id)
return
def simplify_network_def_based_on_constraint(network_def, block, constraint, resource_type,
lookup_table_path=None, skip_connection_block_sets=[],
min_feature_size=8):
'''
Derive how much a certain block of layers ('block') should be simplified
based on resource constraints.
Here we treat one block as one layer although a block can contain several layers.
Input:
`network_def`: simplifiable network definition (conv & fc). defined in self.get_network_def_from_model(...)
`block`: (int) index of block to simplify
`constraint`: (float) representing the FLOPs/weights/latency constraint the simplied model should satisfy
`resource_type`: (string) `FLOPS`, `WEIGHTS`, or `LATENCY`
`lookup_table_path`: (string) path to latency lookup table. Needed only when resource_type == 'LATENCY'
`skip_connection_block_sets`: (list or tuple) the list of sets of blocks. Blocks in the same sets will have the
same number of output channels as the corresponding feature maps will be summed later.
(default: [])
For example, if the outputs of block 0 and block 4 are summed and
the outputs of block 1 and block 5 are summed, then
skip_connection_block_sets = [(0, 4), (1, 5)] or ((0, 4), (1, 5)).
Note that we currently support addition.
`min_feature_size`: (int) the number of output channels of simplified (pruned) layer would be
multiples of min_feature_size. (defulat: 8)
Output:
`simplified_network_def`: simplified network definition. Indicates how much the network should
be simplified/pruned.
`simplified_resource`: (float) the estimated resource consumption of simplified models.
'''
# Check whether the block has a skip connection.
block = [block]
for skip_connection_block_set in skip_connection_block_sets:
if block[0] in skip_connection_block_set:
block = list(skip_connection_block_set)
block.sort()
break
print(' simplify_def> constraint: ', constraint)
print(' simplify_def> target block:', block)
# Find the target layer and other layers whose output would later be added to that of target layer
# (i.e. skip connection)
# (contains layer index)
target_layer_indices = []
max_num_out_channels = None
block_counter = 0
for layer_idx, (layer_name, layer_properties) in enumerate(network_def.items()):
# Neglect the depthwise layers.
if layer_properties[KEY_IS_DEPTHWISE]:
continue
if block_counter == block[0]:
target_layer_indices.append(layer_idx)
if max_num_out_channels is not None:
if max_num_out_channels != layer_properties[KEY_NUM_OUT_CHANNELS]:
print('The blocks involved in this skip connection do not have compatible numbers of output '
'channels.')
sys.stdout.flush()
max_num_out_channels = layer_properties[KEY_NUM_OUT_CHANNELS]
print(' simplify_def> target layer: {}, layer index: {}'.format(layer_name, layer_idx))
del block[0]
if not block:
break
block_counter += 1
# Check target_layer_idx.
if target_layer_indices is None:
raise ValueError('`Block` seems out of bound.')
# Determine the number of filters and the resource consumption.
simplified_network_def = copy.deepcopy(network_def)
simplified_resource = None
return_with_constraint_satisfied = False
if max_num_out_channels >= min_feature_size:
# Try numbers of channels that are multiples of '_MIN_FEATURE_SIZE'.
num_out_channels_try = list(range(max_num_out_channels // min_feature_size * min_feature_size,
min_feature_size - 1, -min_feature_size))
else:
num_out_channels_try = [max_num_out_channels]
'''
Update # of output channels of target layers.
Update # of input/output channels of all depthwise layers between target layers and
other subsequent non-depthwise layers (assuming # of groups == # of input channels)
Update # of input channels of one non-depthwise layer following the target layers.
'''
for current_num_out_channels in num_out_channels_try: # Only allow multiple of '_MIN_FEATURE_SIZE'.
for target_layer_index in target_layer_indices:
update_num_out_channels = True
current_num_out_channels_after_pixel_shuffle = current_num_out_channels
for layer_idx, (layer_name, layer_properties) in enumerate(simplified_network_def.items()):
if layer_idx < target_layer_index:
continue
# for the block to be simplified (# of output channels is simplified)
if update_num_out_channels:
if not layer_properties[KEY_IS_DEPTHWISE]:
layer_properties[KEY_NUM_OUT_CHANNELS] = current_num_out_channels
update_num_out_channels = False
print(' simplify_def> layer {}: num of output channel changed to {}'.format(layer_name, str(current_num_out_channels)))
else:
raise ValueError('Expected a non-depthwise layer but got a depthwise layer.')
# for blocks following the target blocks (# of input channels is simplified)
else:
if current_num_out_channels_after_pixel_shuffle % layer_properties[
KEY_BEFORE_SQUARED_PIXEL_SHUFFLE_FACTOR] != 0:
raise ValueError('current_num_out_channels or current_num_out_channels_after_pixel_shuffle is '
'not divisible by the scaling factor of pixel shuffling.')
current_num_out_channels_after_pixel_shuffle = (
current_num_out_channels_after_pixel_shuffle / layer_properties[
KEY_BEFORE_SQUARED_PIXEL_SHUFFLE_FACTOR])
layer_properties[KEY_NUM_IN_CHANNELS] = current_num_out_channels_after_pixel_shuffle
print(' simplify_def> layer {}: num of input channel changed to {}'.format(layer_name, str(current_num_out_channels_after_pixel_shuffle)))
'''
Consider the case that a FC layer is placed after a Conv and Flatten:
FC: input feature size: Cin
output feature size: Cout
Conv: output feature map size: H x W x C
So Cin = H x W x C.
If C -> C' based on constraints, then Cin -> H x W x C'
'''
if layer_properties[KEY_BEFORE_SQUARED_PIXEL_SHUFFLE_FACTOR] == 1:
if network_def[layer_name][KEY_NUM_IN_CHANNELS] > max_num_out_channels:
assert network_def[layer_name][KEY_NUM_IN_CHANNELS] % max_num_out_channels == 0
# H x W here
spatial_factor = network_def[layer_name][KEY_NUM_IN_CHANNELS] // max_num_out_channels
layer_properties[KEY_NUM_IN_CHANNELS] = spatial_factor*current_num_out_channels
print(' simplify_def> [Update] layer {}: num of input channel changed to {}'.format(layer_name, str(spatial_factor*current_num_out_channels)))
if not layer_properties[KEY_IS_DEPTHWISE]:
break
else:
layer_properties[KEY_NUM_OUT_CHANNELS] = current_num_out_channels_after_pixel_shuffle
layer_properties[KEY_GROUPS] = current_num_out_channels_after_pixel_shuffle
print(' simplify_def> depthwise layer {}: num of output channel changed to {}'.format(layer_name, str(current_num_out_channels_after_pixel_shuffle)))
# Get the current resource consumption
simplified_resource = compute_resource(simplified_network_def, resource_type,
lookup_table_path)
print(' simplify_def> finish trying num of output channel: {}, resource: {}'.format(current_num_out_channels, simplified_resource))
# Terminate the simplification when the constraint has been satisfied.
if simplified_resource < constraint:
return_with_constraint_satisfied = True
print(' simplify_def> constraint {} met when trying num of output channel: {}'.format(constraint, current_num_out_channels))
break
if not return_with_constraint_satisfied:
warnings.warn(
'Constraint not satisfied: constraint = {}, simplified_resource = {}'.format(constraint,
simplified_resource))
return simplified_network_def, simplified_resource
def simplify_model_based_on_network_def(simplified_network_def, model):
'''
Choose which filters to perserve
Here filters with largest L2 magnitude will be kept
Input:
`simplified_network_def`: network_def shows how a model will be pruned.
defined in get_network_def_from_model()
`model`: model to be simplified.
Output:
`simplified_model`: simplified model.
'''
simplified_model = copy.deepcopy(model)
simplified_state_dict = simplified_model.state_dict()
kept_filter_idx = None
for layer_param_full_name in simplified_state_dict.keys():
layer = get_layer_by_param_name(simplified_model, layer_param_full_name)
layer_param_full_name_split = layer_param_full_name.split(STRING_SEPARATOR)
layer_name_str = STRING_SEPARATOR.join(layer_param_full_name_split[:-1])
layer_param_name = layer_param_full_name_split[-1]
layer_type_str = layer.__class__.__name__
# Reduce the number of input channels based on the layer and data type.
# Reduce the number of biases of simplified layers
if kept_filter_idx is None:
pass
elif layer_type_str in CONV_LAYER_TYPES:
# Support pixel shuffle.
before_squared_pixel_shuffle_factor = simplified_network_def[layer_name_str][
KEY_BEFORE_SQUARED_PIXEL_SHUFFLE_FACTOR]
kept_filter_idx = (kept_filter_idx[::before_squared_pixel_shuffle_factor] /
before_squared_pixel_shuffle_factor)
if layer_param_name == WEIGHTSTRING: #WEIGHTSTRING == layer_param_name:
if layer.groups == 1: # Pointwise layer or depthwise layer with only one filter.
setattr(layer, layer_param_name,
torch.nn.Parameter(getattr(layer, layer_param_name)[:, kept_filter_idx, :, :]))
layer.in_channels = len(kept_filter_idx)
print(' simplify_model> simplify Conv layer {}: ipnut channel weights {}'.format(layer_name_str,
len(kept_filter_idx)))
else: # depthwise
setattr(layer, layer_param_name,
torch.nn.Parameter(getattr(layer, layer_param_name)[kept_filter_idx, :, :, :]))
layer.in_channels = len(kept_filter_idx)
layer.out_channels = len(kept_filter_idx)
layer.groups = len(kept_filter_idx)
print(' simplify_model> simplify Conv layer {}: ipnut/output channel weights {} and groups {}'.format(layer_name_str,
len(kept_filter_idx), len(kept_filter_idx)))
elif layer_param_name == BIASSTRING: #BIASSTRING == layer_param_name:
print(' simplify_model> simplify Conv layer {}: output channel biases {}'.format(layer_name_str,
len(kept_filter_idx)))
setattr(layer, layer_param_name,
torch.nn.Parameter(getattr(layer, layer_param_name)[kept_filter_idx]))
else:
raise ValueError('The layer_param_name `{}` is not supported.'.format(layer_param_name))
elif layer_type_str in FC_LAYER_TYPES:
if layer_param_name == BIASSTRING:
print(' simplify_model> simplify FC layer {}: output channel biases {}'.format(layer_name_str,
len(kept_filter_idx)))
# the weights of this layer is already reduced
setattr(layer, layer_param_name,
torch.nn.Parameter(getattr(layer, layer_param_name)[kept_filter_idx]))
else:
'''
the input features should be modified
as its previous layer has different output features
Consider the case that a FC layer is placed after a Conv and Flatten:
FC: input feature size: Cin
output feature size: Cout
Conv: output feature map size: H x W x C
So Cin = H x W x C.
If C -> C' based on constraints, then Cin -> H x W x C'
'''
num_in_features = simplified_network_def[layer_name_str][KEY_NUM_IN_CHANNELS]
if num_in_features > len(kept_filter_idx):
assert num_in_features % len(kept_filter_idx) == 0
# H x W here
spatial_ratio = int(num_in_features / len(kept_filter_idx))
kept_filter_idx_fc = kept_filter_idx.clone()
kept_filter_idx_fc_element = kept_filter_idx_fc*spatial_ratio
kept_filter_idx_fc = kept_filter_idx_fc_element.clone()
for i in range(1, spatial_ratio):
kept_filter_idx_fc = torch.cat((kept_filter_idx_fc,
kept_filter_idx_fc_element + i), dim=0)
kept_filter_idx_fc, _ = kept_filter_idx_fc.sort()
setattr(layer, layer_param_name,
torch.nn.Parameter(getattr(layer, layer_param_name)[:, kept_filter_idx_fc]))
layer.in_features = len(kept_filter_idx_fc)
assert len(kept_filter_idx_fc) == num_in_features
else:
setattr(layer, layer_param_name,
torch.nn.Parameter(getattr(layer, layer_param_name)[:, kept_filter_idx]))
layer.in_features = len(kept_filter_idx)
print(' simplify_model> simplify FC layer {}: input channel weights {}'.format(layer_name_str,
layer.in_features))
elif layer_type_str in BNORM_LAYER_TYPES:
if any(substr == layer_param_name for substr in [WEIGHTSTRING, BIASSTRING]):
setattr(layer, layer_param_name,
torch.nn.Parameter(getattr(layer, layer_param_name)[kept_filter_idx], requires_grad=True))
layer.num_features = len(kept_filter_idx)
print(' simplify_model> simplify {} layer {}: {} {}'.format(layer_type_str,
layer_name_str, layer_param_name, layer.num_features))
elif any(substr == layer_param_name for substr in [RUNNING_MEANSTRING, RUNNING_VARSTRING]):
setattr(layer, layer_param_name,
torch.nn.Parameter(getattr(layer, layer_param_name)[kept_filter_idx], requires_grad=False))
layer.num_features = len(kept_filter_idx)
print(' simplify_model> simplify {} layer {}: {} {}'.format(layer_type_str,
layer_name_str, layer_param_name, layer.num_features))
elif NUM_BATCHES_TRACKED == layer_param_name:
getattr(layer, layer_param_name).zero_()
print(' simplify_model> simplify {} layer {}: {} set to 0'.format(layer_type_str,
layer_name_str, layer_param_name))
else:
raise ValueError('The layer_param_name `{}` is not supported.'.format(layer_param_name))
else:
raise ValueError('The layer type `{}` is not supported.'.format(type(layer)))
# Reduce the number of filters and update kept_filter_idx if it is in network_def and
# not a depth-wise layer.
# Reduce the number of output feature maps of simplified layers
if (layer_param_name == WEIGHTSTRING and
layer_name_str in simplified_network_def and
not simplified_network_def[layer_name_str][KEY_IS_DEPTHWISE]):
num_filters = simplified_network_def[layer_name_str][KEY_NUM_OUT_CHANNELS]
weight = layer.weight.data
if num_filters == weight.shape[0]:
# Not target layer thus not simplify
# Means the current model and simplified network def
# have the same number of output channels
kept_filter_idx = None
# Based on L2 norm, determine `kept_filter_idx`
# `kept_filter_idx` is used to simplify the current layer (conv & fc) and
# is also used to simplify some related following layers ()
else:
after_squared_pixel_shuffle_factor = simplified_network_def[layer_name_str][
KEY_AFTER_SQUSRED_PIXEL_SHUFFLE_FACTOR]
if num_filters % after_squared_pixel_shuffle_factor != 0:
raise ValueError('num_filters is not divisible by after_squared_pixel_shuffle_factor.')
num_filters //= after_squared_pixel_shuffle_factor
if layer_type_str in CONV_LAYER_TYPES:
filter_norm = (weight * weight).sum((1, 2, 3))
filter_norm = filter_norm.view(-1, after_squared_pixel_shuffle_factor).sum(1)
elif layer_type_str in FC_LAYER_TYPES:
filter_norm = (weight * weight).sum(1)
_, kept_filter_idx = filter_norm.topk(num_filters, sorted=False)
# consider pixel shuffle
kept_filter_idx_element = kept_filter_idx * after_squared_pixel_shuffle_factor
kept_filter_idx = kept_filter_idx_element.clone()
for pixel_shuffle_factor_counter in range(1, after_squared_pixel_shuffle_factor):
kept_filter_idx = torch.cat(
(kept_filter_idx, kept_filter_idx_element + pixel_shuffle_factor_counter),
dim=0)
kept_filter_idx, _ = kept_filter_idx.sort()
if layer_type_str in CONV_LAYER_TYPES:
setattr(layer, layer_param_name,
torch.nn.Parameter(getattr(layer, layer_param_name)[kept_filter_idx, :, :, :]))
layer.out_channels = len(kept_filter_idx)
print(' simplify_model> simplify Conv layer {}: output channel weights {}'.format(layer_name_str,
len(kept_filter_idx)))
elif layer_type_str in FC_LAYER_TYPES:
setattr(layer, layer_param_name,
torch.nn.Parameter(getattr(layer, layer_param_name)[kept_filter_idx, :]))
layer.out_features = len(kept_filter_idx)
print(' simplify_model> simplify FC layer {}: output channel weights {}'.format(layer_name_str,
len(kept_filter_idx)))
return simplified_model