-
Notifications
You must be signed in to change notification settings - Fork 1
/
pruning_utils_3.py
137 lines (95 loc) · 3.34 KB
/
pruning_utils_3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import copy
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
def pruning_model(model, px):
print('start unstructured pruning')
parameters_to_prune =[]
for name,m in model.named_modules():
if isinstance(m, nn.Conv2d):
parameters_to_prune.append((m,'weight'))
parameters_to_prune = tuple(parameters_to_prune)
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=px,
)
def pruning_model_random(model, px):
parameters_to_prune =[]
for name,m in model.named_modules():
if isinstance(m, nn.Conv2d):
parameters_to_prune.append((m,'weight'))
parameters_to_prune = tuple(parameters_to_prune)
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.RandomUnstructured,
amount=px,
)
def prune_model_custom(model, mask_dict):
print('start unstructured pruning with custom mask')
for name,m in model.named_modules():
if isinstance(m, nn.Conv2d):
prune.CustomFromMask.apply(m, 'weight', mask=mask_dict[name+'.weight_mask'])
def remove_prune(model):
print('remove pruning')
for name,m in model.named_modules():
if isinstance(m, nn.Conv2d):
prune.remove(m,'weight')
def extract_mask(model_dict):
new_dict = {}
for key in model_dict.keys():
if 'mask' in key:
new_dict[key] = copy.deepcopy(model_dict[key])
return new_dict
def reverse_mask(mask_dict):
new_dict = {}
for key in mask_dict.keys():
new_dict[key] = 1 - mask_dict[key]
return new_dict
def check_sparsity(model):
sum_list = 0
zero_sum = 0
for name,m in model.named_modules():
if isinstance(m, nn.Conv2d):
sum_list = sum_list+float(m.weight.nelement())
zero_sum = zero_sum+float(torch.sum(m.weight == 0))
print('* remain weight = ', 100*(1-zero_sum/sum_list),'%')
return 100*(1-zero_sum/sum_list)
def return_current_mask(model, px, pruned=False):
# saving current model informaction
checkpoint = copy.deepcopy(model.state_dict())
if pruned:
checkpoint_mask = extract_mask(checkpoint)
# get current mask
pruning_model(model, px)
epoch_mask = extract_mask(model.state_dict())
# recover model
remove_prune(model)
if pruned:
prune_model_custom(model, checkpoint_mask)
model.load_state_dict(checkpoint)
return epoch_mask
def calculate_hamming_distance(last_mask, current_mask, remain_parameters):
cnt_diff = 0
for key in last_mask.keys():
same_number = (last_mask[key] == current_mask[key]).float().sum()
mask_size = last_mask[key].nelement()
cnt_diff += (mask_size - same_number)
hamming_distance = cnt_diff/remain_parameters
return hamming_distance
def cnt_remain_para(mask_dict):
remain_para = 0
for key in mask_dict.keys():
remain_para += mask_dict[key].sum().float()
return remain_para.item()
def cnt_model_para(model):
para_number = 0
for m in model.modules():
if isinstance(m, nn.Conv2d):
para_number += m.weight.nelement()
return para_number
def FIFO(dis_queue, ele):
new_queue = torch.ones_like(dis_queue)
new_queue[1:] = dis_queue[:-1]
new_queue[0] = ele
return new_queue