Skip to content

Commit

Permalink
Unification of layout and data types of Parameter and Result in MO (#…
Browse files Browse the repository at this point in the history
…7630)

* Added transposes insertion for Parameter and Result.

* Separated into several transformations.

* Corrected runtime_info format.

* Fixed runtime info serialization.

* Code refactoring.

* Corrected checks.

* Added debug output.

* Added check.

* Fixed unit tests.

* Changed old api map format, removed debug output.

* Moved serialize to rt_info property, code corrections.

* Refactored RTInfo class.

* Small corrections.

* Small corrections.

* Removed redurant import.

* Added tests, added undefined default type.

* Code reformat.

* Fixed serialization unit tests.

* Added comment.

* Added comment.

* Small test correction.

* Changed default values of old_api_map to values from old API IR.

* np.array -> int64_array

* Update MO to use FE to read IR; Swith MO IR version to 11

* Preserve output node name when inserting Transpose

* Codestyle

* Fix layer tests

* Pylint fix

* Disable ref_graphs comparision in layer tests

* codestyle

* Updated MO IR reader.

* Moved version initialization to constructor of OldApiMap.

* Added shape infer after transpose insertion.

* Fixed Pylint

* Removed wrong attribute removal.

* Added transposes insertion for Parameter and Result.

* Separated into several transformations.

* Corrected runtime_info format.

* Fixed runtime info serialization.

* Code refactoring.

* Corrected checks.

* Added debug output.

* Added check.

* Fixed unit tests.

* Changed old api map format, removed debug output.

* Moved serialize to rt_info property, code corrections.

* Refactored RTInfo class.

* Small corrections.

* Small corrections.

* Removed redurant import.

* Added tests, added undefined default type.

* Code reformat.

* Fixed serialization unit tests.

* Added comment.

* Added comment.

* Small test correction.

* Changed default values of old_api_map to values from old API IR.

* np.array -> int64_array

* Update MO to use FE to read IR; Swith MO IR version to 11

* Preserve output node name when inserting Transpose

* Codestyle

* Fix layer tests

* Pylint fix

* Disable ref_graphs comparision in layer tests

* codestyle

* Updated MO IR reader.

* Moved version initialization to constructor of OldApiMap.

* Added shape infer after transpose insertion.

* Fixed Pylint

* Removed wrong attribute removal.

* Serialize fix.

Co-authored-by: Gleb Kazantaev <[email protected]>
  • Loading branch information
popovaan and Gleb Kazantaev authored Oct 27, 2021
1 parent d5a5dc1 commit bf8f916
Show file tree
Hide file tree
Showing 23 changed files with 568 additions and 85 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,9 @@ def GenerateMappingFile(IENetwork network, string path, bool extract_names):
C.GenerateMappingFile(network.impl, path, extract_names)


def Serialize(IENetwork network, string path_to_xml, string path_to_bin):
C.Serialize(network.impl, path_to_xml, path_to_bin)


def CheckAPI():
C.CheckAPI()
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <pruning.hpp>
#include <transformations/common_optimizations/moc_transformations.hpp>
#include <transformations/control_flow/unroll_tensor_iterator.hpp>
#include <transformations/serialize.hpp>

