-
Notifications
You must be signed in to change notification settings - Fork 118
/
Copy pathoneshot_pruner.py
192 lines (146 loc) · 7.53 KB
/
oneshot_pruner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import copy
import typing
import torch
import torch.nn as nn
import torch.distributed as dist
from tinynn.graph import modifier
from tinynn.graph.tracer import ignore_mod_param_update_warning
from tinynn.util.util import get_logger
from tinynn.prune.base_pruner import BasePruner
from tinynn.graph.modifier import is_dw_conv
log = get_logger(__name__)
class OneShotChannelPruner(BasePruner):
required_params = ('sparsity', 'metrics')
metrics: str
sparsity: typing.Union[typing.Dict[str, float], float]
default_sparsity: float
metric_func: typing.Callable[[torch.Tensor, torch.nn.Module], float]
skip_last_fc: bool
bn_compensation: bool
exclude_ops: list
multiple: int
def __init__(self, model, dummy_input, config):
"""Constructs a new OneShotPruner (including random, l1_norm, l2_norm, fpgm)
Args:
model: The model to be pruned
dummy_input: A viable input to the model
config (dict, str): Configuration of the pruner (could be path to the json file)
Raises:
Exception: If a model without parameters or prunable ops is given, the exception will be thrown
"""
super(OneShotChannelPruner, self).__init__(model, dummy_input, config)
self.center_nodes = []
self.sparsity = {}
self.bn_compensation = False
self.parse_config()
for n in self.graph.forward_nodes:
# Only prune the specific operators
if (
(
n.type() in [nn.Conv2d, nn.ConvTranspose2d, nn.Conv1d, nn.ConvTranspose1d]
and not is_dw_conv(n.module)
)
or (n.type() in [nn.Linear])
or (n.type() in [nn.RNN, nn.GRU, nn.LSTM] and len(n.prev_tensors) == 1)
):
self.center_nodes.append(n)
if n.unique_name not in self.sparsity:
self.sparsity[n.unique_name] = self.default_sparsity
last_center_node = self.center_nodes[-1] if len(self.center_nodes) > 0 else None
if last_center_node is None:
raise Exception("No operations to prune")
# 除非人工指定了剪枝率,否则当最后一个中心节点是linear时不对其进行剪枝(通常网络最后的全连接层与类别数量有关)
if self.skip_last_fc and last_center_node.type() in [nn.Linear]:
if self.sparsity[last_center_node.unique_name] == self.default_sparsity:
self.sparsity[last_center_node.unique_name] = 0.0
self.graph_modifier = modifier.GraphChannelModifier(self.graph, self.center_nodes, self.bn_compensation)
# TODO: 为了保持剪枝精度,目前暂时先将所有的常量参与的剪枝子图取消
self.exclude_ops.append('weight')
for sub_graph in self.graph_modifier.sub_graphs.values():
exclude = False
for m in sub_graph.modifiers:
if m.node.type() in self.exclude_ops:
exclude = True
break
if m.node.type() in (nn.RNN, nn.GRU, nn.LSTM) and len(m.node.prev_tensors) > 1:
exclude = True
break
if exclude:
sub_graph.skip = True
def parse_config(self):
"""Parses the context and copy the needed items to the pruner"""
super().parse_config()
all_param_keys = list(self.required_params) + list(self.default_values.keys())
for param_key in all_param_keys:
if param_key not in ['sparsity', 'metrics', 'skip_last_fc', 'exclude_ops', "bn_compensation"]:
setattr(self, param_key, self.config[param_key])
sparsity = self.config['sparsity']
metrics = self.config['metrics']
self.skip_last_fc = self.config.get('skip_last_fc', True)
self.multiple = self.config.get('multiple', None)
self.bn_compensation = self.config.get('bn_compensation', self.bn_compensation)
self.exclude_ops = self.config.get('exclude_ops', [])
if isinstance(sparsity, float):
self.default_sparsity = sparsity
elif isinstance(sparsity, dict):
if 'default' in sparsity:
self.default_sparsity = sparsity['default']
del sparsity['default']
self.sparsity.update(sparsity)
else:
raise Exception(f'The type of `sparsity` should either be float or dict for {type(self).__name__}')
if not hasattr(self, 'default_sparsity'):
raise Exception(f'Please specify a default sparsity using the key `default` for {type(self).__name__}')
if hasattr(modifier, metrics):
self.metric_func = getattr(modifier, metrics)
else:
raise Exception(f'{metrics} is not a known metrics for {type(self).__name__}')
def prune(self):
"""The main function for pruning.
As for oneshot pruning, it is simply calculating the masks (`register_mask`) and them apply them (`apply_mask`).
"""
self.register_mask()
with ignore_mod_param_update_warning():
self.apply_mask()
def register_mask(self):
"""Computes the mask for the parameters in the model and register them through the maskers"""
log.info("Register a mask for each operator")
importance = {}
for sub_graph in self.graph_modifier.sub_graphs.values():
if sub_graph.skip:
log.info(f"skip subgraph {sub_graph.center}")
continue
for m in sub_graph.modifiers:
if m.node in self.center_nodes and m in sub_graph.dependent_centers:
modifier.update_weight_metric(importance, self.metric_func, m.module(), m.unique_name())
sub_graph.calc_prune_idx(importance, self.sparsity)
log.info(f"subgraph [{sub_graph.center}] compute over")
for m in self.graph_modifier.modifiers.values():
m.register_mask(self.graph_modifier.modifiers, importance, self.sparsity)
def apply_mask(self):
"""Applies the masks for the parameters and updates the shape and properties of the tensors and modules"""
log.info("Apply the mask of each operator")
for m in self.graph_modifier.modifiers.values():
m.apply_mask(self.graph_modifier.modifiers)
self.graph_modifier.unregister_modifier()
# Sync parameters after the size of the tensors has been shrunk
if dist.is_available() and dist.is_initialized():
for ps in self.model.parameters():
dist.broadcast(ps, 0)
def generate_config(self, path: str) -> None:
"""Generates a new copy the updated configuration with the given path"""
config = copy.deepcopy(self.config)
config['sparsity'] = dict()
config['sparsity']['default'] = self.default_sparsity
config['sparsity'].update(self.sparsity)
super().generate_config(path, config)
def reset(self) -> None:
"""Regenerate the TraceGraph and the Graph Modifier when they are invalidated"""
self.graph = self.trace()
self.center_nodes.clear()
for n in self.graph.forward_nodes:
if (n.type() in [nn.Conv2d, nn.ConvTranspose2d] and not is_dw_conv(n.module)) or (n.type() in [nn.Linear]):
self.center_nodes.append(n)
if n.unique_name not in self.sparsity:
self.sparsity[n.unique_name] = self.default_sparsity
self.graph_modifier = modifier.GraphChannelModifier(self.graph, self.center_nodes, self.bn_compensation)