-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Unification of layout and data types of Parameter and Result in MO (#…
…7630) * Added transposes insertion for Parameter and Result. * Separated into several transformations. * Corrected runtime_info format. * Fixed runtime info serialization. * Code refactoring. * Corrected checks. * Added debug output. * Added check. * Fixed unit tests. * Changed old api map format, removed debug output. * Moved serialize to rt_info property, code corrections. * Refactored RTInfo class. * Small corrections. * Small corrections. * Removed redurant import. * Added tests, added undefined default type. * Code reformat. * Fixed serialization unit tests. * Added comment. * Added comment. * Small test correction. * Changed default values of old_api_map to values from old API IR. * np.array -> int64_array * Update MO to use FE to read IR; Swith MO IR version to 11 * Preserve output node name when inserting Transpose * Codestyle * Fix layer tests * Pylint fix * Disable ref_graphs comparision in layer tests * codestyle * Updated MO IR reader. * Moved version initialization to constructor of OldApiMap. * Added shape infer after transpose insertion. * Fixed Pylint * Removed wrong attribute removal. * Added transposes insertion for Parameter and Result. * Separated into several transformations. * Corrected runtime_info format. * Fixed runtime info serialization. * Code refactoring. * Corrected checks. * Added debug output. * Added check. * Fixed unit tests. * Changed old api map format, removed debug output. * Moved serialize to rt_info property, code corrections. * Refactored RTInfo class. * Small corrections. * Small corrections. * Removed redurant import. * Added tests, added undefined default type. * Code reformat. * Fixed serialization unit tests. * Added comment. * Added comment. * Small test correction. * Changed default values of old_api_map to values from old API IR. * np.array -> int64_array * Update MO to use FE to read IR; Swith MO IR version to 11 * Preserve output node name when inserting Transpose * Codestyle * Fix layer tests * Pylint fix * Disable ref_graphs comparision in layer tests * codestyle * Updated MO IR reader. * Moved version initialization to constructor of OldApiMap. * Added shape infer after transpose insertion. * Fixed Pylint * Removed wrong attribute removal. * Serialize fix. Co-authored-by: Gleb Kazantaev <[email protected]>
- Loading branch information
Showing
23 changed files
with
568 additions
and
85 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
65 changes: 65 additions & 0 deletions
65
model-optimizer/extensions/middle/MergeNodesPermutations.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# Copyright (C) 2018-2021 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import numpy as np | ||
|
||
from extensions.middle.ApplyNHWCtoNCHWpermutation import ApplyNHWCtoNCHWpermutation | ||
from mo.front.common.partial_infer.utils import int64_array | ||
from mo.graph.graph import Graph, Node | ||
from mo.middle.replacement import MiddleReplacementPattern | ||
from mo.utils.error import Error | ||
|
||
|
||
class MergeNodesPermutations(MiddleReplacementPattern): | ||
enabled = True | ||
|
||
def run_after(self): | ||
return [ApplyNHWCtoNCHWpermutation] | ||
|
||
def run_before(self): | ||
return [] | ||
|
||
def find_and_replace_pattern(self, graph: Graph): | ||
self.merge_nodes_permutations(graph) | ||
|
||
@staticmethod | ||
def merge_nodes_permutations(graph: Graph): | ||
# Iterate over all data nodes and check all permutations for similarity | ||
# In case of equal permutations, this permutation will be set as attribute for data node | ||
# otherwise exception will be raised | ||
for node in graph.nodes(): | ||
node = Node(graph, node) | ||
if node.kind != 'data': | ||
continue | ||
|
||
permutations = [] | ||
|
||
# Get all permutations from in edges | ||
for in_node in node.in_nodes(): | ||
edge_attrs = node.graph.get_edge_data(in_node.id, node.id)[0] | ||
if 'permutation' in edge_attrs: | ||
permutations.append(edge_attrs['permutation']) | ||
|
||
# Get all permutations from out edges | ||
for out_node in node.out_nodes(): | ||
edge_attrs = node.graph.get_edge_data(node.id, out_node.id)[0] | ||
if 'permutation' in edge_attrs: | ||
permutations.append(edge_attrs['permutation']) | ||
|
||
final_permutations = [] | ||
for p in permutations: | ||
if p is not None: | ||
final_permutations.append(p.perm) | ||
else: | ||
final_permutations.append(int64_array(np.arange(node.shape.size))) | ||
|
||
if len(final_permutations) == 0: | ||
continue | ||
|
||
# Check that all permutations are equal | ||
if not all([np.array_equal(final_permutations[0], perm) for perm in final_permutations]): | ||
raise Error('Permutations requested for {} data node are not equal! List of permutations: {}' | ||
''.format(node.name, [p.perm for p in permutations])) | ||
|
||
assert not node.has_valid('permutation') or np.array_equal(node.permutation, permutations[0]) | ||
node['permutation'] = permutations[0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
# Copyright (C) 2018-2021 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import numpy as np | ||
|
||
from extensions.middle.MergeNodesPermutations import MergeNodesPermutations | ||
from extensions.ops.transpose import Transpose | ||
from mo.front.tf.graph_utils import create_op_node_with_second_input | ||
from mo.graph.graph import Graph | ||
from mo.middle.replacement import MiddleReplacementPattern | ||
from mo.utils.runtime_info import OldAPIMap | ||
|
||
|
||
class PreserveRuntimeInfo(MiddleReplacementPattern): | ||
""" This transformation preserves original layout for Parameter and Result nodes | ||
and adds old_api_map attribute in rt_info which stores the following information: | ||
Parameter: | ||
Order of the transpose which should be applied to Parameter with old API layout to | ||
obtain Parameter with new API layout. | ||
Result: | ||
Order of the transpose which should be applied to Result with new API layout to | ||
obtain Result with old API layout. | ||
This transformation shouldn't be applied for Parameter or Result nodes inside | ||
body graphs of any operations like If, TensorIterator, Loop etc. For this reason | ||
transformation should be executed non-recursively. | ||
""" | ||
enabled = True | ||
run_not_recursively = True | ||
|
||
def run_after(self): | ||
return [MergeNodesPermutations] | ||
|
||
def run_before(self): | ||
return [] | ||
|
||
def find_and_replace_pattern(self, graph: Graph): | ||
self.preserve_rt_info(graph) | ||
|
||
@staticmethod | ||
def preserve_rt_info(graph: Graph): | ||
for op in graph.get_op_nodes(): | ||
op_name = op.soft_get('name', op.id) | ||
op_type = op.soft_get('type') | ||
if op_type == 'Parameter' and op.has_valid('permute_attrs') and not op.has_and_set('nchw_layout'): | ||
if not op.out_node(0).has_valid('permutation'): | ||
continue | ||
permutation = op.out_node(0).permutation | ||
if np.array_equal(permutation.inv, range(len(permutation.inv))): | ||
continue | ||
|
||
# rt info update | ||
assert op.has('rt_info'), 'Unable to preserve runtime information for node with name={}'.format(op_name) | ||
|
||
old_api_map = OldAPIMap(version=0) | ||
if ('old_api_map', old_api_map.get_version()) not in op.rt_info.info: | ||
op.rt_info.info[('old_api_map', old_api_map.get_version())] = old_api_map | ||
op.rt_info.info[('old_api_map', old_api_map.get_version())].old_api_transpose_parameter(permutation.inv) | ||
|
||
# keep input in the framework format | ||
transpose = create_op_node_with_second_input( | ||
graph, Transpose, permutation.perm, {'name': op_name + '/Transpose({})'.format(permutation.perm)}) | ||
|
||
# source mode is used to keep tensor names at Parameter node | ||
op.out_port(0).get_connection().insert_node(transpose, "source") | ||
|
||
if op.has_valid('permute_attrs'): | ||
del op['permute_attrs'] | ||
if op.out_node(0).has_valid('permutation'): | ||
del op.out_node(0)['permutation'] | ||
|
||
elif op_type == 'Result' and op.in_ports(): | ||
prev_node_out_port = op.in_port(0).get_connection().get_source() | ||
if prev_node_out_port is None: | ||
continue | ||
in_node = prev_node_out_port.node | ||
in_data_node = in_node.out_node(prev_node_out_port.idx) | ||
if in_data_node.has_and_set('permutation'): | ||
permutation = in_data_node['permutation'] | ||
if np.array_equal(permutation.perm, range(len(permutation.perm))): | ||
continue | ||
|
||
# rt info update | ||
assert op.has('rt_info'), 'Unable to preserve runtime information for node with name={}'.format(op) | ||
old_api_map = OldAPIMap(version=0) | ||
if ('old_api_map', old_api_map.get_version()) not in op.rt_info.info: | ||
op.rt_info.info[('old_api_map', old_api_map.get_version())] = old_api_map | ||
op.rt_info.info[('old_api_map', old_api_map.get_version())].old_api_transpose_result(permutation.perm) | ||
|
||
# keep result in the framework format | ||
transpose = create_op_node_with_second_input(graph, Transpose, permutation.inv) | ||
# preserve output node name as it is used as output name in legacy IE API | ||
transpose.name = in_node.name | ||
in_node.name += "/prev" | ||
|
||
prev_node_out_port.get_connection().insert_node(transpose) |
Oops, something went wrong.