void InferenceEnginePython::ApplyMOCTransformations(InferenceEnginePython::IENetwork network, bool cf) {
ngraph::pass::Manager manager;
Expand Down Expand Up @@ -55,6 +56,14 @@ void InferenceEnginePython::GenerateMappingFile(InferenceEnginePython::IENetwork
manager.run_passes(network.actual->getFunction());
}

void InferenceEnginePython::Serialize(InferenceEnginePython::IENetwork network,
std::string path_to_xml,
std::string path_to_bin) {
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::Serialize>(path_to_xml, path_to_bin);
manager.run_passes(network.actual->getFunction());
}

void InferenceEnginePython::CheckAPI() {
std::shared_ptr<ngraph::Function> f;
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ void ApplyPruningTransformation(InferenceEnginePython::IENetwork network);

void GenerateMappingFile(InferenceEnginePython::IENetwork network, std::string path, bool extract_names);

void Serialize(InferenceEnginePython::IENetwork network, std::string path_to_xml, std::string path_to_bin);

void CheckAPI();

}; // namespace InferenceEnginePython
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,6 @@ cdef extern from "offline_transformations_api_impl.hpp" namespace "InferenceEngi

cdef void GenerateMappingFile(IENetwork network, string path, bool extract_names)

cdef void Serialize(IENetwork network, string path_to_xml, string path_to_bin)

cdef void CheckAPI()
3 changes: 3 additions & 0 deletions model-optimizer/automation/package_BOM.txt
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,7 @@ extensions/middle/LeakyReluPattern.py
extensions/middle/LSTMRNNSequenceToTensorIterator.py
extensions/middle/MakeKaldiConstReshapable.py
extensions/middle/MarkSubgraphsWithCorrectLayout.py
extensions/middle/MergeNodesPermutations.py
extensions/middle/MoveConstToLoopBody.py
extensions/middle/MulFakeQuantizeFuse.py
extensions/middle/MXNetRNNSequenceNormalize.py
Expand All @@ -612,6 +613,7 @@ extensions/middle/pass_separator.py
extensions/middle/permute_tensor_iterator.py
extensions/middle/PoolV2ToAttributedPool.py
extensions/middle/preprocessing.py
extensions/middle/PreserveRuntimeInfo.py
extensions/middle/quantize_fuses.py
extensions/middle/quantize_linear_resolver.py
extensions/middle/ReluQuantizeFuse.py
Expand Down Expand Up @@ -1073,6 +1075,7 @@ mo/utils/logger.py
mo/utils/model_analysis.py
mo/utils/pipeline_config.py
mo/utils/replacement_pattern.py
mo/utils/runtime_info.py
mo/utils/shape.py
mo/utils/simple_proto_parser.py
mo/utils/str_to.py
Expand Down
32 changes: 13 additions & 19 deletions model-optimizer/extensions/front/ChangePlaceholderTypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from mo.front.common.replacement import FrontReplacementPattern
from mo.graph.graph import Graph, Node
from mo.utils.runtime_info import OldAPIMap


class ChangePlaceholderTypes(FrontReplacementPattern):
Expand All @@ -18,29 +19,22 @@ def is_node_casts_to_float_or_shapeof(node: Node):
return (node.soft_get('type') == 'Convert' and node.soft_get('dst_type') == np.float32) or \
node.soft_get('type') == 'ShapeOf'

@staticmethod
def update_type(node: Node, new_type: np.array):
assert node.has_valid('rt_info')
old_api_map = OldAPIMap(version=0)
if ('old_api_map', old_api_map.get_version()) not in node.rt_info.info:
node.rt_info.info[('old_api_map', old_api_map.get_version())] = old_api_map
node.rt_info.info[('old_api_map', old_api_map.get_version())].old_api_convert(new_type)

def find_and_replace_pattern(self, graph: Graph):
for op in graph.get_op_nodes(type='Parameter'):
consumer_nodes = [p.node for p in op.out_port(0).get_destinations()]
if all([ChangePlaceholderTypes.is_node_casts_to_float_or_shapeof(consumer) for consumer in consumer_nodes]):
log.debug('Convert data type of Parameter "{}" to float32'.format(op.soft_get('name', op.id)))
op.data_type = np.float32
for convert_node in consumer_nodes:
if convert_node.soft_get('type') == 'Convert':
log.debug('Removing "Convert" node "{}"'.format(convert_node.soft_get('name', convert_node.id)))

# disconnect consumer ports of Convert operations. Then connect them with an output of Parameter
convert_destinations = convert_node.out_port(0).get_destinations()
for dst_port in convert_destinations:
dst_port.disconnect()
for dst_port in convert_destinations:
op.out_port(0).connect(dst_port)

graph.remove_node(convert_node.id)
self.update_type(op, np.float32)

if op.soft_get('data_type') == np.int64:
op.data_type = np.int32
log.error('Convert data type of Parameter "{}" to int32'.format(op.soft_get('name', op.id)),
extra={'is_warning': True})
self.update_type(op, np.int32)

if op.soft_get('data_type') == np.uint8:
op.data_type = np.float32
log.debug('Convert data type of Parameter "{}" to float'.format(op.soft_get('name', op.id)))
self.update_type(op, np.float32)
52 changes: 4 additions & 48 deletions model-optimizer/extensions/middle/ApplyPermutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@

import numpy as np

from extensions.middle.ApplyNHWCtoNCHWpermutation import ApplyNHWCtoNCHWpermutation
from extensions.middle.InsertLayoutPropagationTransposes import is_input_data_in_correct_layout, \
is_output_data_in_correct_layout
from extensions.middle.LayoutChangeForConstantShapePaths import LayoutChangeForConstantShapePaths
from extensions.middle.pass_separator import PostMiddleStart
from mo.front.common.partial_infer.utils import int64_array, shape_array
from mo.graph.graph import Graph, Node
from extensions.middle.PreserveRuntimeInfo import PreserveRuntimeInfo
from mo.front.common.partial_infer.utils import shape_array
from mo.graph.graph import Graph
from mo.graph.perm_inputs import get_node_with_permutation
from mo.graph.port import Port
from mo.middle.replacement import MiddleReplacementPattern
Expand All @@ -25,61 +24,18 @@ class ApplyPermutation(MiddleReplacementPattern):
graph_condition = [lambda graph: graph.graph['fw'] != 'kaldi']

def run_after(self):
return [ApplyNHWCtoNCHWpermutation, PostMiddleStart]
return [PreserveRuntimeInfo]

def run_before(self):
return []

def find_and_replace_pattern(self, graph: Graph):
self.merge_nodes_permutations(graph)
self.permute_data_nodes_attrs(graph)
self.permute_op_nodes_attrs(graph)
self.shape_of_sub_graph_reinference(graph)
self.permute_input_data(graph)
graph.graph['layout'] = 'NCHW'

@staticmethod
def merge_nodes_permutations(graph: Graph):
# Iterate over all data nodes and check all permutations for similarity
# In case of equal permutations, this permutation will be set as attribute for data node
# otherwise exception will be raised
for node in graph.nodes():
node = Node(graph, node)
if node.kind != 'data':
continue

permutations = []

# Get all permutations from in edges
for in_node in node.in_nodes():
edge_attrs = node.graph.get_edge_data(in_node.id, node.id)[0]
if 'permutation' in edge_attrs:
permutations.append(edge_attrs['permutation'])

# Get all permutations from out edges
for out_node in node.out_nodes():
edge_attrs = node.graph.get_edge_data(node.id, out_node.id)[0]
if 'permutation' in edge_attrs:
permutations.append(edge_attrs['permutation'])

# Check that all permutations are equal
final_permutations = []
for p in permutations:
if p is not None:
final_permutations.append(p.perm)
else:
final_permutations.append(int64_array(np.arange(node.shape.size)))

if len(final_permutations) == 0:
continue

if not all([np.array_equal(final_permutations[0], perm) for perm in final_permutations]):
raise Error('Permutations requested for {} data node are not equal! List of permutations: {}'
''.format(node.name, [p.perm for p in permutations]))

assert not node.has_valid('permutation') or np.array_equal(node.permutation, permutations[0])
node['permutation'] = permutations[0]

@staticmethod
def permute_data_nodes_attrs(graph: Graph):
# Iterate over all data nodes and apply permutation if exists
Expand Down
65 changes: 65 additions & 0 deletions model-optimizer/extensions/middle/MergeNodesPermutations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import numpy as np

from extensions.middle.ApplyNHWCtoNCHWpermutation import ApplyNHWCtoNCHWpermutation
from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Graph, Node
from mo.middle.replacement import MiddleReplacementPattern
from mo.utils.error import Error


class MergeNodesPermutations(MiddleReplacementPattern):
enabled = True

def run_after(self):
return [ApplyNHWCtoNCHWpermutation]

def run_before(self):
return []

def find_and_replace_pattern(self, graph: Graph):
self.merge_nodes_permutations(graph)

@staticmethod
def merge_nodes_permutations(graph: Graph):
# Iterate over all data nodes and check all permutations for similarity
# In case of equal permutations, this permutation will be set as attribute for data node
# otherwise exception will be raised
for node in graph.nodes():
node = Node(graph, node)
if node.kind != 'data':
continue

permutations = []

# Get all permutations from in edges
for in_node in node.in_nodes():
edge_attrs = node.graph.get_edge_data(in_node.id, node.id)[0]
if 'permutation' in edge_attrs:
permutations.append(edge_attrs['permutation'])

# Get all permutations from out edges
for out_node in node.out_nodes():
edge_attrs = node.graph.get_edge_data(node.id, out_node.id)[0]
if 'permutation' in edge_attrs:
permutations.append(edge_attrs['permutation'])

final_permutations = []
for p in permutations:
if p is not None:
final_permutations.append(p.perm)
else:
final_permutations.append(int64_array(np.arange(node.shape.size)))

if len(final_permutations) == 0:
continue

# Check that all permutations are equal
if not all([np.array_equal(final_permutations[0], perm) for perm in final_permutations]):
raise Error('Permutations requested for {} data node are not equal! List of permutations: {}'
''.format(node.name, [p.perm for p in permutations]))

assert not node.has_valid('permutation') or np.array_equal(node.permutation, permutations[0])
node['permutation'] = permutations[0]
98 changes: 98 additions & 0 deletions model-optimizer/extensions/middle/PreserveRuntimeInfo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import numpy as np

from extensions.middle.MergeNodesPermutations import MergeNodesPermutations
from extensions.ops.transpose import Transpose
from mo.front.tf.graph_utils import create_op_node_with_second_input
from mo.graph.graph import Graph
from mo.middle.replacement import MiddleReplacementPattern
from mo.utils.runtime_info import OldAPIMap


class PreserveRuntimeInfo(MiddleReplacementPattern):
""" This transformation preserves original layout for Parameter and Result nodes
and adds old_api_map attribute in rt_info which stores the following information:
Parameter:
Order of the transpose which should be applied to Parameter with old API layout to
obtain Parameter with new API layout.
Result:
Order of the transpose which should be applied to Result with new API layout to
obtain Result with old API layout.
This transformation shouldn't be applied for Parameter or Result nodes inside
body graphs of any operations like If, TensorIterator, Loop etc. For this reason
transformation should be executed non-recursively.
"""
enabled = True
run_not_recursively = True

def run_after(self):
return [MergeNodesPermutations]

def run_before(self):
return []

def find_and_replace_pattern(self, graph: Graph):
self.preserve_rt_info(graph)

@staticmethod
def preserve_rt_info(graph: Graph):
for op in graph.get_op_nodes():
op_name = op.soft_get('name', op.id)
op_type = op.soft_get('type')
if op_type == 'Parameter' and op.has_valid('permute_attrs') and not op.has_and_set('nchw_layout'):
if not op.out_node(0).has_valid('permutation'):
continue
permutation = op.out_node(0).permutation
if np.array_equal(permutation.inv, range(len(permutation.inv))):
continue

# rt info update
assert op.has('rt_info'), 'Unable to preserve runtime information for node with name={}'.format(op_name)

old_api_map = OldAPIMap(version=0)
if ('old_api_map', old_api_map.get_version()) not in op.rt_info.info:
op.rt_info.info[('old_api_map', old_api_map.get_version())] = old_api_map
op.rt_info.info[('old_api_map', old_api_map.get_version())].old_api_transpose_parameter(permutation.inv)

# keep input in the framework format
transpose = create_op_node_with_second_input(
graph, Transpose, permutation.perm, {'name': op_name + '/Transpose({})'.format(permutation.perm)})

# source mode is used to keep tensor names at Parameter node
op.out_port(0).get_connection().insert_node(transpose, "source")

if op.has_valid('permute_attrs'):
del op['permute_attrs']
if op.out_node(0).has_valid('permutation'):
del op.out_node(0)['permutation']

elif op_type == 'Result' and op.in_ports():
prev_node_out_port = op.in_port(0).get_connection().get_source()
if prev_node_out_port is None:
continue
in_node = prev_node_out_port.node
in_data_node = in_node.out_node(prev_node_out_port.idx)
if in_data_node.has_and_set('permutation'):
permutation = in_data_node['permutation']
if np.array_equal(permutation.perm, range(len(permutation.perm))):
continue

# rt info update
assert op.has('rt_info'), 'Unable to preserve runtime information for node with name={}'.format(op)
old_api_map = OldAPIMap(version=0)
if ('old_api_map', old_api_map.get_version()) not in op.rt_info.info:
op.rt_info.info[('old_api_map', old_api_map.get_version())] = old_api_map
op.rt_info.info[('old_api_map', old_api_map.get_version())].old_api_transpose_result(permutation.perm)

# keep result in the framework format
transpose = create_op_node_with_second_input(graph, Transpose, permutation.inv)
# preserve output node name as it is used as output name in legacy IE API
transpose.name = in_node.name
in_node.name += "/prev"

prev_node_out_port.get_connection().insert_node(transpose)
Loading

0 comments on commit bf8f916

Please sign in to comment.