Skip to content

Commit

Permalink
[ONNX] Remove ONNXGraph (#2173)
Browse files Browse the repository at this point in the history
### Changes

Remove ONNXGraph and place all its methods into module onnx_helper.py
To optimize performance four mappings are introduced for ONNX model
manipulation.

`get_node_mapping(model: onnx.ModelProto) -> Dict[str, onnx.NodeProto]`
Mapping from node name to the corresponding node. This needs to not
iterate through all nodes of a model.
`get_edge_info_mapping(model: onnx.ModelProto) -> Dict[str,
onnx.ValueInfoProto]`
Mapping from edge name to corresponding edge information. This needs to
not iterate through all edge infos of a model.
`get_children_node_mapping(model: onnx.ModelProto) -> Dict[str,
List[onnx.NodeProto]]`
Mapping from edge name and corresponding nodes that consume that edge as
an input. Used to traverse forwardwith the optimal performance.
`get_parents_node_mapping(model: onnx.ModelProto) -> Dict[str,
onnx.NodeProto]`
Mapping from edge name to node which outputs this edge. Used to traverse
backward with the optimal performance.

Locally measured perf difference after removing ONNXGraph.

Model | PR time (sec) | develop time (sec) | SpeedUp
-- | -- | -- | --
swinv2_cr_tiny_224| 91.434 | 105.73 | 15.64%
visformer_small | 57.265 | 59.097 | 3.2%
deit3_small_patch16_224 | 52.31 | 55.503 | 6.1%

### Reason for changes

Code refactor

### Related tickets

96982

### Tests

N/A
  • Loading branch information
kshpv authored Oct 10, 2023
1 parent bbb7e56 commit 48f8723
Show file tree
Hide file tree
Showing 12 changed files with 490 additions and 476 deletions.
29 changes: 19 additions & 10 deletions nncf/onnx/graph/metatypes/onnx_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional, Type
from typing import Dict, List, Optional, Type

import onnx

from nncf.common.graph.operator_metatypes import OperatorMetatype
from nncf.common.graph.operator_metatypes import OperatorMetatypeRegistry
from nncf.common.hardware.opset import HWConfigOpName
from nncf.onnx.graph.onnx_graph import ONNXGraph
from nncf.onnx.graph.onnx_helper import get_parent
from nncf.onnx.graph.onnx_helper import get_parents_node_mapping
from nncf.onnx.graph.onnx_helper import get_tensor
from nncf.onnx.graph.onnx_helper import has_tensor

ONNX_OPERATION_METATYPES = OperatorMetatypeRegistry("onnx_operator_metatypes")

Expand Down Expand Up @@ -648,7 +651,12 @@ def get_metatype(model: onnx.ModelProto, node: onnx.NodeProto) -> ONNXOpMetatype
return metatype


def get_tensor_edge_name(onnx_graph: ONNXGraph, node: onnx.NodeProto, port_id: int) -> Optional[str]:
def get_tensor_edge_name(
model: onnx.ModelProto,
node: onnx.NodeProto,
port_id: int,
parents_node_mapping: Dict[str, onnx.NodeProto],
) -> Optional[str]:
"""
Returns an edge name associated with a weight of a node laying on an input port_id.
Expand All @@ -665,9 +673,10 @@ def get_tensor_edge_name(onnx_graph: ONNXGraph, node: onnx.NodeProto, port_id: i
ONNXTransposeMetatype
ONNXQuantizeLinearMetatype
:param onnx_graph: ONNXGraph.
:param model: ONNX model.
:param node: Node.
:param port_id: Port id on which a weight edge is seeking.
:param parents_node_mapping: Mapping from edge name to node which outputs this edge.
:return: Edge name associated with a weight.
"""
PROPAGATING_NODES = (
Expand All @@ -678,14 +687,14 @@ def get_tensor_edge_name(onnx_graph: ONNXGraph, node: onnx.NodeProto, port_id: i
+ ONNXDequantizeLinearMetatype.get_all_aliases()
)
END_NODES = ONNXConstantMetatype.get_all_aliases()
parent = onnx_graph.get_parent(node, port_id)
parent = get_parent(node, port_id, parents_node_mapping)
if not parent:
if onnx_graph.has_tensor(node.input[port_id]):
if has_tensor(model, node.input[port_id]):
return node.input[port_id]
elif parent.op_type in END_NODES:
return node.input[port_id]
elif parent.op_type in PROPAGATING_NODES:
return get_tensor_edge_name(onnx_graph, parent, 0)
return get_tensor_edge_name(model, parent, 0, parents_node_mapping)
return None


Expand Down Expand Up @@ -734,12 +743,12 @@ def _is_embedding(model: onnx.ModelProto, node: onnx.NodeProto) -> bool:
:return: True if the layer is embedding, False - otherwise.
"""
tensor_port_id = ONNXEmbeddingMetatype.weight_port_ids[0]
onnx_graph = ONNXGraph(model)
allowed_types_list = ["TensorProto.FLOAT"]
weight_edge_name = get_tensor_edge_name(onnx_graph, node, tensor_port_id)
parents_node_mapping = get_parents_node_mapping(model)
weight_edge_name = get_tensor_edge_name(model, node, tensor_port_id, parents_node_mapping)

if weight_edge_name is not None:
tensor_data_type = onnx_graph.get_tensor(weight_edge_name).data_type
tensor_data_type = get_tensor(model, weight_edge_name).data_type
if onnx.helper.tensor_dtype_to_string(tensor_data_type) in allowed_types_list:
return True
return False
82 changes: 47 additions & 35 deletions nncf/onnx/graph/model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,13 @@
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.onnx.graph.node_utils import get_input_edge
from nncf.onnx.graph.onnx_graph import ONNXGraph
from nncf.onnx.graph.onnx_helper import get_children
from nncf.onnx.graph.onnx_helper import get_children_node_mapping
from nncf.onnx.graph.onnx_helper import get_edge_dtype
from nncf.onnx.graph.onnx_helper import get_edge_info_mapping
from nncf.onnx.graph.onnx_helper import get_name_to_node_map
from nncf.onnx.graph.onnx_helper import get_node_index
from nncf.onnx.graph.onnx_helper import get_tensor
from nncf.onnx.graph.transformations.commands import ONNXBiasCorrectionCommand
from nncf.onnx.graph.transformations.commands import ONNXModelExtractionCommand
from nncf.onnx.graph.transformations.commands import ONNXOutputInsertionCommand
Expand All @@ -40,15 +46,16 @@ class ONNXModelTransformer(ModelTransformer):
ZERO_POINT_NAME_PREFIX = "zero_point_"

def __init__(self, model: onnx.ModelProto):
super().__init__(model)
infered_model = onnx.shape_inference.infer_shapes(model)
super().__init__(infered_model)
self.onnx_model_extractor = onnx.utils.Extractor(self._model)

def _get_target_edge(
self,
port_id: int,
node_name: str,
transform_type: TargetType,
onnx_graph: ONNXGraph,
node_mapping: Dict[str, onnx.NodeProto],
input_edges_mapping: Dict[str, str],
) -> str:
"""
Expand All @@ -57,16 +64,16 @@ def _get_target_edge(
:param port_id: Edge number of port.
:param node_name: Node name.
:param transform_type: Type of transformation.
:param onnx_graph: ONNXGraph.
:param node_mapping: Mapping from a node name to the node.
:param input_edges_mapping: Mapping between NNCF Input nodes and
the following ONNX nodes and corresponding input port id.
the following ONNX nodes and corresponding input port id.
:return: Target edge name.
"""
if transform_type in [TargetType.PRE_LAYER_OPERATION, TargetType.OPERATION_WITH_WEIGHTS]:
return onnx_graph.get_node_edge_names(node_name)["input"][port_id]
return node_mapping[node_name].input[port_id]
if node_name in input_edges_mapping: # ADD INPUT NODE CASE
return get_input_edge(node_name, input_edges_mapping, onnx_graph)
return onnx_graph.get_node_edge_names(node_name)["output"][port_id]
return get_input_edge(node_name, input_edges_mapping, node_mapping)
return node_mapping[node_name].output[port_id]

def transform(self, transformation_layout: TransformationLayout) -> onnx.ModelProto:
"""
Expand Down Expand Up @@ -123,15 +130,15 @@ def _apply_output_insertion_transformations(
:param transformations: ONNXOutputInsertionCommand transformations.
:return: New model with inserted outputs.
"""
onnx_graph = ONNXGraph(self._model)
model_outputs = set(output.name for output in onnx_graph.get_model_outputs())
model_outputs = set(output.name for output in self._model.graph.output)
node_mapping = get_name_to_node_map(self._model)
for transformation in transformations:
port_id = transformation.target_point.port_id
node_name = transformation.target_point.target_node_name
transform_type = transformation.target_point.type
input_edges_mapping = transformation.input_edges_mapping
target_edge_name = self._get_target_edge(
port_id, node_name, transform_type, onnx_graph, input_edges_mapping
port_id, node_name, transform_type, node_mapping, input_edges_mapping
)
model_outputs.add(target_edge_name)

Expand All @@ -146,11 +153,11 @@ def _insert_outputs(model: onnx.ModelProto, outputs: Union[List[str], Set[str]])
:param outputs: Edge names to use as outputs.
:return: New model with inserted outputs.
"""
onnx_graph = ONNXGraph(model)
model_outputs = []
edge_info_mapping = get_edge_info_mapping(model)
for output in outputs:
edge = onnx_graph.get_edge(output)
onnx_dtype = ONNXGraph.get_edge_dtype(edge)
edge = edge_info_mapping[output]
onnx_dtype = get_edge_dtype(edge)
type_proto = onnx.helper.make_tensor_type_proto(onnx_dtype, shape=None)
model_outputs.append(onnx.helper.make_value_info(name=output, type_proto=type_proto))

Expand Down Expand Up @@ -193,7 +200,8 @@ def _apply_quantizer_insertion_transformations(
"""
self._added_target_edges = Counter()
for transformation in transformations:
model = self._insert_quantizer_dequantizer(model, transformation)
children_node_mapping = get_children_node_mapping(model)
model = self._insert_quantizer_dequantizer(model, transformation, children_node_mapping)
return model

def _get_quantize_dequantize_nodes(
Expand Down Expand Up @@ -274,51 +282,55 @@ def _get_scale_zero_point_tensors(
return onnx_scale_tensor, onnx_zero_point_tensor

def _get_quantizer_dequantizer_edge_name(
self, transformation: ONNXQuantizerInsertionCommand, onnx_graph: ONNXGraph
self, transformation: ONNXQuantizerInsertionCommand, node_mapping: Dict[str, onnx.NodeProto]
) -> str:
"""
Returns an edge name on which QuantizeLinear-DequantizeLinear nodes pair has to be inserted.
:param transformation: QuantizeLinear-DequantizeLinear insertion transformation.
:param onnx_graph: ONNXGraph.
:param node_mapping: Mapping from a node name to the node.
:return: Edge name to insert QuantizeLinear-DequantizeLinear nodes pair.
"""
port_id = transformation.target_point.port_id
node_name = transformation.target_point.target_node_name
transform_type = transformation.target_point.type
input_edges_mapping = transformation.input_edges_mapping
target_edge_name = self._get_target_edge(port_id, node_name, transform_type, onnx_graph, input_edges_mapping)
target_edge_name = self._get_target_edge(port_id, node_name, transform_type, node_mapping, input_edges_mapping)
self._added_target_edges[target_edge_name] += 1
return target_edge_name

def _insert_quantizer_dequantizer(
self, model: onnx.ModelProto, transformation: ONNXQuantizerInsertionCommand
self,
model: onnx.ModelProto,
transformation: ONNXQuantizerInsertionCommand,
children_node_mapping: Dict[str, List[onnx.ValueInfoProto]],
) -> onnx.ModelProto:
"""
Inserts QuantizeLinear-DequantizeLinear nodes pair.
:param model: Model to insert new nodes.
:param transformation: QuantizeLinear-DequantizeLinear insertion transformation.
:param children_node_mapping: Mapping from edge name to nodes which consume this edge as an input.
:return: Updated model with inserted QuantizeLinear-DequantizeLinear pair.
"""
onnx_graph = ONNXGraph(model)
target_edge_name = self._get_quantizer_dequantizer_edge_name(transformation, onnx_graph)
node_mapping = get_name_to_node_map(model)
target_edge_name = self._get_quantizer_dequantizer_edge_name(transformation, node_mapping)
quantizer, dequantizer = self._get_quantize_dequantize_nodes(transformation, target_edge_name)
onnx_scale_tensor, onnx_zero_point_tensor = ONNXModelTransformer._get_scale_zero_point_tensors(
transformation, quantizer, dequantizer
)

# If several nodes on one edge
input_nodes = []
input_nodes.extend(onnx_graph.get_nodes_by_input(target_edge_name))
input_nodes.extend(children_node_mapping[target_edge_name])
if not input_nodes:
raise RuntimeError(
f"Can not add the quantizer to the {target_edge_name} edge. This edge does not have end node."
)

if transformation.target_point.type == TargetType.PRE_LAYER_OPERATION:
# If we need to change only target nodes input
target_node = onnx_graph.get_node_by_name(transformation.target_point.target_node_name)
target_node = node_mapping[transformation.target_point.target_node_name]
for i, inp in enumerate(target_node.input):
if inp == target_edge_name:
target_node.input[i] = dequantizer.output[0]
Expand All @@ -336,7 +348,7 @@ def _insert_quantizer_dequantizer(
)
model.graph.initializer.extend([onnx_scale_tensor, onnx_zero_point_tensor])
model.graph.value_info.extend([onnx_scale_value_info, onnx_zero_point_info])
insert_index = onnx_graph.get_node_index(input_nodes[0].name)
insert_index = get_node_index(model, input_nodes[0].name)
model.graph.node.insert(insert_index, quantizer)
model.graph.node.insert(insert_index + 1, dequantizer)
return model
Expand All @@ -351,13 +363,13 @@ def _apply_bias_correction_transformations(
:param transformations: Bias correction transformations.
:return: Copy of original model with updated biases.
"""
onnx_graph = ONNXGraph(model)
node_mapping = get_name_to_node_map(model)
for transformation in transformations:
bias_tensor_position = transformation.target_point.port_id
node_name = transformation.target_point.target_node_name
onnx_node = onnx_graph.get_node_by_name(node_name)
onnx_node = node_mapping[node_name]
bias_initializer_name = onnx_node.input[bias_tensor_position]
bias_initializer = onnx_graph.get_tensor(bias_initializer_name)
bias_initializer = get_tensor(model, bias_initializer_name)

new_bias_tensor = onnx.numpy_helper.from_array(transformation.bias_value, bias_initializer_name)
bias_initializer.CopyFrom(new_bias_tensor)
Expand All @@ -370,20 +382,19 @@ def _apply_model_extraction_transformation(self, transformation: ONNXModelExtrac
:param transformation: Model extraction transformation.
:return: Extracted sub-model.
"""
onnx_graph = ONNXGraph(self._model)

input_tensor_names = []
node_mapping = get_name_to_node_map(self._model)
for input_node_name in transformation.inputs:
input_onnx_node = onnx_graph.get_node_by_name(input_node_name)
input_onnx_node = node_mapping[input_node_name]
input_tensor_names.append(input_onnx_node.input[0])

output_tensor_names = []
for output_node_name in transformation.outputs:
output_onnx_node = onnx_graph.get_node_by_name(output_node_name)
output_onnx_node = node_mapping[output_node_name]
output_tensor_names.append(output_onnx_node.output[0])

if not output_tensor_names:
output_tensor_names = [n.name for n in onnx_graph.get_model_outputs()]
output_tensor_names = [n.name for n in self._model.graph.output]

return self.onnx_model_extractor.extract_model(input_tensor_names, output_tensor_names)

Expand All @@ -397,11 +408,12 @@ def _apply_qdq_node_removing_transformations(
:param transformations: Nodes removing transformations.
:return: Model with removed nodes.
"""
onnx_graph = ONNXGraph(model)
for transformation in transformations:
node = onnx_graph.get_node_by_name(transformation.target_point.target_node_name)
node_mapping = get_name_to_node_map(model)
children_node_mapping = get_children_node_mapping(model)
node = node_mapping[transformation.target_point.target_node_name]

node_children = onnx_graph.get_children(node)
node_children = get_children(node, children_node_mapping)
for node_child in node_children:
for input_id, input_obj in enumerate(node_child.input):
if input_obj == node.output[0]:
Expand Down
Loading

0 comments on commit 48f8723

Please sign in to comment.