diff --git a/nncf/onnx/graph/metatypes/onnx_metatypes.py b/nncf/onnx/graph/metatypes/onnx_metatypes.py index 35d532caeac..2644658f8ee 100644 --- a/nncf/onnx/graph/metatypes/onnx_metatypes.py +++ b/nncf/onnx/graph/metatypes/onnx_metatypes.py @@ -9,14 +9,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Type +from typing import Dict, List, Optional, Type import onnx from nncf.common.graph.operator_metatypes import OperatorMetatype from nncf.common.graph.operator_metatypes import OperatorMetatypeRegistry from nncf.common.hardware.opset import HWConfigOpName -from nncf.onnx.graph.onnx_graph import ONNXGraph +from nncf.onnx.graph.onnx_helper import get_parent +from nncf.onnx.graph.onnx_helper import get_parents_node_mapping +from nncf.onnx.graph.onnx_helper import get_tensor +from nncf.onnx.graph.onnx_helper import has_tensor ONNX_OPERATION_METATYPES = OperatorMetatypeRegistry("onnx_operator_metatypes") @@ -648,7 +651,12 @@ def get_metatype(model: onnx.ModelProto, node: onnx.NodeProto) -> ONNXOpMetatype return metatype -def get_tensor_edge_name(onnx_graph: ONNXGraph, node: onnx.NodeProto, port_id: int) -> Optional[str]: +def get_tensor_edge_name( + model: onnx.ModelProto, + node: onnx.NodeProto, + port_id: int, + parents_node_mapping: Dict[str, onnx.NodeProto], +) -> Optional[str]: """ Returns an edge name associated with a weight of a node laying on an input port_id. @@ -665,9 +673,10 @@ def get_tensor_edge_name(onnx_graph: ONNXGraph, node: onnx.NodeProto, port_id: i ONNXTransposeMetatype ONNXQuantizeLinearMetatype - :param onnx_graph: ONNXGraph. + :param model: ONNX model. :param node: Node. :param port_id: Port id on which a weight edge is seeking. + :param parents_node_mapping: Mapping from edge name to node which outputs this edge. :return: Edge name associated with a weight. """ PROPAGATING_NODES = ( @@ -678,14 +687,14 @@ def get_tensor_edge_name(onnx_graph: ONNXGraph, node: onnx.NodeProto, port_id: i + ONNXDequantizeLinearMetatype.get_all_aliases() ) END_NODES = ONNXConstantMetatype.get_all_aliases() - parent = onnx_graph.get_parent(node, port_id) + parent = get_parent(node, port_id, parents_node_mapping) if not parent: - if onnx_graph.has_tensor(node.input[port_id]): + if has_tensor(model, node.input[port_id]): return node.input[port_id] elif parent.op_type in END_NODES: return node.input[port_id] elif parent.op_type in PROPAGATING_NODES: - return get_tensor_edge_name(onnx_graph, parent, 0) + return get_tensor_edge_name(model, parent, 0, parents_node_mapping) return None @@ -734,12 +743,12 @@ def _is_embedding(model: onnx.ModelProto, node: onnx.NodeProto) -> bool: :return: True if the layer is embedding, False - otherwise. """ tensor_port_id = ONNXEmbeddingMetatype.weight_port_ids[0] - onnx_graph = ONNXGraph(model) allowed_types_list = ["TensorProto.FLOAT"] - weight_edge_name = get_tensor_edge_name(onnx_graph, node, tensor_port_id) + parents_node_mapping = get_parents_node_mapping(model) + weight_edge_name = get_tensor_edge_name(model, node, tensor_port_id, parents_node_mapping) if weight_edge_name is not None: - tensor_data_type = onnx_graph.get_tensor(weight_edge_name).data_type + tensor_data_type = get_tensor(model, weight_edge_name).data_type if onnx.helper.tensor_dtype_to_string(tensor_data_type) in allowed_types_list: return True return False diff --git a/nncf/onnx/graph/model_transformer.py b/nncf/onnx/graph/model_transformer.py index a8f5355babf..b6db3d36b0d 100644 --- a/nncf/onnx/graph/model_transformer.py +++ b/nncf/onnx/graph/model_transformer.py @@ -19,7 +19,13 @@ from nncf.common.graph.transformations.commands import TargetType from nncf.common.graph.transformations.layout import TransformationLayout from nncf.onnx.graph.node_utils import get_input_edge -from nncf.onnx.graph.onnx_graph import ONNXGraph +from nncf.onnx.graph.onnx_helper import get_children +from nncf.onnx.graph.onnx_helper import get_children_node_mapping +from nncf.onnx.graph.onnx_helper import get_edge_dtype +from nncf.onnx.graph.onnx_helper import get_edge_info_mapping +from nncf.onnx.graph.onnx_helper import get_name_to_node_map +from nncf.onnx.graph.onnx_helper import get_node_index +from nncf.onnx.graph.onnx_helper import get_tensor from nncf.onnx.graph.transformations.commands import ONNXBiasCorrectionCommand from nncf.onnx.graph.transformations.commands import ONNXModelExtractionCommand from nncf.onnx.graph.transformations.commands import ONNXOutputInsertionCommand @@ -40,7 +46,8 @@ class ONNXModelTransformer(ModelTransformer): ZERO_POINT_NAME_PREFIX = "zero_point_" def __init__(self, model: onnx.ModelProto): - super().__init__(model) + infered_model = onnx.shape_inference.infer_shapes(model) + super().__init__(infered_model) self.onnx_model_extractor = onnx.utils.Extractor(self._model) def _get_target_edge( @@ -48,7 +55,7 @@ def _get_target_edge( port_id: int, node_name: str, transform_type: TargetType, - onnx_graph: ONNXGraph, + node_mapping: Dict[str, onnx.NodeProto], input_edges_mapping: Dict[str, str], ) -> str: """ @@ -57,16 +64,16 @@ def _get_target_edge( :param port_id: Edge number of port. :param node_name: Node name. :param transform_type: Type of transformation. - :param onnx_graph: ONNXGraph. + :param node_mapping: Mapping from a node name to the node. :param input_edges_mapping: Mapping between NNCF Input nodes and - the following ONNX nodes and corresponding input port id. + the following ONNX nodes and corresponding input port id. :return: Target edge name. """ if transform_type in [TargetType.PRE_LAYER_OPERATION, TargetType.OPERATION_WITH_WEIGHTS]: - return onnx_graph.get_node_edge_names(node_name)["input"][port_id] + return node_mapping[node_name].input[port_id] if node_name in input_edges_mapping: # ADD INPUT NODE CASE - return get_input_edge(node_name, input_edges_mapping, onnx_graph) - return onnx_graph.get_node_edge_names(node_name)["output"][port_id] + return get_input_edge(node_name, input_edges_mapping, node_mapping) + return node_mapping[node_name].output[port_id] def transform(self, transformation_layout: TransformationLayout) -> onnx.ModelProto: """ @@ -123,15 +130,15 @@ def _apply_output_insertion_transformations( :param transformations: ONNXOutputInsertionCommand transformations. :return: New model with inserted outputs. """ - onnx_graph = ONNXGraph(self._model) - model_outputs = set(output.name for output in onnx_graph.get_model_outputs()) + model_outputs = set(output.name for output in self._model.graph.output) + node_mapping = get_name_to_node_map(self._model) for transformation in transformations: port_id = transformation.target_point.port_id node_name = transformation.target_point.target_node_name transform_type = transformation.target_point.type input_edges_mapping = transformation.input_edges_mapping target_edge_name = self._get_target_edge( - port_id, node_name, transform_type, onnx_graph, input_edges_mapping + port_id, node_name, transform_type, node_mapping, input_edges_mapping ) model_outputs.add(target_edge_name) @@ -146,11 +153,11 @@ def _insert_outputs(model: onnx.ModelProto, outputs: Union[List[str], Set[str]]) :param outputs: Edge names to use as outputs. :return: New model with inserted outputs. """ - onnx_graph = ONNXGraph(model) model_outputs = [] + edge_info_mapping = get_edge_info_mapping(model) for output in outputs: - edge = onnx_graph.get_edge(output) - onnx_dtype = ONNXGraph.get_edge_dtype(edge) + edge = edge_info_mapping[output] + onnx_dtype = get_edge_dtype(edge) type_proto = onnx.helper.make_tensor_type_proto(onnx_dtype, shape=None) model_outputs.append(onnx.helper.make_value_info(name=output, type_proto=type_proto)) @@ -193,7 +200,8 @@ def _apply_quantizer_insertion_transformations( """ self._added_target_edges = Counter() for transformation in transformations: - model = self._insert_quantizer_dequantizer(model, transformation) + children_node_mapping = get_children_node_mapping(model) + model = self._insert_quantizer_dequantizer(model, transformation, children_node_mapping) return model def _get_quantize_dequantize_nodes( @@ -274,35 +282,39 @@ def _get_scale_zero_point_tensors( return onnx_scale_tensor, onnx_zero_point_tensor def _get_quantizer_dequantizer_edge_name( - self, transformation: ONNXQuantizerInsertionCommand, onnx_graph: ONNXGraph + self, transformation: ONNXQuantizerInsertionCommand, node_mapping: Dict[str, onnx.NodeProto] ) -> str: """ Returns an edge name on which QuantizeLinear-DequantizeLinear nodes pair has to be inserted. :param transformation: QuantizeLinear-DequantizeLinear insertion transformation. - :param onnx_graph: ONNXGraph. + :param node_mapping: Mapping from a node name to the node. :return: Edge name to insert QuantizeLinear-DequantizeLinear nodes pair. """ port_id = transformation.target_point.port_id node_name = transformation.target_point.target_node_name transform_type = transformation.target_point.type input_edges_mapping = transformation.input_edges_mapping - target_edge_name = self._get_target_edge(port_id, node_name, transform_type, onnx_graph, input_edges_mapping) + target_edge_name = self._get_target_edge(port_id, node_name, transform_type, node_mapping, input_edges_mapping) self._added_target_edges[target_edge_name] += 1 return target_edge_name def _insert_quantizer_dequantizer( - self, model: onnx.ModelProto, transformation: ONNXQuantizerInsertionCommand + self, + model: onnx.ModelProto, + transformation: ONNXQuantizerInsertionCommand, + children_node_mapping: Dict[str, List[onnx.ValueInfoProto]], ) -> onnx.ModelProto: """ Inserts QuantizeLinear-DequantizeLinear nodes pair. :param model: Model to insert new nodes. :param transformation: QuantizeLinear-DequantizeLinear insertion transformation. + :param children_node_mapping: Mapping from edge name to nodes which consume this edge as an input. :return: Updated model with inserted QuantizeLinear-DequantizeLinear pair. """ - onnx_graph = ONNXGraph(model) - target_edge_name = self._get_quantizer_dequantizer_edge_name(transformation, onnx_graph) + node_mapping = get_name_to_node_map(model) + target_edge_name = self._get_quantizer_dequantizer_edge_name(transformation, node_mapping) quantizer, dequantizer = self._get_quantize_dequantize_nodes(transformation, target_edge_name) onnx_scale_tensor, onnx_zero_point_tensor = ONNXModelTransformer._get_scale_zero_point_tensors( transformation, quantizer, dequantizer @@ -310,7 +322,7 @@ def _insert_quantizer_dequantizer( # If several nodes on one edge input_nodes = [] - input_nodes.extend(onnx_graph.get_nodes_by_input(target_edge_name)) + input_nodes.extend(children_node_mapping[target_edge_name]) if not input_nodes: raise RuntimeError( f"Can not add the quantizer to the {target_edge_name} edge. This edge does not have end node." @@ -318,7 +330,7 @@ def _insert_quantizer_dequantizer( if transformation.target_point.type == TargetType.PRE_LAYER_OPERATION: # If we need to change only target nodes input - target_node = onnx_graph.get_node_by_name(transformation.target_point.target_node_name) + target_node = node_mapping[transformation.target_point.target_node_name] for i, inp in enumerate(target_node.input): if inp == target_edge_name: target_node.input[i] = dequantizer.output[0] @@ -336,7 +348,7 @@ def _insert_quantizer_dequantizer( ) model.graph.initializer.extend([onnx_scale_tensor, onnx_zero_point_tensor]) model.graph.value_info.extend([onnx_scale_value_info, onnx_zero_point_info]) - insert_index = onnx_graph.get_node_index(input_nodes[0].name) + insert_index = get_node_index(model, input_nodes[0].name) model.graph.node.insert(insert_index, quantizer) model.graph.node.insert(insert_index + 1, dequantizer) return model @@ -351,13 +363,13 @@ def _apply_bias_correction_transformations( :param transformations: Bias correction transformations. :return: Copy of original model with updated biases. """ - onnx_graph = ONNXGraph(model) + node_mapping = get_name_to_node_map(model) for transformation in transformations: bias_tensor_position = transformation.target_point.port_id node_name = transformation.target_point.target_node_name - onnx_node = onnx_graph.get_node_by_name(node_name) + onnx_node = node_mapping[node_name] bias_initializer_name = onnx_node.input[bias_tensor_position] - bias_initializer = onnx_graph.get_tensor(bias_initializer_name) + bias_initializer = get_tensor(model, bias_initializer_name) new_bias_tensor = onnx.numpy_helper.from_array(transformation.bias_value, bias_initializer_name) bias_initializer.CopyFrom(new_bias_tensor) @@ -370,20 +382,19 @@ def _apply_model_extraction_transformation(self, transformation: ONNXModelExtrac :param transformation: Model extraction transformation. :return: Extracted sub-model. """ - onnx_graph = ONNXGraph(self._model) - input_tensor_names = [] + node_mapping = get_name_to_node_map(self._model) for input_node_name in transformation.inputs: - input_onnx_node = onnx_graph.get_node_by_name(input_node_name) + input_onnx_node = node_mapping[input_node_name] input_tensor_names.append(input_onnx_node.input[0]) output_tensor_names = [] for output_node_name in transformation.outputs: - output_onnx_node = onnx_graph.get_node_by_name(output_node_name) + output_onnx_node = node_mapping[output_node_name] output_tensor_names.append(output_onnx_node.output[0]) if not output_tensor_names: - output_tensor_names = [n.name for n in onnx_graph.get_model_outputs()] + output_tensor_names = [n.name for n in self._model.graph.output] return self.onnx_model_extractor.extract_model(input_tensor_names, output_tensor_names) @@ -397,11 +408,12 @@ def _apply_qdq_node_removing_transformations( :param transformations: Nodes removing transformations. :return: Model with removed nodes. """ - onnx_graph = ONNXGraph(model) for transformation in transformations: - node = onnx_graph.get_node_by_name(transformation.target_point.target_node_name) + node_mapping = get_name_to_node_map(model) + children_node_mapping = get_children_node_mapping(model) + node = node_mapping[transformation.target_point.target_node_name] - node_children = onnx_graph.get_children(node) + node_children = get_children(node, children_node_mapping) for node_child in node_children: for input_id, input_obj in enumerate(node_child.input): if input_obj == node.output[0]: diff --git a/nncf/onnx/graph/nncf_graph_builder.py b/nncf/onnx/graph/nncf_graph_builder.py index b728b4020a8..148de756c30 100644 --- a/nncf/onnx/graph/nncf_graph_builder.py +++ b/nncf/onnx/graph/nncf_graph_builder.py @@ -29,7 +29,16 @@ from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXOpWithWeightsMetatype from nncf.onnx.graph.metatypes.onnx_metatypes import get_metatype from nncf.onnx.graph.metatypes.onnx_metatypes import get_tensor_edge_name -from nncf.onnx.graph.onnx_graph import ONNXGraph +from nncf.onnx.graph.onnx_helper import get_children_node_mapping +from nncf.onnx.graph.onnx_helper import get_edge_dtype +from nncf.onnx.graph.onnx_helper import get_edge_info_mapping +from nncf.onnx.graph.onnx_helper import get_edge_shape +from nncf.onnx.graph.onnx_helper import get_input_port_id_for_node_after_input +from nncf.onnx.graph.onnx_helper import get_model_inputs +from nncf.onnx.graph.onnx_helper import get_output_port_id_for_node_before_output +from nncf.onnx.graph.onnx_helper import get_parents_node_mapping +from nncf.onnx.graph.onnx_helper import get_port_ids_between_nodes +from nncf.onnx.graph.onnx_helper import is_node_has_shared_weight class ONNXLayerAttributes(BaseLayerAttributes): @@ -103,23 +112,28 @@ def get_bias_tensor_port_id(metatype: ONNXOpWithWeightsMetatype) -> Optional[int return None -def _get_weight_port_ids(node: onnx.NodeProto, onnx_graph: ONNXGraph) -> Set[int]: +def _get_weight_port_ids( + node: onnx.NodeProto, + model: onnx.ModelProto, + parents_node_mapping: Dict[str, onnx.NodeProto], +) -> Set[int]: """ Returns all weight input ports. First, add constant weight port ids from metatype. Second, add weight port ids determined dynamically if metatype could have them. :param node: ONNX node. - :param onnx_graph: ONNXGraph. + :param model: ONNX model. + :param parents_node_mapping: Mapping from edge name to node which outputs this edge. :return: Port ids with weights. """ port_ids = set() - metatype = get_metatype(onnx_graph.onnx_model, node) + metatype = get_metatype(model, node) constant_port_ids = get_constant_weight_port_ids(metatype) port_ids.update(constant_port_ids) possible_port_ids = get_possible_weight_port_ids(metatype) for port_id in possible_port_ids: - if get_tensor_edge_name(onnx_graph, node, port_id): + if get_tensor_edge_name(model, node, port_id, parents_node_mapping): port_ids.add(port_id) return port_ids @@ -129,7 +143,7 @@ def _is_node_with_bias(node: onnx.NodeProto, model: onnx.ModelProto) -> bool: Returns True if node has bias tensor, otherwise - False. :param node: ONNX node. - :param onnx_graph: ONNXGraph. + :param model: ONNX model. :return: True if node has bias tensor, otherwise - False. """ metatype = get_metatype(model, node) @@ -139,23 +153,6 @@ def _is_node_with_bias(node: onnx.NodeProto, model: onnx.ModelProto) -> bool: return False -def _get_weight_attr(node: onnx.NodeProto, onnx_graph: ONNXGraph, weight_port_id: int) -> Dict[int, Dict]: - """ - Returns weight attributes. - - :param node: ONNX node. - :param onnx_graph: ONNXGraph. - :param weight_port_ids: Port ids with weights location. - :return: Weight attributes. - """ - weight_attrs = {} - weight_edge_name = node.input[weight_port_id] - edge = onnx_graph.get_edge(weight_edge_name) - weight_shape = ONNXGraph.get_edge_shape(edge) - weight_attrs[weight_port_id] = {"name": weight_edge_name, "shape": weight_shape} - return weight_attrs - - def _get_gemm_attrs(node: onnx.NodeProto) -> Dict[str, int]: """ Returns transpose attrbiutes of GEMM node. @@ -176,7 +173,7 @@ def _get_node_attrs(node: onnx.NodeProto, model: onnx.ModelProto) -> Dict[str, A Returns node attributes. :param node: Node. - :param onnx_graph: ONNXGraph. + :param model: ONNX model. :return : Node attributes. """ metatype = get_metatype(model, node) @@ -185,19 +182,24 @@ def _get_node_attrs(node: onnx.NodeProto, model: onnx.ModelProto) -> Dict[str, A return {} -def _get_bias_attr(node: onnx.NodeProto, onnx_graph: ONNXGraph) -> Dict[str, str]: +def _get_bias_attr( + node: onnx.NodeProto, + model: onnx.ModelProto, + parents_node_mapping: Dict[str, onnx.NodeProto], +) -> Dict[str, str]: """ Returns bias tensor attributes. :param node: ONNX node. - :param onnx_graph: ONNXGraph. + :param model: ONNX model. + :param parents_node_mapping: Mapping from edge name to node which outputs this edge. :return: Bias tensor attributes. """ bias_attrs = {} - metatype = get_metatype(onnx_graph.onnx_model, node) - if _is_node_with_bias(node, onnx_graph.onnx_model): + metatype = get_metatype(model, node) + if _is_node_with_bias(node, model): bias_tensor_port_id = get_bias_tensor_port_id(metatype) - bias_edge_name = get_tensor_edge_name(onnx_graph, node, bias_tensor_port_id) + bias_edge_name = get_tensor_edge_name(model, node, bias_tensor_port_id, parents_node_mapping) bias_attrs["name"] = bias_edge_name return bias_attrs @@ -232,15 +234,22 @@ def _replace_empty_node_name(model: onnx.ModelProto) -> onnx.ModelProto: return model @staticmethod - def _add_nncf_input_nodes(onnx_graph: ONNXGraph, nncf_graph: NNCFGraph) -> None: + def _add_nncf_input_nodes( + model: onnx.ModelProto, + nncf_graph: NNCFGraph, + edge_info_mapping: Dict[str, onnx.ValueInfoProto], + children_node_mapping: Dict[str, List[onnx.NodeProto]], + ) -> None: """ Adds special NNCF Input nodes to NNCFGraph. For all the ONNX model inputs, the special NNCF Input node is placed and then corresponding edges are added. - :param onnx_graph: ONNXGraph, which helps to get information about the ONNX model. + :param model: ONNX model. :param nncf_graph: NNCFGraph, in which the new nodes will be added. + :param edge_info_mapping: Mapping from edge name to the edge info. + :param children_node_mapping: Mapping from edge name to nodes which consume this edge as an input. :return: None. """ - for i, _input in enumerate(onnx_graph.get_model_inputs()): + for i, _input in enumerate(get_model_inputs(model)): input_name = _input.name layer_attributes = ONNXLayerAttributes() input_node = nncf_graph.add_nncf_node( @@ -249,18 +258,18 @@ def _add_nncf_input_nodes(onnx_graph: ONNXGraph, nncf_graph: NNCFGraph) -> None: node_metatype=InputNoopMetatype, layer_attributes=layer_attributes, ) - to_nodes = onnx_graph.get_nodes_by_input(input_name) + to_nodes = children_node_mapping[input_name] input_node_node_id = input_node.node_id - edge = onnx_graph.get_edge(input_name) - input_shape = ONNXGraph.get_edge_shape(edge) - onnx_dtype = ONNXGraph.get_edge_dtype(edge) + edge = edge_info_mapping[input_name] + input_shape = get_edge_shape(edge) + onnx_dtype = get_edge_dtype(edge) nncf_dtype = GraphConverter.convert_onnx_dtype_to_nncf_dtype(onnx_dtype) output_port_id = 0 for node in to_nodes: to_node_id = nncf_graph.get_node_by_name(node.name).node_id - input_port_id = ONNXGraph.get_input_port_id_for_node_after_input(input_name, node) + input_port_id = get_input_port_id_for_node_after_input(input_name, node) nncf_graph.add_edge_between_nncf_nodes( from_node_id=input_node_node_id, to_node_id=to_node_id, @@ -272,15 +281,22 @@ def _add_nncf_input_nodes(onnx_graph: ONNXGraph, nncf_graph: NNCFGraph) -> None: output_port_id += 1 @staticmethod - def _add_nncf_output_nodes(onnx_graph: ONNXGraph, nncf_graph: NNCFGraph) -> None: + def _add_nncf_output_nodes( + model: onnx.ModelProto, + nncf_graph: NNCFGraph, + edge_info_mapping: Dict[str, onnx.ValueInfoProto], + parents_node_mapping: Dict[str, onnx.NodeProto], + ) -> None: """ Adds special NNCF Output nodes to NNCFGraph. For all the ONNX model outputs, the special NNCF Output node is placed and then corresponding edges are added. - :param onnx_graph: ONNXGraph, which helps to get information about the ONNX model. + :param model: ONNX model. :param nncf_graph: NNCFGraph, in which the new nodes will be added. + :param edge_info_mapping: Mapping from edge name to the edge info. + :param parents_node_mapping: Mapping from edge name to node which outputs this edge. :return: None. """ - for i, _output in enumerate(onnx_graph.get_model_outputs()): + for i, _output in enumerate(model.graph.output): output_name = _output.name layer_attributes = ONNXLayerAttributes() output_node = nncf_graph.add_nncf_node( @@ -289,16 +305,16 @@ def _add_nncf_output_nodes(onnx_graph: ONNXGraph, nncf_graph: NNCFGraph) -> None node_metatype=OutputNoopMetatype, layer_attributes=layer_attributes, ) - from_node = onnx_graph.get_node_by_output(output_name) + from_node = parents_node_mapping[output_name] output_node_node_id = output_node.node_id - edge = onnx_graph.get_edge(output_name) - output_shape = ONNXGraph.get_edge_shape(edge) - onnx_dtype = ONNXGraph.get_edge_dtype(edge) + edge = edge_info_mapping[output_name] + output_shape = get_edge_shape(edge) + onnx_dtype = get_edge_dtype(edge) nncf_dtype = GraphConverter.convert_onnx_dtype_to_nncf_dtype(onnx_dtype) input_port_id = 0 from_node_id = nncf_graph.get_node_by_name(from_node.name).node_id - output_port_id = ONNXGraph.get_output_port_id_for_node_before_output(output_name, from_node) + output_port_id = get_output_port_id_for_node_before_output(output_name, from_node) nncf_graph.add_edge_between_nncf_nodes( from_node_id=from_node_id, to_node_id=output_node_node_id, @@ -330,21 +346,27 @@ def create_nncf_graph(onnx_model: onnx.ModelProto) -> NNCFGraph: :return: NNCFGraph. """ onnx_model = GraphConverter._replace_empty_node_name(onnx_model) + onnx_model = onnx.shape_inference.infer_shapes(onnx_model) + edge_info_mapping = get_edge_info_mapping(onnx_model) + children_node_mapping = get_children_node_mapping(onnx_model) + parents_node_mapping = get_parents_node_mapping(onnx_model) nncf_graph = NNCFGraph() - onnx_graph = ONNXGraph(onnx_model) - for node in onnx_graph.get_all_nodes(): + for node in onnx_model.graph.node: metatype = get_metatype(onnx_model, node) - weight_port_ids = _get_weight_port_ids(node, onnx_graph) + weight_port_ids = _get_weight_port_ids(node, onnx_model, parents_node_mapping) is_shared = None weight_attrs = {} node_attrs = _get_node_attrs(node, onnx_model) - bias_attrs = _get_bias_attr(node, onnx_graph) + bias_attrs = _get_bias_attr(node, onnx_model, parents_node_mapping) if weight_port_ids: # If node has weight weight_edge_names = [] for weight_port_id in weight_port_ids: - weight_edge_names.append(node.input[weight_port_id]) - weight_attrs.update(_get_weight_attr(node, onnx_graph, weight_port_id)) - if not is_shared and onnx_graph.is_node_has_shared_weight(node, weight_port_id): + weight_edge_name = node.input[weight_port_id] + weight_edge_names.append(weight_edge_name) + edge = edge_info_mapping[weight_edge_name] + weight_shape = get_edge_shape(edge) + weight_attrs[weight_port_id] = {"name": weight_edge_name, "shape": weight_shape} + if not is_shared and is_node_has_shared_weight(node, weight_port_id, children_node_mapping): is_shared = True layer_attributes = ONNXLayerAttributes( @@ -357,22 +379,23 @@ def create_nncf_graph(onnx_model: onnx.ModelProto) -> NNCFGraph: layer_attributes=layer_attributes, is_shared=is_shared, ) - for output_node in onnx_graph.get_all_nodes(): - output_edges = onnx_graph.get_node_edge_names(output_node.name)["output"] + + for output_node in onnx_model.graph.node: + output_edges = output_node.output for output_edge in output_edges: - edge = onnx_graph.get_edge(output_edge) + edge = edge_info_mapping.get(output_edge) if edge is None: # If the edge is None it means that the edge was not added during shape inference of ONNX model. # BatchNorm exported in Training mode has unused outputs edges: mean, var, saved_mean, saved_var. # NNCFGraph should not contain such edges. continue - tensor_shape = ONNXGraph.get_edge_shape(edge) - onnx_dtype = ONNXGraph.get_edge_dtype(edge) + tensor_shape = get_edge_shape(edge) + onnx_dtype = get_edge_dtype(edge) nncf_dtype = GraphConverter.convert_onnx_dtype_to_nncf_dtype(onnx_dtype) output_node_id = nncf_graph.get_node_by_name(output_node.name).node_id - input_nodes = onnx_graph.get_nodes_by_input(output_edge) + input_nodes = children_node_mapping[output_edge] for input_node in input_nodes: - port_ids = ONNXGraph.get_port_ids_between_nodes(output_node, input_node) + port_ids = get_port_ids_between_nodes(output_node, input_node) input_port_id = port_ids["input_port_id"] output_port_id = port_ids["output_port_id"] in_node_id = nncf_graph.get_node_by_name(input_node.name).node_id @@ -384,6 +407,7 @@ def create_nncf_graph(onnx_model: onnx.ModelProto) -> NNCFGraph: output_port_id=output_port_id, dtype=Dtype(nncf_dtype), ) - GraphConverter._add_nncf_input_nodes(onnx_graph, nncf_graph) - GraphConverter._add_nncf_output_nodes(onnx_graph, nncf_graph) + + GraphConverter._add_nncf_input_nodes(onnx_model, nncf_graph, edge_info_mapping, children_node_mapping) + GraphConverter._add_nncf_output_nodes(onnx_model, nncf_graph, edge_info_mapping, parents_node_mapping) return nncf_graph diff --git a/nncf/onnx/graph/node_utils.py b/nncf/onnx/graph/node_utils.py index 6575dff6f1c..1e9a162211d 100644 --- a/nncf/onnx/graph/node_utils.py +++ b/nncf/onnx/graph/node_utils.py @@ -21,7 +21,7 @@ from nncf.common.tensor_statistics.collectors import ReductionAxes from nncf.onnx.graph.metatypes import onnx_metatypes as om from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXDequantizeLinearMetatype -from nncf.onnx.graph.onnx_graph import ONNXGraph +from nncf.onnx.graph.onnx_helper import get_tensor_value from nncf.onnx.graph.transformations.commands import ONNXTargetPoint @@ -45,10 +45,9 @@ def get_bias_value(node_with_bias: NNCFNode, model: onnx.ModelProto) -> np.ndarr :param model: The model that contains this operation. :return: The bias value that is applied to the output tensor of the node's operation. """ - onnx_graph = ONNXGraph(model) assert node_with_bias.layer_attributes.has_bias() bias_name = node_with_bias.layer_attributes.bias_attrs["name"] - return onnx_graph.get_tensor_value(bias_name) + return get_tensor_value(model, bias_name) def get_input_edges_mapping(nncf_graph: NNCFGraph) -> Dict[str, Tuple[str, int]]: @@ -68,20 +67,25 @@ def get_input_edges_mapping(nncf_graph: NNCFGraph) -> Dict[str, Tuple[str, int]] return input_edges_mapping -def get_input_edge(input_node_name: str, input_edges_mapping: Dict[str, Tuple[str, int]], onnx_graph: ONNXGraph) -> str: +def get_input_edge( + input_node_name: str, + input_edges_mapping: Dict[str, Tuple[str, int]], + node_mapping: Dict[str, onnx.NodeProto], +) -> str: """ Returns input edge corresponding to the NNCF input node with the name input_node_name. :param input_node_name: Name of NNCF input node. :param input_edges_mapping: A mapping of NNCF input node names and - a tuple with the consumed node names and their input port ids. - :param onnx_graph: Instance of ONNXGraph of the model. + a tuple with the consumed node names and their input port ids. + :param node_mapping: Mapping of node names to the nodes. :return: Input edge name. """ input_edges = set() for node_info in input_edges_mapping[input_node_name]: name, port_id = node_info - input_edges.add(onnx_graph.get_node_edge_names(name)["input"][port_id]) + node = node_mapping[name] + input_edges.add(node.input[port_id]) assert len(input_edges) == 1 return input_edges.pop() diff --git a/nncf/onnx/graph/onnx_graph.py b/nncf/onnx/graph/onnx_graph.py deleted file mode 100644 index df754263c99..00000000000 --- a/nncf/onnx/graph/onnx_graph.py +++ /dev/null @@ -1,321 +0,0 @@ -# Copyright (c) 2023 Intel Corporation -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Dict, Iterator, List, Optional, Union - -import numpy as np -import onnx -from onnx import numpy_helper - - -class ONNXGraph: - """ - The class provides the interface to get the necessary information from ONNX model. - """ - - def __init__(self, onnx_model: onnx.ModelProto): - self.onnx_model = onnx_model - self._node_name_to_node = None # type: Dict[str, onnx.NodeProto] - self._edge_name_to_value_info = None # type: Dict[str, onnx.ValueInfoProto] - - def _update_edges(self) -> None: - self.onnx_model = onnx.shape_inference.infer_shapes(self.onnx_model) - value_infos = [ - *self.onnx_model.graph.value_info, - *self.onnx_model.graph.input, - *self.onnx_model.graph.output, - *self.onnx_model.graph.initializer, - ] - self._edge_name_to_value_info = {tensor.name: tensor for tensor in value_infos} - - def _update_node_names(self) -> None: - self._node_name_to_node = {n.name: n for n in self.onnx_model.graph.node} - - def _get_all_tensors(self) -> Iterator[onnx.TensorProto]: - """ - Iterate over all tensors of ONNX model. - - :yield: tensors of ONNX model. - """ - for initializer in self.onnx_model.graph.initializer: - yield initializer - for node in self.onnx_model.graph.node: - for attribute in node.attribute: - if attribute.HasField("t"): - yield attribute.t - yield from attribute.tensors - - def get_all_nodes(self) -> List[onnx.NodeProto]: - """ - Returns model nodes in the original order. - - :return: model nodes. - """ - return self.onnx_model.graph.node - - def get_node_by_name(self, node_name: str) -> Optional[onnx.NodeProto]: - """ - Returns a model node with the name equals to 'node_name' from self._node_name_to_node. - If the self._node_name_to_node is None, fills it with the nodes from the self.onnx_model. - If there is no node with such name returns None. - - :param node_name: Name of the node. - :return: None if the node with the specified name exists - otherwise returns the node. - """ - if self._node_name_to_node is None: - self._update_node_names() - return self._node_name_to_node[node_name] if node_name in self._node_name_to_node else None - - def get_edge(self, edge_name: str) -> Optional[onnx.ValueInfoProto]: - """ - Returns edge by its name or None if the model has no such edge. - If self._edge_name_to_value_info is not initialized runs an initialization. - - :param edge_name: Name of edge. - :return: Edge. - """ - if self._edge_name_to_value_info is None: - self._update_edges() - return self._edge_name_to_value_info.get(edge_name, None) - - def get_model_inputs(self) -> List[onnx.ValueInfoProto]: - """ - Returns all model inputs. - - :return: Model Inputs. - """ - inputs = [] - input_all = [node.name for node in self.onnx_model.graph.input] - input_initializer = [node.name for node in self.onnx_model.graph.initializer] - net_feed_input = list(set(input_all) - set(input_initializer)) - for node in self.onnx_model.graph.input: - if node.name in net_feed_input: - inputs.append(node) - return inputs - - def get_model_outputs(self) -> List[onnx.ValueInfoProto]: - """ - Returns all model outputs. - - :return: Model Outputs. - """ - return list(self.onnx_model.graph.output) - - def get_node_by_output(self, output_name: str) -> Optional[onnx.NodeProto]: - """ - Returns node that have output edge with the name 'output_name'. - - :param output_name: The name of output edge. - :return: Node with corresponding output. - """ - for node in self.get_all_nodes(): - if output_name in node.output: - return node - return None - - def get_nodes_by_input(self, input_name: str) -> List[onnx.NodeProto]: - """ - Returns all nodes that have input with the name 'input_name'. - - :param input_name: The name of input edge. - :return: Nodes with corresponding input. - """ - output = [] - for node in self.get_all_nodes(): - if input_name in node.input: - output.append(node) - return output - - def get_node_edge_names(self, node_name: str) -> Dict[str, List[str]]: - """ - Returns node edge names. - - :param node_name: The name of the node. - :return: Dict with two keys: 'input' and 'output', - which are corresponding to input and output edges accordingly. - """ - if self._node_name_to_node is None: - self._update_node_names() - if node_name in self._node_name_to_node: - return { - "input": list(self._node_name_to_node[node_name].input), - "output": list(self._node_name_to_node[node_name].output), - } - raise RuntimeError("There is no node with the name {}".format(node_name)) - - @staticmethod - def get_input_port_id_for_node_after_input(input_name: str, to_node: onnx.NodeProto) -> int: - """ - Returns input_port_id for 'to_node' connected with the model input with the name 'input_name'. - - :param input_name: Name of the ONNX model Input. - :param to_node: Node, which has input edge with 'input_name' name. - :return: input port number for 'to_node', which is connected to 'input_name'. - """ - for input_port_id, port in enumerate(to_node.input): - if port == input_name: - return input_port_id - raise RuntimeError(f"The node {to_node} does not have input edge with the name {input_name}") - - @staticmethod - def get_output_port_id_for_node_before_output(output_name: str, from_node: onnx.NodeProto) -> int: - """ - Returns output_port_id for 'from_node' connected with the model output with the name 'output_name'. - - :param output_name: Name of the ONNX model Output. - :param from_node: Node, which has output edge with 'output_name' name. - :return: output port number for 'from_node', which is connected to 'output_name'. - """ - for output_port_id, port in enumerate(from_node.output): - if port == output_name: - return output_port_id - raise RuntimeError(f"The node {from_node} does not have output edge with the name {output_name}") - - @staticmethod - def get_port_ids_between_nodes(from_node: onnx.NodeProto, to_node: onnx.NodeProto) -> Dict[str, int]: - """ - Returns input_port_id and output_port_id between 'from_node' and 'to_node'. - - :param from_node: Node, whose output is connected to 'to_node' node. - :param to_node: Node, whose input is connected to 'from_node' node. - :return: Dict{'input_port_id': input port id, 'output_port_id': output port id} - """ - output = {"input_port_id": None, "output_port_id": None} - for port_id, port in enumerate(to_node.input): - if port in from_node.output: - output["input_port_id"] = port_id - for port_id, port in enumerate(from_node.output): - if port in to_node.input: - output["output_port_id"] = port_id - if output["output_port_id"] is None or output["input_port_id"] is None: - raise RuntimeError(f"The nodes {from_node.name} and {to_node.name} do not have edges between.") - return output - - def get_node_index(self, node_name: str) -> int: - """ - Returns the node index in the model. - - :param node_name: Name of the node. - :return: Node index, -1 if there is no such node. - """ - for i, node in enumerate(self.get_all_nodes()): - if node.name == node_name: - return i - return -1 - - def has_tensor(self, tensor_name: str) -> bool: - """ - Returns True whether the model has the tensor with the name equals to tensor_name. - - :param tensor_name: Name of the tensor. - :return: True if the model has such tensor, False - otherwise. - """ - for tensor in self._get_all_tensors(): - if tensor.name == tensor_name: - return True - return False - - def get_tensor_value(self, tensor_name: str) -> np.ndarray: - """ - Returns tensor value of a tensor with the name 'tensor_name'. - - :param tensor_name: Name of the tensor. - :return: The value of the tensor. - """ - tensor = self.get_tensor(tensor_name) - return numpy_helper.to_array(tensor) - - def get_tensor(self, tensor_name: str) -> onnx.TensorProto: - """ - Returns a tensor with the name 'tensor_name'. - - :param initializer_name: Name of the Initializer. - :return: The Initializer. - """ - for tensor in self._get_all_tensors(): - if tensor.name == tensor_name: - return tensor - raise RuntimeError("There is no tensor with the name {}".format(tensor_name)) - - @staticmethod - def get_edge_shape(edge: Union[onnx.ValueInfoProto, onnx.TensorProto]) -> List[int]: - """ - Returns edge shape. - - :param edge: The edge. - :return: Shape of the Tensor. - """ - if isinstance(edge, onnx.TensorProto): - return list(edge.dims) - tensor_type = edge.type.tensor_type - shape = [] - if tensor_type.HasField("shape"): - for d in tensor_type.shape.dim: - if d.HasField("dim_value"): - dim_value = d.dim_value - if isinstance(dim_value, int): - shape.append(dim_value) - else: - return shape - elif d.HasField("dim_param"): - # flexible shape make manually -1 - shape.append(-1) - else: - return shape - return shape - - @staticmethod - def get_edge_dtype(edge: Union[onnx.ValueInfoProto, onnx.TensorProto]) -> int: - """ - Returns the data type of the edge. - - :param edge: The edge. - :return: Data type of the edge. - """ - if isinstance(edge, onnx.ValueInfoProto): - return edge.type.tensor_type.elem_type - return edge.data_type - - def get_parent(self, node: onnx.NodeProto, port_id: int) -> Optional[onnx.NodeProto]: - """ - Returns parents of the node. If there is no parent node, returns None. - - :param node: The child node. - :param port_id: Input port id on which the parent is seeked. - :return: Parent node. - """ - if port_id < len(node.input): - return self.get_node_by_output(node.input[port_id]) - return None - - def get_children(self, node: onnx.NodeProto) -> List[onnx.NodeProto]: - """ - Returns children of the node. - - :param node: The parent node. - :return: All children nodes. - """ - output = [] - node_edges = self.get_node_edge_names(node.name)["output"] - for node_edge in node_edges: - output.extend(self.get_nodes_by_input(node_edge)) - return output - - def is_node_has_shared_weight(self, node: onnx.NodeProto, weight_port_id: int) -> bool: - """ - Returns whether the node share a weight. - - :param node: Node. - :return: True whether node shares a weight - otherwise False. - """ - weight_tensor_edge = self.get_node_edge_names(node.name)["input"][weight_port_id] - nodes = self.get_nodes_by_input(weight_tensor_edge) - return len(nodes) > 1 diff --git a/nncf/onnx/graph/onnx_helper.py b/nncf/onnx/graph/onnx_helper.py new file mode 100644 index 00000000000..f6b082050a0 --- /dev/null +++ b/nncf/onnx/graph/onnx_helper.py @@ -0,0 +1,290 @@ +# Copyright (c) 2023 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import defaultdict +from typing import Dict, Iterator, List, Optional, Union + +import numpy as np +import onnx +from onnx import numpy_helper + + +def get_name_to_node_map(model: onnx.ModelProto) -> Dict[str, onnx.NodeProto]: + """ + Returns mapping from node name to the node. + + :param model: Model from mapping is built. + :return: Mapping. + """ + return {node.name: node for node in model.graph.node} + + +def get_edge_info_mapping(model: onnx.ModelProto) -> Dict[str, onnx.ValueInfoProto]: + """ + Retuns mapping from edge name to the edge info. + + :param model: Model from mapping is built. + :return: Mapping. + """ + return { + tensor.name: tensor + for tensor in (*model.graph.value_info, *model.graph.input, *model.graph.output, *model.graph.initializer) + } + + +def get_children_node_mapping(model: onnx.ModelProto) -> Dict[str, List[onnx.NodeProto]]: + """ + Returns a mapping from edge name to nodes which consume this edge as an input. + + :param model: ONNX model. + :return: Mapping from edge name to nodes which consume this edge as an input. + """ + output = defaultdict(list) + for node in model.graph.node: + for edge in node.input: + output[edge].append(node) + return output + + +def get_parents_node_mapping(model: onnx.ModelProto) -> Dict[str, onnx.NodeProto]: + """ + Returns a mapping from edge name to node which outputs this edge. + + :param model: ONNX model. + :return: Mapping from edge name to node which outputs this edge. + """ + output = {} + for node in model.graph.node: + for edge in node.output: + output[edge] = node + return output + + +def get_model_inputs(model: onnx.ModelProto) -> List[onnx.ValueInfoProto]: + """ + Returns all model inputs. + + :param model: ONNX model. + :return: Model Inputs. + """ + inputs = [] + input_all = [node.name for node in model.graph.input] + input_initializer = [node.name for node in model.graph.initializer] + net_feed_input = list(set(input_all) - set(input_initializer)) + for node in model.graph.input: + if node.name in net_feed_input: + inputs.append(node) + return inputs + + +def get_input_port_id_for_node_after_input(input_name: str, to_node: onnx.NodeProto) -> int: + """ + Returns input_port_id for 'to_node' connected with the model input with the name 'input_name'. + + :param input_name: Name of the ONNX model Input. + :param to_node: Node, which has input edge with 'input_name' name. + :return: input port number for 'to_node', which is connected to 'input_name'. + """ + for input_port_id, port in enumerate(to_node.input): + if port == input_name: + return input_port_id + raise RuntimeError(f"The node {to_node} does not have input edge with the name {input_name}") + + +def get_output_port_id_for_node_before_output(output_name: str, from_node: onnx.NodeProto) -> int: + """ + Returns output_port_id for 'from_node' connected with the model output with the name 'output_name'. + + :param output_name: Name of the ONNX model Output. + :param from_node: Node, which has output edge with 'output_name' name. + :return: output port number for 'from_node', which is connected to 'output_name'. + """ + for output_port_id, port in enumerate(from_node.output): + if port == output_name: + return output_port_id + raise RuntimeError(f"The node {from_node} does not have output edge with the name {output_name}") + + +def get_port_ids_between_nodes(from_node: onnx.NodeProto, to_node: onnx.NodeProto) -> Dict[str, int]: + """ + Returns input_port_id and output_port_id between 'from_node' and 'to_node'. + + :param from_node: Node, whose output is connected to 'to_node' node. + :param to_node: Node, whose input is connected to 'from_node' node. + :return: Dict{'input_port_id': input port id, 'output_port_id': output port id} + """ + output = {"input_port_id": None, "output_port_id": None} + for port_id, port in enumerate(to_node.input): + if port in from_node.output: + output["input_port_id"] = port_id + for port_id, port in enumerate(from_node.output): + if port in to_node.input: + output["output_port_id"] = port_id + if output["output_port_id"] is None or output["input_port_id"] is None: + raise RuntimeError(f"The nodes {from_node.name} and {to_node.name} do not have edges between.") + return output + + +def get_node_index(model: onnx.ModelProto, node_name: str) -> Optional[int]: + """ + Returns the node index in the model. + + :param model: ONNX model. + :param node_name: Name of the node. + :return: Node index, -1 if there is no such node. + """ + for i, node in enumerate(model.graph.node): + if node.name == node_name: + return i + return None + + +def _get_all_tensors(model: onnx.ModelProto) -> Iterator[onnx.TensorProto]: + """ + Iterate over all tensors of ONNX model. + + :param model: ONNX model. + :yield: tensors of ONNX model. + """ + for initializer in model.graph.initializer: + yield initializer + for node in model.graph.node: + for attribute in node.attribute: + if attribute.HasField("t"): + yield attribute.t + yield from attribute.tensors + + +def has_tensor(model: onnx.ModelProto, tensor_name: str) -> bool: + """ + Returns True whether the model has the tensor with the name equals to tensor_name. + + :param model: ONNX model. + :param tensor_name: Name of the tensor. + :return: True if the model has such tensor, False - otherwise. + """ + for tensor in _get_all_tensors(model): + if tensor.name == tensor_name: + return True + return False + + +def get_tensor(model: onnx.ModelProto, tensor_name: str) -> onnx.TensorProto: + """ + Returns a tensor with the name 'tensor_name'. + + :param model: ONNX model. + :param tensor_name: Name of the tensor. + :return: The Initializer. + """ + for tensor in _get_all_tensors(model): + if tensor.name == tensor_name: + return tensor + raise RuntimeError("There is no tensor with the name {}".format(tensor_name)) + + +def get_tensor_value(model: onnx.ModelProto, tensor_name: str) -> np.ndarray: + """ + Returns tensor value of a tensor with the name 'tensor_name'. + + :param model: ONNX model. + :param tensor_name: Name of the tensor. + :return: The value of the tensor. + """ + return numpy_helper.to_array(get_tensor(model, tensor_name)) + + +def get_edge_shape(edge: Union[onnx.ValueInfoProto, onnx.TensorProto]) -> List[int]: + """ + Returns edge shape. + + :param edge: The edge. + :return: Shape of the Tensor. + """ + if isinstance(edge, onnx.TensorProto): + return list(edge.dims) + tensor_type = edge.type.tensor_type + shape = [] + if tensor_type.HasField("shape"): + for d in tensor_type.shape.dim: + if d.HasField("dim_value"): + dim_value = d.dim_value + if isinstance(dim_value, int): + shape.append(dim_value) + else: + return shape + elif d.HasField("dim_param"): + # flexible shape make manually -1 + shape.append(-1) + else: + return shape + return shape + + +def get_edge_dtype(edge: Union[onnx.ValueInfoProto, onnx.TensorProto]) -> int: + """ + Returns the data type of the edge. + + :param edge: The edge. + :return: Data type of the edge. + """ + if isinstance(edge, onnx.ValueInfoProto): + return edge.type.tensor_type.elem_type + return edge.data_type + + +def get_parent( + node: onnx.NodeProto, + port_id: int, + parents_node_mapping: Dict[str, onnx.NodeProto], +) -> Optional[onnx.NodeProto]: + """ + Returns parents of the node. If there is no parent node, returns None. + + :param node: The child node. + :param port_id: Input port id on which the parent is seeked. + :param edge_node_mapping: Mapping describing start and consumed nodes of the edges. + :return: Parent node. + """ + if port_id < len(node.input): + return parents_node_mapping.get(node.input[port_id]) + return None + + +def get_children(node: onnx.NodeProto, children_node_mapping: Dict[str, List[onnx.NodeProto]]) -> List[onnx.NodeProto]: + """ + Returns children of the node. + + :param node: The parent node. + :param edge_node_mapping: Mapping describing start and consumed nodes of the edges. + :return: All children nodes. + """ + output = [] + for node_edge in node.output: + output.extend(children_node_mapping[node_edge]) + return output + + +def is_node_has_shared_weight( + node: onnx.NodeProto, + weight_port_id: int, + children_node_mapping: Dict[str, List[onnx.NodeProto]], +) -> bool: + """ + Returns whether the node share a weight. + + :param node: Node. + :param weight_port_id: Port id on which there is a weight. + :param edge_node_mapping: Mapping describing start and consumed nodes of the edges. + :return: True whether node shares a weight - otherwise False. + """ + weight_tensor_edge = node.input[weight_port_id] + nodes = children_node_mapping[weight_tensor_edge] + return len(nodes) > 1 diff --git a/nncf/onnx/statistics/aggregator.py b/nncf/onnx/statistics/aggregator.py index e3435382b5d..a768a855258 100644 --- a/nncf/onnx/statistics/aggregator.py +++ b/nncf/onnx/statistics/aggregator.py @@ -22,7 +22,7 @@ from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.onnx.graph.node_utils import get_input_edge from nncf.onnx.graph.node_utils import get_input_edges_mapping -from nncf.onnx.graph.onnx_graph import ONNXGraph +from nncf.onnx.graph.onnx_helper import get_name_to_node_map from nncf.onnx.graph.transformations.commands import ONNXOutputInsertionCommand from nncf.onnx.tensor import ONNXNNCFTensor @@ -30,28 +30,30 @@ class ONNXStatisticsAggregator(StatisticsAggregator): def collect_statistics(self, model: onnx.ModelProto, graph: NNCFGraph) -> None: self.input_edges_mapping = get_input_edges_mapping(graph) - self._onnx_graph = ONNXGraph(model) + self.node_mapping = get_name_to_node_map(model) self._registered_weights = set() super().collect_statistics(model, graph) def _register_statistics( self, outputs: Dict[str, ONNXNNCFTensor], statistic_points: StatisticPointsContainer ) -> None: - for node_name, _statistic_points in statistic_points.items(): + for _statistic_points in statistic_points.values(): for statistic_point in _statistic_points: target_point = statistic_point.target_point port_id = target_point.port_id if target_point.target_node_name in self.input_edges_mapping: # Input case edge_name = get_input_edge( - target_point.target_node_name, self.input_edges_mapping, self._onnx_graph + target_point.target_node_name, + self.input_edges_mapping, + self.node_mapping, ) - statistic_point.register_tensor(outputs[edge_name]) elif target_point.type == TargetType.POST_LAYER_OPERATION: - edge_name = self._onnx_graph.get_node_edge_names(node_name)["output"][port_id] - statistic_point.register_tensor(outputs[edge_name]) + node = self.node_mapping[target_point.target_node_name] + edge_name = node.output[port_id] elif target_point.type in [TargetType.PRE_LAYER_OPERATION, TargetType.OPERATION_WITH_WEIGHTS]: - edge_name = self._onnx_graph.get_node_edge_names(node_name)["input"][port_id] - statistic_point.register_tensor(outputs[edge_name]) + node = self.node_mapping[target_point.target_node_name] + edge_name = node.input[port_id] + statistic_point.register_tensor(outputs[edge_name]) def _get_transformation_layout_extra_outputs( self, statistic_points: StatisticPointsContainer diff --git a/nncf/quantization/algorithms/bias_correction/onnx_backend.py b/nncf/quantization/algorithms/bias_correction/onnx_backend.py index 364c93acc5a..d7f34936bfd 100644 --- a/nncf/quantization/algorithms/bias_correction/onnx_backend.py +++ b/nncf/quantization/algorithms/bias_correction/onnx_backend.py @@ -22,7 +22,7 @@ from nncf.onnx.graph.node_utils import get_bias_value from nncf.onnx.graph.node_utils import is_any_weight_quantized from nncf.onnx.graph.node_utils import is_node_with_bias -from nncf.onnx.graph.onnx_graph import ONNXGraph +from nncf.onnx.graph.onnx_helper import get_name_to_node_map from nncf.onnx.graph.transformations.command_creation import create_bias_correction_command from nncf.onnx.graph.transformations.commands import ONNXBiasCorrectionCommand from nncf.onnx.graph.transformations.commands import ONNXModelExtractionCommand @@ -101,15 +101,13 @@ def get_bias_value(node: NNCFNode, model: onnx.ModelProto, nncf_graph: NNCFGraph @staticmethod def get_input_name(model: onnx.ModelProto, node_name: str) -> str: - onnx_graph = ONNXGraph(model) - node = onnx_graph.get_node_by_name(node_name) - return node.input[0] + node_mapping = get_name_to_node_map(model) + return node_mapping[node_name].input[0] @staticmethod def get_output_name(model: onnx.ModelProto, node_name: str, output_id: int) -> List[str]: - onnx_graph = ONNXGraph(model) - node = onnx_graph.get_node_by_name(node_name) - return node.output[output_id] + node_mapping = get_name_to_node_map(model) + return node_mapping[node_name].output[output_id] @staticmethod def is_quantized_weights(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: diff --git a/tests/onnx/quantization/common.py b/tests/onnx/quantization/common.py index 01a916b61a5..a5b9c8f47e3 100644 --- a/tests/onnx/quantization/common.py +++ b/tests/onnx/quantization/common.py @@ -18,7 +18,9 @@ from nncf import Dataset from nncf.experimental.tensor import Tensor from nncf.onnx.graph.nncf_graph_builder import GraphConverter -from nncf.onnx.graph.onnx_graph import ONNXGraph +from nncf.onnx.graph.onnx_helper import get_edge_dtype +from nncf.onnx.graph.onnx_helper import get_edge_info_mapping +from nncf.onnx.graph.onnx_helper import get_edge_shape from nncf.onnx.statistics.statistics import ONNXMinMaxTensorStatistic from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters from nncf.quantization.algorithms.post_training.algorithm import PostTrainingQuantization @@ -62,15 +64,15 @@ def _get_input_keys(original_model: onnx.ModelProto) -> str: def get_random_dataset_for_test(model: onnx.ModelProto, has_batch_dim: bool, length: Optional[int] = 10): keys = _get_input_keys(model) - onnx_graph = ONNXGraph(model) + edge_info_mapping = get_edge_info_mapping(model) def transform_fn(i): output = {} for key in keys: - edge = onnx_graph.get_edge(key) - input_dtype = ONNXGraph.get_edge_dtype(edge) + edge = edge_info_mapping[key] + input_dtype = get_edge_dtype(edge) input_np_dtype = onnx.helper.tensor_dtype_to_np_dtype(input_dtype) - shape = ONNXGraph.get_edge_shape(edge) + shape = get_edge_shape(edge) rng = get_random_generator() tensor = rng.uniform(-1, 1, shape).astype(input_np_dtype) if has_batch_dim: diff --git a/tests/onnx/quantization/test_qdq_params_calculation.py b/tests/onnx/quantization/test_qdq_params_calculation.py index bf16eb152b2..1b3367ab6fa 100644 --- a/tests/onnx/quantization/test_qdq_params_calculation.py +++ b/tests/onnx/quantization/test_qdq_params_calculation.py @@ -15,7 +15,7 @@ import pytest from nncf.common.quantization.structs import QuantizationPreset -from nncf.onnx.graph.onnx_graph import ONNXGraph +from nncf.onnx.graph.onnx_helper import get_tensor_value from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters from nncf.quantization.advanced_parameters import OverflowFix from tests.onnx.conftest import ONNX_TEST_ROOT @@ -36,11 +36,10 @@ def get_q_nodes_params(model: onnx.ModelProto) -> Dict[str, np.ndarray]: output = {} - onnx_graph = ONNXGraph(model) - for node in onnx_graph.get_all_nodes(): + for node in model.graph.node: if node.op_type == "QuantizeLinear": - scale = onnx_graph.get_tensor_value(node.input[1]) - zero_point = onnx_graph.get_tensor_value(node.input[2]) + scale = get_tensor_value(model, node.input[1]) + zero_point = get_tensor_value(model, node.input[2]) output[node.name] = {"scale": scale, "zero_point": zero_point} return output diff --git a/tests/onnx/test_model_transformer.py b/tests/onnx/test_model_transformer.py index 4cf5cb4e332..da039ee2d1a 100644 --- a/tests/onnx/test_model_transformer.py +++ b/tests/onnx/test_model_transformer.py @@ -20,7 +20,8 @@ from nncf.common.graph.transformations.layout import TransformationLayout from nncf.onnx.graph.model_transformer import ONNXModelTransformer from nncf.onnx.graph.nncf_graph_builder import GraphConverter -from nncf.onnx.graph.onnx_graph import ONNXGraph +from nncf.onnx.graph.onnx_helper import get_tensor +from nncf.onnx.graph.onnx_helper import get_tensor_value from nncf.onnx.graph.transformations.commands import ONNXBiasCorrectionCommand from nncf.onnx.graph.transformations.commands import ONNXOutputInsertionCommand from nncf.onnx.graph.transformations.commands import ONNXQDQNodeRemovingCommand @@ -60,7 +61,7 @@ def test_quantizer_insertion(target_layers, should_raise, quantizer_number): if should_raise: try: _ = model_transformer.transform(transformation_layout) - except RuntimeError: + except KeyError: return transformed_model = model_transformer.transform(transformation_layout) onnx.checker.check_model(transformed_model) @@ -124,17 +125,15 @@ def test_inserted_quantizer_parameters(test_parameters): transformed_model = model_transformer.transform(transformation_layout) onnx.checker.check_model(transformed_model) - onnx_graph = ONNXGraph(transformed_model) - # pylint:disable=no-member for node in transformed_model.graph.node: op_type = node.op_type if op_type == "QuantizeLinear": for attr in node.attribute: assert test_parameters.onnx_attributes[attr.name] == onnx.helper.get_attribute_value(attr) - assert np.allclose(onnx_graph.get_tensor_value(node.input[1]), np.array(test_parameters.scale)) - assert np.allclose(onnx_graph.get_tensor_value(node.input[2]), np.array(test_parameters.zero_point)) - assert onnx_graph.get_tensor_value(node.input[2]).dtype == test_parameters.onnx_dtype + assert np.allclose(get_tensor_value(transformed_model, node.input[1]), np.array(test_parameters.scale)) + assert np.allclose(get_tensor_value(transformed_model, node.input[2]), np.array(test_parameters.zero_point)) + assert get_tensor_value(transformed_model, node.input[2]).dtype == test_parameters.onnx_dtype TARGET_LAYERS = [["ReLU1"], ["Conv1", "BN1"], ["Conv1", "BN1", "ReLU1"]] @@ -160,8 +159,7 @@ def test_output_insertion(target_layers, target_layer_outputs): transformed_model = model_transformer.transform(transformation_layout) - onnx_graph = ONNXGraph(transformed_model) - assert Counter([out.name for out in onnx_graph.get_model_outputs()]) == Counter(target_layer_outputs) + assert Counter([out.name for out in transformed_model.graph.output]) == Counter(target_layer_outputs) CONV_LAYERS = [["Conv1", "Conv2"]] @@ -182,11 +180,11 @@ def test_bias_correction(layers, values, refs): model_transformer = ONNXModelTransformer(model) transformed_model = model_transformer.transform(transformation_layout) - onnx_graph = ONNXGraph(transformed_model) + node_dict = {node.name: node for node in transformed_model.graph.node} for conv_layer, bias_reference in zip(layers, refs): - bias_tensor_name = onnx_graph.get_node_by_name(conv_layer).input[2] - bias_tensor = onnx_graph.get_tensor(bias_tensor_name) + bias_tensor_name = node_dict[conv_layer].input[2] + bias_tensor = get_tensor(transformed_model, bias_tensor_name) bias_value = onnx.numpy_helper.to_array(bias_tensor) assert np.all(bias_value == bias_reference) diff --git a/tests/onnx/weightless_model.py b/tests/onnx/weightless_model.py index 046568df8eb..6f34347ba38 100644 --- a/tests/onnx/weightless_model.py +++ b/tests/onnx/weightless_model.py @@ -19,8 +19,6 @@ from onnx import TensorProto # pylint:disable=no-name-in-module from onnx.external_data_helper import uses_external_data -from nncf.onnx.graph.onnx_graph import ONNXGraph - # pylint: disable=no-member @@ -32,8 +30,7 @@ def load_model_topology_with_zeros_weights(model_path: Union[str, Path]) -> onnx :return: Onnx model with filled the all external tensors by random values. """ model = onnx.load_model(model_path, load_external_data=False) - onnx_graph = ONNXGraph(model) - for tensor in onnx_graph.onnx_model.graph.initializer: + for tensor in model.graph.initializer: if uses_external_data(tensor): np_dtype = onnx.helper.tensor_dtype_to_np_dtype(tensor.data_type) np_tensor = np.zeros(list(tensor.dims)).astype(np_dtype)