Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unification of layout and data types of Parameter and Result in MO #7630

Merged
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
1f77e23
Added transposes insertion for Parameter and Result.
popovaan Sep 16, 2021
881e50e
Separated into several transformations.
popovaan Sep 20, 2021
8ee8b78
Corrected runtime_info format.
popovaan Sep 22, 2021
a5766dc
Fixed runtime info serialization.
popovaan Sep 22, 2021
ca2b693
Code refactoring.
popovaan Sep 22, 2021
ec7cba6
Corrected checks.
popovaan Sep 23, 2021
a8b71e4
Added debug output.
popovaan Sep 23, 2021
d1d6736
Added check.
popovaan Sep 24, 2021
170e0fe
Fixed unit tests.
popovaan Sep 27, 2021
802aec6
Changed old api map format, removed debug output.
popovaan Sep 28, 2021
fd96e43
Moved serialize to rt_info property, code corrections.
popovaan Sep 30, 2021
04b34be
Refactored RTInfo class.
popovaan Sep 30, 2021
e688008
Small corrections.
popovaan Sep 30, 2021
5eb1979
Small corrections.
popovaan Sep 30, 2021
3118ac3
Removed redurant import.
popovaan Oct 1, 2021
303ab08
Added tests, added undefined default type.
popovaan Oct 5, 2021
cd11ca9
Code reformat.
popovaan Oct 5, 2021
847cd62
Fixed serialization unit tests.
popovaan Oct 5, 2021
7667b04
Added comment.
popovaan Oct 6, 2021
8f5758e
Added comment.
popovaan Oct 6, 2021
390e31a
Small test correction.
popovaan Oct 6, 2021
b21d370
Changed default values of old_api_map to values from old API IR.
popovaan Oct 7, 2021
01129b3
np.array -> int64_array
popovaan Oct 7, 2021
87f8139
Update MO to use FE to read IR; Swith MO IR version to 11
Oct 13, 2021
07b7e2d
Preserve output node name when inserting Transpose
Oct 15, 2021
a3a8f5a
Codestyle
Oct 15, 2021
0d0fb34
Fix layer tests
Oct 15, 2021
3efc3da
Pylint fix
Oct 15, 2021
168cf53
Disable ref_graphs comparision in layer tests
Oct 17, 2021
ff7bb8c
codestyle
Oct 17, 2021
e660832
Updated MO IR reader.
popovaan Oct 18, 2021
84fd53b
Moved version initialization to constructor of OldApiMap.
popovaan Oct 18, 2021
8d3c9d2
Added shape infer after transpose insertion.
popovaan Oct 19, 2021
d5e80a2
Fixed Pylint
popovaan Oct 19, 2021
99c521b
Removed wrong attribute removal.
popovaan Oct 19, 2021
3691c06
Added transposes insertion for Parameter and Result.
popovaan Sep 16, 2021
42d9bfa
Separated into several transformations.
popovaan Sep 20, 2021
00a846e
Corrected runtime_info format.
popovaan Sep 22, 2021
80ed247
Fixed runtime info serialization.
popovaan Sep 22, 2021
ddc870f
Code refactoring.
popovaan Sep 22, 2021
ab7de5b
Corrected checks.
popovaan Sep 23, 2021
0360576
Added debug output.
popovaan Sep 23, 2021
9726d5e
Added check.
popovaan Sep 24, 2021
d443643
Fixed unit tests.
popovaan Sep 27, 2021
d5fe8fe
Changed old api map format, removed debug output.
popovaan Sep 28, 2021
23e16dd
Moved serialize to rt_info property, code corrections.
popovaan Sep 30, 2021
17a1bde
Refactored RTInfo class.
popovaan Sep 30, 2021
33ac6ab
Small corrections.
popovaan Sep 30, 2021
df44183
Small corrections.
popovaan Sep 30, 2021
3b52039
Removed redurant import.
popovaan Oct 1, 2021
57020f6
Added tests, added undefined default type.
popovaan Oct 5, 2021
822c08e
Code reformat.
popovaan Oct 5, 2021
8dcfb9e
Fixed serialization unit tests.
popovaan Oct 5, 2021
4bbe4c8
Added comment.
popovaan Oct 6, 2021
a0a9106
Added comment.
popovaan Oct 6, 2021
29cf979
Small test correction.
popovaan Oct 6, 2021
eaaf93b
Changed default values of old_api_map to values from old API IR.
popovaan Oct 7, 2021
1064911
np.array -> int64_array
popovaan Oct 7, 2021
2dd6433
Update MO to use FE to read IR; Swith MO IR version to 11
Oct 13, 2021
3d20e29
Preserve output node name when inserting Transpose
Oct 15, 2021
de59993
Codestyle
Oct 15, 2021
1fa2e9a
Fix layer tests
Oct 15, 2021
a976c59
Pylint fix
Oct 15, 2021
8164823
Disable ref_graphs comparision in layer tests
Oct 17, 2021
44f2e77
codestyle
Oct 17, 2021
a398a95
Updated MO IR reader.
popovaan Oct 18, 2021
c813c9a
Moved version initialization to constructor of OldApiMap.
popovaan Oct 18, 2021
b74b868
Added shape infer after transpose insertion.
popovaan Oct 19, 2021
4b848c3
Fixed Pylint
popovaan Oct 19, 2021
673214e
Removed wrong attribute removal.
popovaan Oct 19, 2021
d6b91b3
Merge branch 'old_api_map_transposes_insertion' of https://github.com…
popovaan Oct 21, 2021
b5709c5
Serialize fix.
popovaan Oct 21, 2021
4cba911
Merge remote-tracking branch 'upstream/master' into old_api_map_trans…
popovaan Oct 25, 2021
455f18a
Merge remote-tracking branch 'upstream/master' into old_api_map_trans…
popovaan Oct 26, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,9 @@ def GenerateMappingFile(IENetwork network, string path, bool extract_names):
C.GenerateMappingFile(network.impl, path, extract_names)


def Serialize(IENetwork network, string path_to_xml, string path_to_bin):
C.Serialize(network.impl, path_to_xml, path_to_bin)


def CheckAPI():
C.CheckAPI()
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <pruning.hpp>
#include <transformations/common_optimizations/moc_transformations.hpp>
#include <transformations/control_flow/unroll_tensor_iterator.hpp>
#include <transformations/serialize.hpp>

void InferenceEnginePython::ApplyMOCTransformations(InferenceEnginePython::IENetwork network, bool cf) {
ngraph::pass::Manager manager;
Expand Down Expand Up @@ -55,6 +56,14 @@ void InferenceEnginePython::GenerateMappingFile(InferenceEnginePython::IENetwork
manager.run_passes(network.actual->getFunction());
}

void InferenceEnginePython::Serialize(InferenceEnginePython::IENetwork network,
std::string path_to_xml,
std::string path_to_bin) {
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::Serialize>(path_to_xml, path_to_bin);
manager.run_passes(network.actual->getFunction());
GlebKazantaev marked this conversation as resolved.
Show resolved Hide resolved
}

void InferenceEnginePython::CheckAPI() {
std::shared_ptr<ngraph::Function> f;
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ void ApplyPruningTransformation(InferenceEnginePython::IENetwork network);

void GenerateMappingFile(InferenceEnginePython::IENetwork network, std::string path, bool extract_names);

void Serialize(InferenceEnginePython::IENetwork network, std::string path_to_xml, std::string path_to_bin);

void CheckAPI();

}; // namespace InferenceEnginePython
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,6 @@ cdef extern from "offline_transformations_api_impl.hpp" namespace "InferenceEngi

cdef void GenerateMappingFile(IENetwork network, string path, bool extract_names)

cdef void Serialize(IENetwork network, string path_to_xml, string path_to_bin)

cdef void CheckAPI()
3 changes: 3 additions & 0 deletions model-optimizer/automation/package_BOM.txt
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,7 @@ extensions/middle/LeakyReluPattern.py
extensions/middle/LSTMRNNSequenceToTensorIterator.py
extensions/middle/MakeKaldiConstReshapable.py
extensions/middle/MarkSubgraphsWithCorrectLayout.py
extensions/middle/MergeNodesPermutations.py
extensions/middle/MoveConstToLoopBody.py
extensions/middle/MulFakeQuantizeFuse.py
extensions/middle/MXNetRNNSequenceNormalize.py
Expand All @@ -612,6 +613,7 @@ extensions/middle/pass_separator.py
extensions/middle/permute_tensor_iterator.py
extensions/middle/PoolV2ToAttributedPool.py
extensions/middle/preprocessing.py
extensions/middle/PreserveRuntimeInfo.py
extensions/middle/quantize_fuses.py
extensions/middle/quantize_linear_resolver.py
extensions/middle/ReluQuantizeFuse.py
Expand Down Expand Up @@ -1073,6 +1075,7 @@ mo/utils/logger.py
mo/utils/model_analysis.py
mo/utils/pipeline_config.py
mo/utils/replacement_pattern.py
mo/utils/runtime_info.py
mo/utils/shape.py
mo/utils/simple_proto_parser.py
mo/utils/str_to.py
Expand Down
32 changes: 13 additions & 19 deletions model-optimizer/extensions/front/ChangePlaceholderTypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from mo.front.common.replacement import FrontReplacementPattern
from mo.graph.graph import Graph, Node
from mo.utils.runtime_info import OldAPIMap


class ChangePlaceholderTypes(FrontReplacementPattern):
Expand All @@ -18,29 +19,22 @@ def is_node_casts_to_float_or_shapeof(node: Node):
return (node.soft_get('type') == 'Convert' and node.soft_get('dst_type') == np.float32) or \
node.soft_get('type') == 'ShapeOf'

@staticmethod
def update_type(node: Node, new_type: np.array):
assert node.has_valid('rt_info')
old_api_map = OldAPIMap(version=0)
if ('old_api_map', old_api_map.get_version()) not in node.rt_info.info:
node.rt_info.info[('old_api_map', old_api_map.get_version())] = old_api_map
node.rt_info.info[('old_api_map', old_api_map.get_version())].old_api_convert(new_type)

def find_and_replace_pattern(self, graph: Graph):
for op in graph.get_op_nodes(type='Parameter'):
consumer_nodes = [p.node for p in op.out_port(0).get_destinations()]
if all([ChangePlaceholderTypes.is_node_casts_to_float_or_shapeof(consumer) for consumer in consumer_nodes]):
log.debug('Convert data type of Parameter "{}" to float32'.format(op.soft_get('name', op.id)))
op.data_type = np.float32
for convert_node in consumer_nodes:
if convert_node.soft_get('type') == 'Convert':
log.debug('Removing "Convert" node "{}"'.format(convert_node.soft_get('name', convert_node.id)))

# disconnect consumer ports of Convert operations. Then connect them with an output of Parameter
convert_destinations = convert_node.out_port(0).get_destinations()
for dst_port in convert_destinations:
dst_port.disconnect()
for dst_port in convert_destinations:
op.out_port(0).connect(dst_port)

graph.remove_node(convert_node.id)
self.update_type(op, np.float32)

if op.soft_get('data_type') == np.int64:
op.data_type = np.int32
log.error('Convert data type of Parameter "{}" to int32'.format(op.soft_get('name', op.id)),
extra={'is_warning': True})
self.update_type(op, np.int32)

if op.soft_get('data_type') == np.uint8:
op.data_type = np.float32
log.debug('Convert data type of Parameter "{}" to float'.format(op.soft_get('name', op.id)))
self.update_type(op, np.float32)
52 changes: 4 additions & 48 deletions model-optimizer/extensions/middle/ApplyPermutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@

import numpy as np

from extensions.middle.ApplyNHWCtoNCHWpermutation import ApplyNHWCtoNCHWpermutation
from extensions.middle.InsertLayoutPropagationTransposes import is_input_data_in_correct_layout, \
is_output_data_in_correct_layout
from extensions.middle.LayoutChangeForConstantShapePaths import LayoutChangeForConstantShapePaths
from extensions.middle.pass_separator import PostMiddleStart
from mo.front.common.partial_infer.utils import int64_array, shape_array
from mo.graph.graph import Graph, Node
from extensions.middle.PreserveRuntimeInfo import PreserveRuntimeInfo
from mo.front.common.partial_infer.utils import shape_array
from mo.graph.graph import Graph
from mo.graph.perm_inputs import get_node_with_permutation
from mo.graph.port import Port
from mo.middle.replacement import MiddleReplacementPattern
Expand All @@ -25,61 +24,18 @@ class ApplyPermutation(MiddleReplacementPattern):
graph_condition = [lambda graph: graph.graph['fw'] != 'kaldi']

def run_after(self):
return [ApplyNHWCtoNCHWpermutation, PostMiddleStart]
return [PreserveRuntimeInfo]

def run_before(self):
return []

def find_and_replace_pattern(self, graph: Graph):
self.merge_nodes_permutations(graph)
self.permute_data_nodes_attrs(graph)
self.permute_op_nodes_attrs(graph)
self.shape_of_sub_graph_reinference(graph)
self.permute_input_data(graph)
graph.graph['layout'] = 'NCHW'

@staticmethod
def merge_nodes_permutations(graph: Graph):
# Iterate over all data nodes and check all permutations for similarity
# In case of equal permutations, this permutation will be set as attribute for data node
# otherwise exception will be raised
for node in graph.nodes():
node = Node(graph, node)
if node.kind != 'data':
continue

permutations = []

# Get all permutations from in edges
for in_node in node.in_nodes():
edge_attrs = node.graph.get_edge_data(in_node.id, node.id)[0]
if 'permutation' in edge_attrs:
permutations.append(edge_attrs['permutation'])

# Get all permutations from out edges
for out_node in node.out_nodes():
edge_attrs = node.graph.get_edge_data(node.id, out_node.id)[0]
if 'permutation' in edge_attrs:
permutations.append(edge_attrs['permutation'])

# Check that all permutations are equal
final_permutations = []
for p in permutations:
if p is not None:
final_permutations.append(p.perm)
else:
final_permutations.append(int64_array(np.arange(node.shape.size)))

if len(final_permutations) == 0:
continue

if not all([np.array_equal(final_permutations[0], perm) for perm in final_permutations]):
raise Error('Permutations requested for {} data node are not equal! List of permutations: {}'
''.format(node.name, [p.perm for p in permutations]))

assert not node.has_valid('permutation') or np.array_equal(node.permutation, permutations[0])
node['permutation'] = permutations[0]

@staticmethod
def permute_data_nodes_attrs(graph: Graph):
# Iterate over all data nodes and apply permutation if exists
Expand Down
65 changes: 65 additions & 0 deletions model-optimizer/extensions/middle/MergeNodesPermutations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import numpy as np

from extensions.middle.ApplyNHWCtoNCHWpermutation import ApplyNHWCtoNCHWpermutation
from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Graph, Node
from mo.middle.replacement import MiddleReplacementPattern
from mo.utils.error import Error


class MergeNodesPermutations(MiddleReplacementPattern):
enabled = True

def run_after(self):
return [ApplyNHWCtoNCHWpermutation]
jane-intel marked this conversation as resolved.
Show resolved Hide resolved

def run_before(self):
return []
jane-intel marked this conversation as resolved.
Show resolved Hide resolved

def find_and_replace_pattern(self, graph: Graph):
self.merge_nodes_permutations(graph)

@staticmethod
def merge_nodes_permutations(graph: Graph):
# Iterate over all data nodes and check all permutations for similarity
# In case of equal permutations, this permutation will be set as attribute for data node
# otherwise exception will be raised
for node in graph.nodes():
node = Node(graph, node)
if node.kind != 'data':
continue

permutations = []

# Get all permutations from in edges
for in_node in node.in_nodes():
edge_attrs = node.graph.get_edge_data(in_node.id, node.id)[0]
if 'permutation' in edge_attrs:
permutations.append(edge_attrs['permutation'])

# Get all permutations from out edges
for out_node in node.out_nodes():
edge_attrs = node.graph.get_edge_data(node.id, out_node.id)[0]
if 'permutation' in edge_attrs:
permutations.append(edge_attrs['permutation'])

final_permutations = []
for p in permutations:
if p is not None:
final_permutations.append(p.perm)
else:
final_permutations.append(int64_array(np.arange(node.shape.size)))

if len(final_permutations) == 0:
continue

# Check that all permutations are equal
if not all([np.array_equal(final_permutations[0], perm) for perm in final_permutations]):
raise Error('Permutations requested for {} data node are not equal! List of permutations: {}'
''.format(node.name, [p.perm for p in permutations]))

assert not node.has_valid('permutation') or np.array_equal(node.permutation, permutations[0])
node['permutation'] = permutations[0]
105 changes: 105 additions & 0 deletions model-optimizer/extensions/middle/PreserveRuntimeInfo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import numpy as np

from extensions.middle.MergeNodesPermutations import MergeNodesPermutations
from extensions.ops.transpose import Transpose
from mo.front.tf.graph_utils import create_op_node_with_second_input
from mo.graph.graph import Graph
from mo.middle.replacement import MiddleReplacementPattern
from mo.utils.runtime_info import OldAPIMap


class PreserveRuntimeInfo(MiddleReplacementPattern):
popovaan marked this conversation as resolved.
Show resolved Hide resolved
""" This transformation preserves original layout for Parameter and Result nodes
and adds old_api_map attribute in rt_info which stores the following information:

Parameter:
Order of the transpose which should be applied to Parameter with old API layout to
obtain Parameter with new API layout.

Result:
Order of the transpose which should be applied to Result with new API layout to
obtain Result with old API layout.

This transformation shouldn't be applied for Parameter or Result nodes inside
body graphs of any operations like If, TensorIterator, Loop etc. For this reason
transformation should be executed non-recursively.
"""
enabled = True
run_not_recursively = True
popovaan marked this conversation as resolved.
Show resolved Hide resolved

def run_after(self):
return [MergeNodesPermutations]

def run_before(self):
return []

def find_and_replace_pattern(self, graph: Graph):
self.preserve_rt_info(graph)

@staticmethod
def preserve_rt_info(graph: Graph):
for op in graph.get_op_nodes():
op_name = op.soft_get('name', op.id)
op_type = op.soft_get('type')
if op_type == 'Parameter' and op.has_valid('permute_attrs') and not op.has_and_set('nchw_layout'):
if not op.out_node(0).has_valid('permutation'):
continue
permutation = op.out_node(0).permutation
if np.array_equal(permutation.inv, range(len(permutation.inv))):
continue

# rt info update
assert op.has('rt_info'), 'Unable to preserve runtime information for node with name={}'.format(op_name)

old_api_map = OldAPIMap(version=0)
if ('old_api_map', old_api_map.get_version()) not in op.rt_info.info:
op.rt_info.info[('old_api_map', old_api_map.get_version())] = old_api_map
op.rt_info.info[('old_api_map', old_api_map.get_version())].old_api_transpose_parameter(permutation.inv)

# keep input in the framework format
transpose = create_op_node_with_second_input(
graph, Transpose, permutation.perm, {'name': op_name + '/Transpose({})'.format(permutation.perm)})

# source mode is used to keep tensor names at Parameter node
op.out_port(0).get_connection().insert_node(transpose, "source")

transpose.infer(transpose)

if op.has_valid('permute_attrs'):
del op['permute_attrs']
if op.out_node(0).has_valid('permutation'):
del op.out_node(0)['permutation']

elif op_type == 'Result' and op.in_ports():
prev_node_out_port = op.in_port(0).get_connection().get_source()
if prev_node_out_port is None:
continue
in_node = prev_node_out_port.node
in_data_node = in_node.out_node(prev_node_out_port.idx)
if in_data_node.has_and_set('permutation'):
permutation = in_data_node['permutation']
if np.array_equal(permutation.perm, range(len(permutation.perm))):
continue

# rt info update
assert op.has('rt_info'), 'Unable to preserve runtime information for node with name={}'.format(op)
old_api_map = OldAPIMap(version=0)
if ('old_api_map', old_api_map.get_version()) not in op.rt_info.info:
op.rt_info.info[('old_api_map', old_api_map.get_version())] = old_api_map
op.rt_info.info[('old_api_map', old_api_map.get_version())].old_api_transpose_result(permutation.perm)

if in_data_node.has_valid('permutation'):
del in_data_node['permutation']

# keep result in the framework format
transpose = create_op_node_with_second_input(graph, Transpose, permutation.inv)
# preserve output node name as it is used as output name in legacy IE API
transpose.name = in_node.name
in_node.name += "/prev"

prev_node_out_port.get_connection().insert_node(transpose)
in_node.infer(in_node)
transpose.infer(transpose)
Loading