Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support fusion for TNLR based model #10432

Merged
merged 2 commits into from
Feb 4, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 179 additions & 0 deletions onnxruntime/python/tools/transformers/onnx_model_tnlr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
#-------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
#--------------------------------------------------------------------------
import logging
from fusion_attention import FusionAttention, AttentionMask
from fusion_utils import NumpyHelper
from onnx import helper, numpy_helper, TensorProto, NodeProto
from onnx_model import OnnxModel
from onnx_model_bert import BertOnnxModel
from typing import Union

logger = logging.getLogger(__name__)


class FusionTnlrAttention(FusionAttention):
"""
Fuse TNLR Attention subgraph into one Attention node.
wangyems marked this conversation as resolved.
Show resolved Hide resolved
"""
def __init__(self, model: OnnxModel, hidden_size: int, num_heads: int, attention_mask: AttentionMask):
super().__init__(model, hidden_size, num_heads, attention_mask)

def create_attention_node(self, mask_index: str, matmul: NodeProto, add: NodeProto, num_heads: int,
hidden_size: int, input: str, output: str, add_qk_str: str) -> Union[NodeProto, None]:

assert num_heads > 0
if hidden_size > 0 and (hidden_size % num_heads) != 0:
logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}")
return None

weight = self.model.get_initializer(matmul.input[1])
bias = self.model.get_initializer(add.input[1]) or self.model.get_initializer(add.input[0])

if weight is None or bias is None:
return None

qkv_weight = NumpyHelper.to_array(weight)
qkv_bias = NumpyHelper.to_array(bias)

attention_node_name = self.model.create_node_name('Attention')

weight = helper.make_tensor(name=attention_node_name + '_qkv_weight',
data_type=TensorProto.FLOAT,
dims=[hidden_size, 3 * hidden_size],
vals=qkv_weight.flatten().tolist())

# Sometimes weights and bias are stored in fp16
if weight.data_type == 10:
weight.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(weight).astype(np.float16), weight.name))
self.model.add_initializer(weight, self.this_graph_name)

bias = helper.make_tensor(name=attention_node_name + '_qkv_bias',
data_type=TensorProto.FLOAT,
dims=[3 * hidden_size],
vals=qkv_bias.flatten().tolist())
if bias.data_type == 10:
bias.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(bias).astype(np.float16), bias.name))
self.model.add_initializer(bias, self.this_graph_name)

attention_inputs = [input, attention_node_name + '_qkv_weight', attention_node_name + '_qkv_bias']
if mask_index is not None:
attention_inputs.append(mask_index)
else:
attention_inputs.append("")

if add_qk_str is not None:
attention_inputs.append("")
attention_inputs.append(add_qk_str)

attention_node = helper.make_node('Attention',
inputs=attention_inputs,
outputs=[output],
name=attention_node_name)
attention_node.domain = "com.microsoft"
attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])

return attention_node

def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
# Sometimes we can not fuse skiplayernormalization since the add before layernorm has an output that used by nodes outside skiplayernorm
# Conceptually we treat add before layernorm as skiplayernorm node since they share the same pattern
start_node = normalize_node
if normalize_node.op_type != 'SkipLayerNormalization':
return

# SkipLayerNormalization has two inputs, and one of them is the root input for attention.
qkv_nodes = self.model.match_parent_path(start_node,
['Where', 'Add', 'MatMul', 'Reshape', 'Transpose', 'MatMul'],
[1, 1, 1, 0, 0, 0])
if qkv_nodes is not None:
(_, _, matmul_below, reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes
else:
return

other_inputs = []
for i, input in enumerate(start_node.input):
if input not in output_name_to_node:
continue

if input == qkv_nodes[0].output[0]:
continue
other_inputs.append(input)
if len(other_inputs) != 1:
return

root_input = other_inputs[0]

v_nodes = self.model.match_parent_path(matmul_qkv, ['Transpose', 'Reshape', 'Slice', 'Add', 'MatMul'],
[1, 0, 0, 0, 1])
if v_nodes is None:
return
(_, _, _, add, matmul) = v_nodes

upper_nodes = self.model.match_parent_path(matmul, ['Transpose'], [0])
transpose = upper_nodes[0]

qk_nodes = self.model.match_parent_path(matmul_qkv, ['Softmax', 'Add', 'MatMul'], [0, 0, 0])
if qk_nodes is None:
return
(_, add_qk, matmul_qk) = qk_nodes

q_nodes = self.model.match_parent_path(matmul_qk, ['Mul', 'Transpose', 'Reshape', 'Slice', 'Add', 'MatMul'],
[0, 0, 0, 0, 0, 1])
if q_nodes is None:
return
add = q_nodes[-2]
matmul = q_nodes[-1]

k_nodes = self.model.match_parent_path(matmul_qk, ['Transpose', 'Reshape', 'Slice', 'Add', 'MatMul'],
[1, 0, 0, 0, 1])
if k_nodes is None:
return
add = k_nodes[-2]
matmul = k_nodes[-1]

extra_add_qk_nodes = self.model.match_parent_path(add_qk, ['Reshape', 'Where'], [1, 0])
if extra_add_qk_nodes is None:
return

if matmul.input[0] == root_input:
mask_index = None
attention_last_node = reshape_qkv
# number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
# the input_hidden_size represents the input hidden size, this is used as needed but hidden sizes for Q, K are extracted appropriately
new_node = self.create_attention_node(mask_index, matmul, add, self.num_heads, self.hidden_size, root_input,
attention_last_node.output[0], extra_add_qk_nodes[0].input[0])
if new_node is None:
return

self.nodes_to_add.append(new_node)
self.node_name_to_graph_name[new_node.name] = self.this_graph_name

# Add a transpose node after the attention in Offensive V4
wangyems marked this conversation as resolved.
Show resolved Hide resolved
back_transpose = helper.make_node("Transpose", ["back_transpose_in_" + new_node.name], [new_node.output[0]],
"back_transpose_" + new_node.name,
perm=[1, 0, 2])
self.model.add_node(back_transpose, self.this_graph_name)
new_node.input[0] = transpose.input[0]
new_node.output[0] = "back_transpose_in_" + new_node.name

self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv])
self.nodes_to_remove.extend(qk_nodes)
self.nodes_to_remove.extend(q_nodes)
self.nodes_to_remove.extend(k_nodes)
self.nodes_to_remove.extend(v_nodes)

