From e03e90cfbea3534a950e31ad1255391f7ecdfcb6 Mon Sep 17 00:00:00 2001 From: Roman Kazantsev Date: Fri, 5 Feb 2021 12:48:55 +0300 Subject: [PATCH] Support TF2 Keras ConvLSTM2D operation Signed-off-by: Roman Kazantsev --- .../middle/ApplyNHWCtoNCHWpermutation.py | 2 +- .../extensions/middle/ApplyPermutations.py | 2 +- .../extensions/middle/MoveConstToLoopBody.py | 47 +++++++++++++++++++ model-optimizer/extensions/middle/fusings.py | 2 +- model-optimizer/extensions/ops/loop.py | 5 +- model-optimizer/mo/ops/op.py | 2 +- 6 files changed, 53 insertions(+), 7 deletions(-) create mode 100644 model-optimizer/extensions/middle/MoveConstToLoopBody.py diff --git a/model-optimizer/extensions/middle/ApplyNHWCtoNCHWpermutation.py b/model-optimizer/extensions/middle/ApplyNHWCtoNCHWpermutation.py index 5d3561ed739904..682331eabc8fe7 100644 --- a/model-optimizer/extensions/middle/ApplyNHWCtoNCHWpermutation.py +++ b/model-optimizer/extensions/middle/ApplyNHWCtoNCHWpermutation.py @@ -44,7 +44,7 @@ def find_and_replace_pattern(self, graph: Graph): if 'permutation' in edge_attrs: skip_permutation = True for out_node in node.out_nodes(): - edge_attrs = node.graph.get_edge_data(node.id, out_node.id)[0] + edge_attrs = list(node.graph.get_edge_data(node.id, out_node.id).values())[0] if 'permutation' in edge_attrs: skip_permutation = True diff --git a/model-optimizer/extensions/middle/ApplyPermutations.py b/model-optimizer/extensions/middle/ApplyPermutations.py index d35d37cc86b2fd..6a4dea51a957f1 100644 --- a/model-optimizer/extensions/middle/ApplyPermutations.py +++ b/model-optimizer/extensions/middle/ApplyPermutations.py @@ -70,7 +70,7 @@ def merge_nodes_permutations(graph: Graph): # Get all permutations from out edges for out_node in node.out_nodes(): - edge_attrs = node.graph.get_edge_data(node.id, out_node.id)[0] + edge_attrs = list(node.graph.get_edge_data(node.id, out_node.id).values())[0] if 'permutation' in edge_attrs: permutations.append(edge_attrs['permutation']) diff --git a/model-optimizer/extensions/middle/MoveConstToLoopBody.py b/model-optimizer/extensions/middle/MoveConstToLoopBody.py new file mode 100644 index 00000000000000..ec89743dad5365 --- /dev/null +++ b/model-optimizer/extensions/middle/MoveConstToLoopBody.py @@ -0,0 +1,47 @@ +""" + Copyright (C) 2017-2021 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import numpy as np + +from extensions.ops.loop import Loop +from mo.middle.replacement import MiddleReplacementPattern +from mo.graph.graph import Graph, Node +from mo.ops.const import Const + + +class MoveConstToLoopBody(MiddleReplacementPattern): + """ + """ + enabled = True + + def run_after(self): + from extensions.middle.PartialInfer import PartialInfer + return [PartialInfer] + + def run_before(self): + from extensions.middle.ApplyPermutations import ApplyPermutation + return [ApplyPermutation] + + def find_and_replace_pattern(self, graph: Graph): + cleanup_called_once = False + # walk through all Loop nodes and find Const inputs + for loop_node in graph.get_op_nodes(op='Loop'): + if not cleanup_called_once: + graph.clean_up() + cleanup_called_once = True + Loop.pull_constant_inputs_into_body(loop_node) + Loop.normalize_input_output_ports(loop_node) + Loop.infer(loop_node) diff --git a/model-optimizer/extensions/middle/fusings.py b/model-optimizer/extensions/middle/fusings.py index 55abb15499a880..7953f51b5217f8 100644 --- a/model-optimizer/extensions/middle/fusings.py +++ b/model-optimizer/extensions/middle/fusings.py @@ -96,7 +96,7 @@ def find_and_replace_pattern(self, graph: Graph): for_graph_and_each_sub_graph_recursively(graph, fuse_linear_ops) for_graph_and_each_sub_graph_recursively(graph, lambda G: G.clean_up()) - normalize_eltwise_inputs(graph) + for_graph_and_each_sub_graph_recursively(graph, normalize_eltwise_inputs) for_graph_and_each_sub_graph_recursively(graph, lambda G: G.clean_up()) MarkNodesToFuseUpToFakeQuantize().find_and_replace_pattern(graph) diff --git a/model-optimizer/extensions/ops/loop.py b/model-optimizer/extensions/ops/loop.py index 380282841555dc..53af05d219b6bf 100644 --- a/model-optimizer/extensions/ops/loop.py +++ b/model-optimizer/extensions/ops/loop.py @@ -315,8 +315,7 @@ def add_back_edge(loop_node: Node, internal_parameter: Node, internal_result: No @staticmethod def pull_constant_inputs_into_body(loop_node: Node): for port_idx, in_port in reversed(loop_node.in_ports().items()): - # TODO add a check that the input does not correspond to execution_condition - if not in_port.disconnected() and in_port.get_source().node.soft_get('type') == 'Const': + if port_idx > 1 and not in_port.disconnected() and in_port.get_source().node.soft_get('type') == 'Const': original_const_node = in_port.get_source().node new_const_node = Const(loop_node.body, original_const_node.attrs()).create_node() @@ -463,7 +462,7 @@ def remove_unused_ops_from_port_map(loop_node: Node, port_map: dict, port_map_at port_to_remove = port_map[record_id_to_remove]['external_port_id'] if port_to_remove != -1: if dir == 'in': - if port_to_remove not in [0, 1]: # input port 0 and 1 are mandatory for the Loop node + if port_to_remove not in [0, 1] and port_to_remove in loop_node.in_ports().keys(): # input port 0 and 1 are mandatory for the Loop node loop_node.delete_input_port(port_to_remove) elif dir == 'out' and port_to_remove in loop_node.out_ports(): loop_node.delete_output_port(port_to_remove) diff --git a/model-optimizer/mo/ops/op.py b/model-optimizer/mo/ops/op.py index cf4c44aa4204fe..9314631fa652db 100644 --- a/model-optimizer/mo/ops/op.py +++ b/model-optimizer/mo/ops/op.py @@ -419,7 +419,7 @@ def create_permute_attrs(node, attrs=None): @staticmethod def set_permutation(node1, node2, permutation, override=False): # This function creates permutation on edge between node1->node2 - edge_attrs = node1.graph.get_edge_data(node1.id, node2.id)[0] + edge_attrs = list(node1.graph.get_edge_data(node1.id, node2.id).values())[0] if 'permutation' not in edge_attrs or override: nx.set_edge_attributes(G=node1.graph, values={(node1.id, node2.id, 0): permutation}, name='permutation') else: