-
Notifications
You must be signed in to change notification settings - Fork 12
/
fusion_layernorm.py
224 lines (184 loc) · 10.9 KB
/
fusion_layernorm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
#-------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
#--------------------------------------------------------------------------
from typing import Dict
from logging import getLogger
from onnx import helper
from onnx_model import OnnxModel
from fusion_base import Fusion
logger = getLogger(__name__)
class FusionLayerNormalization(Fusion):
def __init__(self, model: OnnxModel):
super().__init__(model, "LayerNormalization", "ReduceMean")
def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict):
"""
Fuse Layer Normalization subgraph into one node LayerNormalization:
+----------------------+
| |
| v
[Root] --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add
(axis=2 or -1) | (Y=2) (axis=2 or -1) (E-6 or E-12 or 0) ^
| |
+-----------------------------------------------+
It also handles cases of duplicated sub nodes exported from older version of PyTorch:
+----------------------+
| v
| +-------> Sub-----------------------------------------------+
| | |
| | v
[Root] --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add
| ^
| |
+----------------------+
"""
children = self.model.get_children(node, input_name_to_nodes)
if len(children) == 0 or len(children) > 2:
return
root_input = node.input[0]
if children[0].op_type != 'Sub' or children[0].input[0] != root_input:
return
if len(children) == 2:
if children[1].op_type != 'Sub' or children[1].input[0] != root_input:
return
div_node = None
for child in children:
div_node = self.model.find_first_child_by_type(child, 'Div', input_name_to_nodes, recursive=False)
if div_node is not None:
break
if div_node is None:
return
path_id, parent_nodes, _ = self.model.match_parent_paths(
div_node, [(['Sqrt', 'Add', 'ReduceMean', 'Pow', 'Sub'], [1, 0, 0, 0, 0]),
(['Sqrt', 'Add', 'ReduceMean', 'Pow', 'Cast', 'Sub'], [1, 0, 0, 0, 0, 0])], output_name_to_node)
if path_id < 0:
return
sub_node = parent_nodes[-1]
if sub_node not in children:
return
second_add_node = parent_nodes[1]
i, add_weight = self.model.get_constant_input(second_add_node)
if add_weight is None or add_weight <= 0 or add_weight > 1.0E-4:
logger.warning(f"epsilon value is not expeced: {add_weight}")
return
pow_node = parent_nodes[3]
if not self.model.find_constant_input(pow_node, 2.0) == 1:
return
mul_node = input_name_to_nodes[div_node.output[0]][0]
if mul_node.op_type != 'Mul':
return
last_add_node = input_name_to_nodes[mul_node.output[0]][0]
if last_add_node.op_type != 'Add':
return
subgraph_nodes = [node]
subgraph_nodes.extend(children)
subgraph_nodes.extend(parent_nodes[:-1])
subgraph_nodes.extend([last_add_node, mul_node, div_node])
if not self.model.is_safe_to_fuse_nodes(subgraph_nodes, last_add_node.output, input_name_to_nodes,
output_name_to_node):
logger.debug(f"It is not safe to fuse LayerNormalization node. Skip")
return
weight_input = mul_node.input[1 - self.model.input_index(div_node.output[0], mul_node)]
if not self.model.is_constant_with_specified_dimension(weight_input, 1, "layernorm weight"):
return
bias_input = last_add_node.input[1 - self.model.input_index(mul_node.output[0], last_add_node)]
if not self.model.is_constant_with_specified_dimension(bias_input, 1, "layernorm bias"):
return
self.nodes_to_remove.extend(subgraph_nodes)
normalize_node = helper.make_node('LayerNormalization',
inputs=[node.input[0], weight_input, bias_input],
outputs=[last_add_node.output[0]],
name=self.model.create_node_name("LayerNormalization",
name_prefix="LayerNorm"))
normalize_node.attribute.extend([helper.make_attribute("epsilon", float(add_weight))])
self.nodes_to_add.append(normalize_node)
self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name
class FusionLayerNormalizationTF(Fusion):
def __init__(self, model: OnnxModel):
super().__init__(model, "LayerNormalization", "Add", "TF")
def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict):
"""
Layer Norm from Tensorflow model(using keras2onnx or tf2onnx):
+------------------------------------+
| |
| |
(Cast_1) |
| |
| v (B) (B) (A)
Add --> (Cast_1) --> ReduceMean --> Sub --> Mul --> ReduceMean --> (Cast_3) --> Add --> Sqrt --> Reciprocol --> Mul --> Mul --> Sub --> Add
| | | ^ ^
| | | | |
| +--------------------------------------------------(Cast_2)-------------------------------|-------+ |
| v |
+---------------------------------------------------------------------------------------------------------------> Mul--------------------+
"""
return_indice = []
_, parent_nodes, return_indice = self.model.match_parent_paths(
node,
[(['Sub', 'Mul', 'Mul', 'Reciprocal', 'Sqrt', 'Add', 'ReduceMean', 'Mul', 'Sub', 'ReduceMean'],
[ 1, 1, None, 0, 0, 0, None, 0, 0, None]),
(['Sub', 'Mul', 'Mul', 'Reciprocal', 'Sqrt', 'Add', 'Cast', 'ReduceMean', 'Mul', 'Sub', 'ReduceMean'],
[ 1, 1, None, 0, 0, 0, 0, None, 0, 0, None])],
output_name_to_node) # yapf: disable
if parent_nodes is None:
return
assert len(return_indice) == 3
if not (return_indice[0] in [0, 1] and return_indice[1] in [0, 1] and return_indice[2] in [0, 1]):
logger.debug("return indice is exepected in [0, 1], but got {return_indice}")
return
sub_node_0, mul_node_0, mul_node_1, reciprocol_node, sqrt_node, add_node_0 = parent_nodes[:6]
reduce_mean_node_0, mul_node_2, sub_node_1, reduce_mean_node_1 = parent_nodes[-4:]
cast_node_3 = None
if len(parent_nodes) == 11:
cast_node_3 = parent_nodes[6]
assert (cast_node_3.op_type == 'Cast')
mul_node_3 = self.model.match_parent(node, 'Mul', 0, output_name_to_node)
if mul_node_3 is None:
logger.debug("mul_node_3 not found")
return
node_before_reduce = self.model.get_parent(reduce_mean_node_1, 0, output_name_to_node)
root_node = node_before_reduce if cast_node_3 is None else self.model.get_parent(
node_before_reduce, 0, output_name_to_node)
if root_node is None:
logger.debug("root node is none")
return
i, epsilon = self.model.get_constant_input(add_node_0)
if epsilon is None or epsilon <= 0 or (epsilon > 1.0E-5 and cast_node_3 is None):
logger.debug("epsilon is not matched")
return
if cast_node_3 is None and (reduce_mean_node_1.input[0] not in mul_node_3.input
or reduce_mean_node_1.input[0] not in sub_node_1.input):
logger.debug("reduce_mean_node_1 and mul_node_3 shall link from root node")
return
if cast_node_3 is not None and (node_before_reduce.input[0] not in mul_node_3.input
or reduce_mean_node_1.input[0] not in sub_node_1.input):
logger.debug("reduce_mean_node_1 and mul_node_3 shall link from root node")
return
if mul_node_2.input[0] != mul_node_2.input[1]:
logger.debug("mul_node_2 shall have two same inputs")
return
subgraph_nodes = [
node, sub_node_0, mul_node_0, mul_node_1, reciprocol_node, sqrt_node, add_node_0, reduce_mean_node_0,
mul_node_2, sub_node_1, reduce_mean_node_1, mul_node_3
]
if cast_node_3 is not None:
cast_node_2 = self.model.match_parent(mul_node_0, 'Cast', 0, output_name_to_node)
if cast_node_2 is None:
logger.debug("cast_node_2 not found")
return
subgraph_nodes.extend([node_before_reduce, cast_node_2, cast_node_3])
if not self.model.is_safe_to_fuse_nodes(subgraph_nodes, node.output, self.model.input_name_to_nodes(),
self.model.output_name_to_node()):
logger.debug("not safe to fuse layer normalization")
return
self.nodes_to_remove.extend(subgraph_nodes)
weight_input = mul_node_1.input[1]
bias_input = sub_node_0.input[0]
#TODO: add epsilon attribute
fused_node = helper.make_node('LayerNormalization',
inputs=[mul_node_3.input[0], weight_input, bias_input],
outputs=[node.output[0]],
name=self.model.create_node_name("LayerNormalization", name_prefix="LayerNorm"))
fused_node.attribute.extend([helper.make_attribute("epsilon", float(epsilon))])
self.nodes_to_add.append(fused_node)
self.node_name_to_graph_name[fused_node.name] = self.this_graph_name