# Use prune graph to remove mask nodes since they are shared by all attention nodes.
#self.nodes_to_remove.extend(mask_nodes)
self.prune_graph = True


class TnlrOnnxModel(BertOnnxModel):
def __init__(self, model, num_heads, hidden_size):
super().__init__(model, num_heads, hidden_size)
self.attention_mask = AttentionMask(self)
self.attention_fusion = FusionTnlrAttention(self, self.hidden_size, self.num_heads, self.attention_mask)

def fuse_attention(self):
self.attention_fusion.apply()
18 changes: 10 additions & 8 deletions onnxruntime/python/tools/transformers/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from onnx_model_bert_tf import BertOnnxModelTF
from onnx_model_bert_keras import BertOnnxModelKeras
from onnx_model_gpt2 import Gpt2OnnxModel
from onnx_model_tnlr import TnlrOnnxModel
from fusion_options import FusionOptions

logger = logging.getLogger(__name__)
Expand All @@ -39,7 +40,8 @@
"bert_tf": (BertOnnxModelTF, "tf2onnx", 0),
"bert_keras": (BertOnnxModelKeras, "keras2onnx", 0),
"gpt2": (Gpt2OnnxModel, "pytorch", 1),
"gpt2_tf": (Gpt2OnnxModel, 'tf2onnx', 0) # might add a class for GPT2OnnxModel for TF later.
"gpt2_tf": (Gpt2OnnxModel, 'tf2onnx', 0), # might add a class for GPT2OnnxModel for TF later.
"tnlr": (TnlrOnnxModel, "pytorch", 1),
}


Expand Down Expand Up @@ -115,9 +117,9 @@ def optimize_by_fusion(model: ModelProto,
model (ModelProto): model object
model_type (str, optional): model type - like bert, bert_tf, bert_keras or gpt2. Defaults to 'bert'.
num_heads (int, optional): number of attention heads. Defaults to 0.
0 allows detect the parameter from graph automatically (for model_type "bert" only).
0 allows detect the parameter from graph automatically (for model_type "bert" only).
hidden_size (int, optional): hidden size. Defaults to 0.
0 allows detect the parameter from graph automatically (for model_type "bert" only).
0 allows detect the parameter from graph automatically (for model_type "bert" only).
optimization_options (FusionOptions, optional): optimization options that turn on/off some fusions. Defaults to None.

Returns:
Expand Down Expand Up @@ -159,7 +161,7 @@ def optimize_model(input: str,
only_onnxruntime: bool = False):
""" Optimize Model by OnnxRuntime and/or python fusion logic.

ONNX Runtime has graph optimizations (https://onnxruntime.ai/docs/resources/graph-optimizations.html).
ONNX Runtime has graph optimizations (https://onnxruntime.ai/docs/resources/graph-optimizations.html).
However, the coverage is limited. We also have graph fusions that implemented in Python to improve the coverage.
They can combined: ONNX Runtime will run first when opt_level > 0, then graph fusions in Python will be applied.

Expand All @@ -170,8 +172,8 @@ def optimize_model(input: str,

When opt_level is 0 and only_onnxruntime is False, only python fusion logic is used and onnxruntime is disabled.

When opt_level > 1, use_gpu shall set properly since the optimized graph might contain operators for GPU or CPU only.
If your model is intended for GPU inference only (especially float16 or mixed precision model), it is recommended to
When opt_level > 1, use_gpu shall set properly since the optimized graph might contain operators for GPU or CPU only.
If your model is intended for GPU inference only (especially float16 or mixed precision model), it is recommended to
set use_gpu to be True, otherwise the model is not optimized for GPU inference.

For BERT model, num_heads and hidden_size are optional. For other model types, you need specify these parameters.
Expand All @@ -180,9 +182,9 @@ def optimize_model(input: str,
input (str): input model path.
model_type (str, optional): model type - like bert, bert_tf, bert_keras or gpt2. Defaults to 'bert'.
num_heads (int, optional): number of attention heads. Defaults to 0.
0 allows detect the parameter from graph automatically (for model_type "bert" only).
0 allows detect the parameter from graph automatically (for model_type "bert" only).
hidden_size (int, optional): hidden size. Defaults to 0.
0 allows detect the parameter from graph automatically (for model_type "bert" only).
0 allows detect the parameter from graph automatically (for model_type "bert" only).
optimization_options (FusionOptions, optional): optimization options that turn on/off some fusions. Defaults to None.
opt_level (int, optional): onnxruntime graph optimization level (0, 1, 2 or 99) or None. Defaults to None.
When the value is None, default value (1 for bert and gpt2, 0 for other model types) will be used.
Expand Down